Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions physicsnemo/diffusion/multi_diffusion/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,29 @@
from physicsnemo.diffusion.utils.utils import apply_loss_weight


def _unwrap_multi_diffusion(model: torch.nn.Module) -> MultiDiffusionModel2D:
"""Peel off DDP / torch.compile wrappers to reach the underlying
:class:`MultiDiffusionModel2D`.

The unwrapping order handles arbitrary nesting of
``DistributedDataParallel`` (``model.module``) and ``torch.compile``
(``OptimizedModule._orig_mod``).
"""
m = model
while not isinstance(m, MultiDiffusionModel2D):
if isinstance(m, torch._dynamo.eval_frame.OptimizedModule):
m = m._orig_mod
elif hasattr(m, "module"):
m = m.module
else:
raise TypeError(
f"Could not unwrap a MultiDiffusionModel2D from "
f"{type(model).__name__}. Found leaf type "
f"{type(m).__name__}."
)
return m


class _CompiledPatchX:
"""Cached ``torch.compile``-d wrapper around
:meth:`~MultiDiffusionModel2D.patch_x`.
Expand Down Expand Up @@ -233,8 +256,9 @@ def __init__(
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self._md_model = _unwrap_multi_diffusion(model)
self.noise_scheduler = noise_scheduler
self._compiled_patch_x = _CompiledPatchX(model)
self._compiled_patch_x = _CompiledPatchX(self._md_model)

if prediction_type == "x0":
self._to_x0 = lambda prediction, x_t, t: prediction
Expand Down Expand Up @@ -301,7 +325,7 @@ def __call__(
:math:`(P \times B, C, H_p, W_p)`. Otherwise a scalar tensor.
"""
if reset_patch_indices:
self.model.reset_patch_indices()
self._md_model.reset_patch_indices()

# Patch x0 and sample per-patch noise
x0_patched = self._compiled_patch_x(x0) # (P*B, C, Hp, Wp)
Expand Down Expand Up @@ -480,8 +504,9 @@ def __init__(
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self._md_model = _unwrap_multi_diffusion(model)
self.noise_scheduler = noise_scheduler
self._compiled_patch_x = _CompiledPatchX(model)
self._compiled_patch_x = _CompiledPatchX(self._md_model)

if prediction_type == "x0":
self._to_x0 = lambda prediction, x_t, t: prediction
Expand Down Expand Up @@ -558,7 +583,7 @@ def __call__(
)

if reset_patch_indices:
self.model.reset_patch_indices()
self._md_model.reset_patch_indices()

# Patch x0 and weight, then sample per-patch noise
x0_patched = self._compiled_patch_x(x0) # (P*B, C, Hp, Wp)
Expand Down
133 changes: 92 additions & 41 deletions physicsnemo/diffusion/multi_diffusion/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""Utilities for multi-diffusion (patching and fusion)."""

import math
import random
import warnings
from abc import ABC, abstractmethod
from typing import Optional, Tuple, Union
Expand Down Expand Up @@ -150,13 +149,20 @@ def global_index(
def _compute_global_index(self) -> Int[Tensor, "P 2 Hp Wp"]:
r"""Compute the global-index tensor from current patch positions.

The ``global_index`` tensor is created and computed on the same device
as ``patch_indices`` (if the buffer exists) to avoid cross-device
indexing errors.

Returns
-------
Tensor
Integer tensor of shape :math:`(P, 2, H_p, W_p)`.
"""
Ny = torch.arange(self.img_shape[0]).int()
Nx = torch.arange(self.img_shape[1]).int()
device = None
if hasattr(self, "patch_indices") and isinstance(self.patch_indices, Tensor):
device = self.patch_indices.device
Ny = torch.arange(self.img_shape[0], device=device).int()
Nx = torch.arange(self.img_shape[1], device=device).int()
grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0).unsqueeze(
0
) # (1, 2, H, W)
Expand Down Expand Up @@ -237,10 +243,20 @@ def set_patch_num(self, value: int) -> None:
self._patch_num = value
self.reset_patch_indices()

def reset_patch_indices(self) -> None:
def reset_patch_indices(
self,
*,
generator: torch.Generator | None = None,
) -> None:
r"""Re-draw random upper-left corner positions for all patches.

Also refreshes the cached ``_global_index`` buffer.
The cached ``_global_index`` buffer is invalidated and will be
lazily recomputed on the next call to :meth:`global_index`.

Parameters
----------
generator : torch.Generator, optional
Pseudo-random number generator for reproducible sampling.
"""
has_buffer = hasattr(self, "patch_indices") and isinstance(
self.patch_indices, Tensor
Expand All @@ -249,60 +265,95 @@ def reset_patch_indices(self) -> None:

max_y = self.img_shape[0] - self.patch_shape[0]
max_x = self.img_shape[1] - self.patch_shape[1]
# TODO: use torch.randint instead of random.randint to create
# patch indices directly on right device. Note: this will break
# non-regression tests because torch.randint and random.randint do not
# use the same random number generation process. But for an object taht
# is deliberately designed to be random, breaking these non-regression
# tests might not be a problem.
new_indices = torch.tensor(
[
(random.randint(0, max_y), random.randint(0, max_x))
for _ in range(self.patch_num)
],

py = torch.randint(
0,
max_y + 1,
(self.patch_num,),
dtype=torch.long,
device=device,
generator=generator,
)
px = torch.randint(
0,
max_x + 1,
(self.patch_num,),
dtype=torch.long,
device=device,
generator=generator,
)
new_indices = torch.stack([py, px], dim=1)

if has_buffer and new_indices.shape == self.patch_indices.shape:
self.patch_indices.copy_(new_indices)
else:
self.register_buffer("patch_indices", new_indices, persistent=False)

# Refresh cached global index
new_global_index = self._compute_global_index()
if (
hasattr(self, "_global_index")
and isinstance(self._global_index, Tensor)
and new_global_index.shape == self._global_index.shape
):
self._global_index.copy_(new_global_index)
else:
self.register_buffer("_global_index", new_global_index, persistent=False)
self._global_index_needs_update = True

def global_index(
self, batch_size: int = 1, device: Union[torch.device, str] = "cpu"
) -> Int[Tensor, "P 2 Hp Wp"]:
r"""Return global :math:`(y, x)` grid coordinates for each patch.
Recomputes lazily if patch positions have changed since the last call.

Parameters
----------
batch_size : int, default=1
Kept for backward compatibility. Ignored.
device : Union[torch.device, str], default="cpu"
Kept for backward compatibility. The buffer follows the module
device (use ``.to(device)`` to move the module).

Returns
-------
Tensor
Integer tensor of shape :math:`(P, 2, H_p, W_p)`.
"""
if getattr(self, "_global_index_needs_update", True):
new_global_index = self._compute_global_index()
if (
hasattr(self, "_global_index")
and isinstance(self._global_index, Tensor)
and new_global_index.shape == self._global_index.shape
):
self._global_index.copy_(new_global_index)
else:
self.register_buffer(
"_global_index", new_global_index, persistent=False
)
self._global_index_needs_update = False
return self._global_index.clone()

def forward(
self,
input: Float[Tensor, "B C H W"],
additional_input: Optional[Float[Tensor, "B C_add H_add W_add"]] = None,
) -> Float[Tensor, "P_times_B C_out Hp Wp"]:
r"""Extract random patches from the input tensor."""
B = input.shape[0]
B, C, H, W = input.shape
Hp, Wp = self.patch_shape
P = self.patch_num

# Unfold creates a view of all stride-1 patches (no copy)
patches = input.unfold(2, Hp, 1).unfold(3, Wp, 1)
# patches: (B, C, H-Hp+1, W-Wp+1, Hp, Wp)

# Gather the P patches at stored random positions
py = self.patch_indices[:, 0] # (P,)
px = self.patch_indices[:, 1] # (P,)
gathered = patches[:, :, py, px] # (B, C, P, Hp, Wp)

# Reorder to patch-major layout
out = gathered.permute(2, 0, 1, 3, 4).reshape(
P * B, -1, Hp, Wp
) # (P*B, C, Hp, Wp)
K = Hp * Wp

patch_indices = self.patch_indices.to(input.device)
py = patch_indices[:, 0] # (P,)
px = patch_indices[:, 1] # (P,)

dy = torch.arange(Hp, device=input.device)
dx = torch.arange(Wp, device=input.device)
base = (py * W + px).reshape(P, 1, 1) # (P, 1, 1)
rel = (dy[:, None] * W + dx[None, :]).reshape(1, 1, K) # (1, 1, K)
idx = (base + rel).expand(P, B, K) # (P, B, Hp*Wp)

x_flat = input.reshape(B, C, H * W) # (B, C, HW)
gathered = torch.gather(
x_flat.unsqueeze(0).expand(P, B, C, H * W),
dim=3,
index=idx.unsqueeze(2).expand(P, B, C, K),
) # (P, B, C, Hp*Wp)

out = gathered.reshape(P * B, C, Hp, Wp)

if input.is_contiguous(memory_format=torch.channels_last):
out = out.to(memory_format=torch.channels_last)
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/diffusion/utils/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Any

Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified test/diffusion/data/test_multi_diffusion_models_uncond.mdlus
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified test/diffusion/data/test_patching_random_apply_after_reset.pth
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading