Skip to content

Fixes for multi-diffusion#1560

Merged
CharlelieLrt merged 9 commits intoNVIDIA:mainfrom
CharlelieLrt:multi-diffusion-fixes
Apr 11, 2026
Merged

Fixes for multi-diffusion#1560
CharlelieLrt merged 9 commits intoNVIDIA:mainfrom
CharlelieLrt:multi-diffusion-fixes

Conversation

@CharlelieLrt
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 10, 2026

Greptile Summary

This PR fixes several device-compatibility and DDP/torch.compile compatibility issues in the multi-diffusion training path. Key changes: reset_patch_indices and _CompiledPatchX now operate on the unwrapped MultiDiffusionModel2D (peeling off DDP/compiled wrappers via the new _unwrap_multi_diffusion helper); _compute_global_index propagates device from patch_indices to avoid cross-device errors; random.randint is replaced by torch.randint for GPU-native patch sampling with optional generator support; and RandomPatching2D.forward is rewritten from unfold+advanced-indexing to a torch.gather-based approach that is more compatible with torch.compile. All remaining findings are P2 suggestions.

Important Files Changed

Filename Overview
physicsnemo/diffusion/multi_diffusion/losses.py Adds _unwrap_multi_diffusion helper to peel DDP/compiled wrappers, and uses the unwrapped model for reset_patch_indices and _CompiledPatchX; constructor type annotations still declare MultiDiffusionModel2D but the logic now accepts wrapped variants.
physicsnemo/diffusion/multi_diffusion/patching.py Replaces random.randint with torch.randint (GPU-compatible, supports PRNG generator), propagates device from patch_indices in _compute_global_index, defers _global_index recomputation lazily via a plain Python flag, and rewrites forward from unfold-based to torch.gather-based indexing.
physicsnemo/diffusion/utils/model_wrappers.py Trivial: adds a blank line after the license header. No functional change.

Comments Outside Diff (2)

  1. physicsnemo/diffusion/multi_diffusion/losses.py, line 228-229 (link)

    P2 Type annotation doesn't match new behaviour

    The model parameter is still typed as MultiDiffusionModel2D, but _unwrap_multi_diffusion was introduced precisely because this constructor should now also accept DistributedDataParallel and torch.compile-wrapped models. Passing a DDP-wrapped model currently satisfies the runtime logic but violates the declared type, which will mislead type checkers and users reading the docstring.

    The same applies to the docstring Parameters section (model : MultiDiffusionModel2D) — updating it to model : torch.nn.Module or model : MultiDiffusionModel2D | torch.nn.Module would keep the contract accurate.

  2. physicsnemo/diffusion/multi_diffusion/losses.py, line 451-452 (link)

    P2 Type annotation doesn't match new behaviour

    Same annotation mismatch as in MultiDiffusionMSEDSMLoss: the parameter is typed MultiDiffusionModel2D but the class now unwraps DDP/compiled wrappers. Consider updating to torch.nn.Module.

Reviews (1): Last reviewed commit: "Fixes for multi-diffusion" | Re-trigger Greptile

Copy link
Copy Markdown
Collaborator

@pzharrington pzharrington left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@CharlelieLrt CharlelieLrt enabled auto-merge April 10, 2026 05:47
@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blosson-ci

@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@CharlelieLrt CharlelieLrt added this pull request to the merge queue Apr 11, 2026
Merged via the queue into NVIDIA:main with commit a102e53 Apr 11, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants