diff --git a/physicsnemo/diffusion/multi_diffusion/losses.py b/physicsnemo/diffusion/multi_diffusion/losses.py index cca84686cb..8aeb9fadb7 100644 --- a/physicsnemo/diffusion/multi_diffusion/losses.py +++ b/physicsnemo/diffusion/multi_diffusion/losses.py @@ -29,6 +29,29 @@ from physicsnemo.diffusion.utils.utils import apply_loss_weight +def _unwrap_multi_diffusion(model: torch.nn.Module) -> MultiDiffusionModel2D: + """Peel off DDP / torch.compile wrappers to reach the underlying + :class:`MultiDiffusionModel2D`. + + The unwrapping order handles arbitrary nesting of + ``DistributedDataParallel`` (``model.module``) and ``torch.compile`` + (``OptimizedModule._orig_mod``). + """ + m = model + while not isinstance(m, MultiDiffusionModel2D): + if isinstance(m, torch._dynamo.eval_frame.OptimizedModule): + m = m._orig_mod + elif hasattr(m, "module"): + m = m.module + else: + raise TypeError( + f"Could not unwrap a MultiDiffusionModel2D from " + f"{type(model).__name__}. Found leaf type " + f"{type(m).__name__}." + ) + return m + + class _CompiledPatchX: """Cached ``torch.compile``-d wrapper around :meth:`~MultiDiffusionModel2D.patch_x`. @@ -233,8 +256,9 @@ def __init__( reduction: Literal["none", "mean", "sum"] = "mean", ) -> None: self.model = model + self._md_model = _unwrap_multi_diffusion(model) self.noise_scheduler = noise_scheduler - self._compiled_patch_x = _CompiledPatchX(model) + self._compiled_patch_x = _CompiledPatchX(self._md_model) if prediction_type == "x0": self._to_x0 = lambda prediction, x_t, t: prediction @@ -301,7 +325,7 @@ def __call__( :math:`(P \times B, C, H_p, W_p)`. Otherwise a scalar tensor. """ if reset_patch_indices: - self.model.reset_patch_indices() + self._md_model.reset_patch_indices() # Patch x0 and sample per-patch noise x0_patched = self._compiled_patch_x(x0) # (P*B, C, Hp, Wp) @@ -480,8 +504,9 @@ def __init__( reduction: Literal["none", "mean", "sum"] = "mean", ) -> None: self.model = model + self._md_model = _unwrap_multi_diffusion(model) self.noise_scheduler = noise_scheduler - self._compiled_patch_x = _CompiledPatchX(model) + self._compiled_patch_x = _CompiledPatchX(self._md_model) if prediction_type == "x0": self._to_x0 = lambda prediction, x_t, t: prediction @@ -558,7 +583,7 @@ def __call__( ) if reset_patch_indices: - self.model.reset_patch_indices() + self._md_model.reset_patch_indices() # Patch x0 and weight, then sample per-patch noise x0_patched = self._compiled_patch_x(x0) # (P*B, C, Hp, Wp) diff --git a/physicsnemo/diffusion/multi_diffusion/patching.py b/physicsnemo/diffusion/multi_diffusion/patching.py index c1b3805990..48442228d6 100644 --- a/physicsnemo/diffusion/multi_diffusion/patching.py +++ b/physicsnemo/diffusion/multi_diffusion/patching.py @@ -17,7 +17,6 @@ """Utilities for multi-diffusion (patching and fusion).""" import math -import random import warnings from abc import ABC, abstractmethod from typing import Optional, Tuple, Union @@ -150,13 +149,20 @@ def global_index( def _compute_global_index(self) -> Int[Tensor, "P 2 Hp Wp"]: r"""Compute the global-index tensor from current patch positions. + The ``global_index`` tensor is created and computed on the same device + as ``patch_indices`` (if the buffer exists) to avoid cross-device + indexing errors. + Returns ------- Tensor Integer tensor of shape :math:`(P, 2, H_p, W_p)`. """ - Ny = torch.arange(self.img_shape[0]).int() - Nx = torch.arange(self.img_shape[1]).int() + device = None + if hasattr(self, "patch_indices") and isinstance(self.patch_indices, Tensor): + device = self.patch_indices.device + Ny = torch.arange(self.img_shape[0], device=device).int() + Nx = torch.arange(self.img_shape[1], device=device).int() grid = torch.stack(torch.meshgrid(Ny, Nx, indexing="ij"), dim=0).unsqueeze( 0 ) # (1, 2, H, W) @@ -237,10 +243,20 @@ def set_patch_num(self, value: int) -> None: self._patch_num = value self.reset_patch_indices() - def reset_patch_indices(self) -> None: + def reset_patch_indices( + self, + *, + generator: torch.Generator | None = None, + ) -> None: r"""Re-draw random upper-left corner positions for all patches. - Also refreshes the cached ``_global_index`` buffer. + The cached ``_global_index`` buffer is invalidated and will be + lazily recomputed on the next call to :meth:`global_index`. + + Parameters + ---------- + generator : torch.Generator, optional + Pseudo-random number generator for reproducible sampling. """ has_buffer = hasattr(self, "patch_indices") and isinstance( self.patch_indices, Tensor @@ -249,36 +265,65 @@ def reset_patch_indices(self) -> None: max_y = self.img_shape[0] - self.patch_shape[0] max_x = self.img_shape[1] - self.patch_shape[1] - # TODO: use torch.randint instead of random.randint to create - # patch indices directly on right device. Note: this will break - # non-regression tests because torch.randint and random.randint do not - # use the same random number generation process. But for an object taht - # is deliberately designed to be random, breaking these non-regression - # tests might not be a problem. - new_indices = torch.tensor( - [ - (random.randint(0, max_y), random.randint(0, max_x)) - for _ in range(self.patch_num) - ], + + py = torch.randint( + 0, + max_y + 1, + (self.patch_num,), dtype=torch.long, device=device, + generator=generator, ) + px = torch.randint( + 0, + max_x + 1, + (self.patch_num,), + dtype=torch.long, + device=device, + generator=generator, + ) + new_indices = torch.stack([py, px], dim=1) if has_buffer and new_indices.shape == self.patch_indices.shape: self.patch_indices.copy_(new_indices) else: self.register_buffer("patch_indices", new_indices, persistent=False) - # Refresh cached global index - new_global_index = self._compute_global_index() - if ( - hasattr(self, "_global_index") - and isinstance(self._global_index, Tensor) - and new_global_index.shape == self._global_index.shape - ): - self._global_index.copy_(new_global_index) - else: - self.register_buffer("_global_index", new_global_index, persistent=False) + self._global_index_needs_update = True + + def global_index( + self, batch_size: int = 1, device: Union[torch.device, str] = "cpu" + ) -> Int[Tensor, "P 2 Hp Wp"]: + r"""Return global :math:`(y, x)` grid coordinates for each patch. + Recomputes lazily if patch positions have changed since the last call. + + Parameters + ---------- + batch_size : int, default=1 + Kept for backward compatibility. Ignored. + device : Union[torch.device, str], default="cpu" + Kept for backward compatibility. The buffer follows the module + device (use ``.to(device)`` to move the module). + + Returns + ------- + Tensor + Integer tensor of shape :math:`(P, 2, H_p, W_p)`. + """ + if getattr(self, "_global_index_needs_update", True): + new_global_index = self._compute_global_index() + if ( + hasattr(self, "_global_index") + and isinstance(self._global_index, Tensor) + and new_global_index.shape == self._global_index.shape + ): + self._global_index.copy_(new_global_index) + else: + self.register_buffer( + "_global_index", new_global_index, persistent=False + ) + self._global_index_needs_update = False + return self._global_index.clone() def forward( self, @@ -286,23 +331,29 @@ def forward( additional_input: Optional[Float[Tensor, "B C_add H_add W_add"]] = None, ) -> Float[Tensor, "P_times_B C_out Hp Wp"]: r"""Extract random patches from the input tensor.""" - B = input.shape[0] + B, C, H, W = input.shape Hp, Wp = self.patch_shape P = self.patch_num - - # Unfold creates a view of all stride-1 patches (no copy) - patches = input.unfold(2, Hp, 1).unfold(3, Wp, 1) - # patches: (B, C, H-Hp+1, W-Wp+1, Hp, Wp) - - # Gather the P patches at stored random positions - py = self.patch_indices[:, 0] # (P,) - px = self.patch_indices[:, 1] # (P,) - gathered = patches[:, :, py, px] # (B, C, P, Hp, Wp) - - # Reorder to patch-major layout - out = gathered.permute(2, 0, 1, 3, 4).reshape( - P * B, -1, Hp, Wp - ) # (P*B, C, Hp, Wp) + K = Hp * Wp + + patch_indices = self.patch_indices.to(input.device) + py = patch_indices[:, 0] # (P,) + px = patch_indices[:, 1] # (P,) + + dy = torch.arange(Hp, device=input.device) + dx = torch.arange(Wp, device=input.device) + base = (py * W + px).reshape(P, 1, 1) # (P, 1, 1) + rel = (dy[:, None] * W + dx[None, :]).reshape(1, 1, K) # (1, 1, K) + idx = (base + rel).expand(P, B, K) # (P, B, Hp*Wp) + + x_flat = input.reshape(B, C, H * W) # (B, C, HW) + gathered = torch.gather( + x_flat.unsqueeze(0).expand(P, B, C, H * W), + dim=3, + index=idx.unsqueeze(2).expand(P, B, C, K), + ) # (P, B, C, Hp*Wp) + + out = gathered.reshape(P * B, C, Hp, Wp) if input.is_contiguous(memory_format=torch.channels_last): out = out.to(memory_format=torch.channels_last) diff --git a/physicsnemo/diffusion/utils/model_wrappers.py b/physicsnemo/diffusion/utils/model_wrappers.py index a7b9f30dbf..9837f54d3c 100644 --- a/physicsnemo/diffusion/utils/model_wrappers.py +++ b/physicsnemo/diffusion/utils/model_wrappers.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import warnings from typing import Any diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_score_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_score_sq_train.pth index 1e9b9a338c..307127deb5 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_score_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_score_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_ns_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_ns_train.pth index cfeaeff2f7..acdd0c18be 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_ns_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_ns_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_sq_train.pth index 1cac9a7cde..969978e385 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_patch_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_vec_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_vec_x0_sq_train.pth index 16e9298b1e..10fd08cda5 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_cond_vec_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_cond_vec_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_posembd_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_posembd_x0_sq_train.pth index c3110a014b..c81faaca90 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_posembd_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_posembd_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_score_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_score_sq_train.pth index 0126ad093e..889fa109b7 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_score_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_score_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_ns_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_ns_train.pth index 6e6725375b..391d2988b1 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_ns_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_ns_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_sq_train.pth index 5f7c3564eb..44a0576153 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_mse_uncond_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_score_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_score_sq_train.pth index 5787d92e5f..c239cf938d 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_score_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_score_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_ns_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_ns_train.pth index 2b3039dc69..091a0cc900 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_ns_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_ns_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_sq_train.pth index 07d2a2aa2d..bf8887d43f 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_patch_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_vec_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_vec_x0_sq_train.pth index e5f5103da1..8bc23049fc 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_vec_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_cond_vec_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_posembd_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_posembd_x0_sq_train.pth index 9cf4f37f5c..f8ca956f5d 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_posembd_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_posembd_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_score_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_score_sq_train.pth index b813a23f08..720d742d92 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_score_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_score_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_ns_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_ns_train.pth index 370e1456fc..9ef5bf97ac 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_ns_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_ns_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_sq_train.pth b/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_sq_train.pth index 6c7cc33b7c..a7332b28df 100644 Binary files a/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_sq_train.pth and b/test/diffusion/data/test_multi_diffusion_losses_wmse_uncond_x0_sq_train.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus b/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus index b3f6cebc1b..c4baf2dae0 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus and b/test/diffusion/data/test_multi_diffusion_models_cond_interp.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand.pth b/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand.pth index 4cc5bb13dd..0d2a7a6545 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand_prepatched.pth b/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand_prepatched.pth index 86ac4aa538..8f84dabec7 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand_prepatched.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_interp_fwd_rand_prepatched.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_interp_patch_x.pth b/test/diffusion/data/test_multi_diffusion_models_cond_interp_patch_x.pth index e4adac678e..6ea26db4f7 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_interp_patch_x.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_interp_patch_x.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus b/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus index 91ee364d54..dd54061dd0 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus and b/test/diffusion/data/test_multi_diffusion_models_cond_patch.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand.pth b/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand.pth index 0d075ea6b9..a6119f6eaa 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand_prepatched.pth b/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand_prepatched.pth index beebec3a87..224c0389e6 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand_prepatched.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_patch_fwd_rand_prepatched.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_patch_patch_x.pth b/test/diffusion/data/test_multi_diffusion_models_cond_patch_patch_x.pth index b2f0d835fb..2c5e883e95 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_patch_patch_x.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_patch_patch_x.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus index 0d95ae3f0f..602261f791 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus and b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand.pth b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand.pth index 946a41ee6c..2cdf1f8bae 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand_prepatched.pth b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand_prepatched.pth index be6069019f..7dfd7981de 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand_prepatched.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_fwd_rand_prepatched.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_patch_x.pth b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_patch_x.pth index ea7a4f0dde..ce87c1bded 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_patch_x.pth and b/test/diffusion/data/test_multi_diffusion_models_cond_vec_img_patch_x.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus b/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus index 0db28e36e3..553360c5fd 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus and b/test/diffusion/data/test_multi_diffusion_models_posembd_learn.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand.pth b/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand.pth index 797c838ede..a18800baed 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand.pth and b/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand_prepatched.pth b/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand_prepatched.pth index bfbeed50aa..f40d5297fe 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand_prepatched.pth and b/test/diffusion/data/test_multi_diffusion_models_posembd_learn_fwd_rand_prepatched.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_learn_patch_x.pth b/test/diffusion/data/test_multi_diffusion_models_posembd_learn_patch_x.pth index 8a9b1b82ca..d2e2cf0b15 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_learn_patch_x.pth and b/test/diffusion/data/test_multi_diffusion_models_posembd_learn_patch_x.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus b/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus index a45bb19ba1..ece97562fc 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus and b/test/diffusion/data/test_multi_diffusion_models_posembd_sin.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand.pth b/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand.pth index 8dd8a8d300..4fcef0eeca 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand.pth and b/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand_prepatched.pth b/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand_prepatched.pth index f3fb8e9d31..c16e9aad74 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand_prepatched.pth and b/test/diffusion/data/test_multi_diffusion_models_posembd_sin_fwd_rand_prepatched.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_posembd_sin_patch_x.pth b/test/diffusion/data/test_multi_diffusion_models_posembd_sin_patch_x.pth index d59056f39b..5594650f60 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_posembd_sin_patch_x.pth and b/test/diffusion/data/test_multi_diffusion_models_posembd_sin_patch_x.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus b/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus index 9418a4b76e..e90f1d7328 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus and b/test/diffusion/data/test_multi_diffusion_models_uncond.mdlus differ diff --git a/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand.pth b/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand.pth index 4d54dc399b..5dc22f77f9 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand.pth and b/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand_prepatched.pth b/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand_prepatched.pth index 35e9694552..faf9d44aa8 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand_prepatched.pth and b/test/diffusion/data/test_multi_diffusion_models_uncond_fwd_rand_prepatched.pth differ diff --git a/test/diffusion/data/test_multi_diffusion_models_uncond_patch_x.pth b/test/diffusion/data/test_multi_diffusion_models_uncond_patch_x.pth index da09616668..6265309931 100644 Binary files a/test/diffusion/data/test_multi_diffusion_models_uncond_patch_x.pth and b/test/diffusion/data/test_multi_diffusion_models_uncond_patch_x.pth differ diff --git a/test/diffusion/data/test_patching_random_apply_after_reset.pth b/test/diffusion/data/test_patching_random_apply_after_reset.pth index 9f7dcc2d42..89fe971ebb 100644 Binary files a/test/diffusion/data/test_patching_random_apply_after_reset.pth and b/test/diffusion/data/test_patching_random_apply_after_reset.pth differ diff --git a/test/diffusion/data/test_patching_random_global_index_after_reset.pth b/test/diffusion/data/test_patching_random_global_index_after_reset.pth index 0a569f6657..f0b439c790 100644 Binary files a/test/diffusion/data/test_patching_random_global_index_after_reset.pth and b/test/diffusion/data/test_patching_random_global_index_after_reset.pth differ diff --git a/test/diffusion/data/test_patching_random_rand_16x16_2p_apply.pth b/test/diffusion/data/test_patching_random_rand_16x16_2p_apply.pth index cd952a41a8..cce3f81dd2 100644 Binary files a/test/diffusion/data/test_patching_random_rand_16x16_2p_apply.pth and b/test/diffusion/data/test_patching_random_rand_16x16_2p_apply.pth differ diff --git a/test/diffusion/data/test_patching_random_rand_16x16_2p_apply_add.pth b/test/diffusion/data/test_patching_random_rand_16x16_2p_apply_add.pth index c034fd787d..474335bd41 100644 Binary files a/test/diffusion/data/test_patching_random_rand_16x16_2p_apply_add.pth and b/test/diffusion/data/test_patching_random_rand_16x16_2p_apply_add.pth differ diff --git a/test/diffusion/data/test_patching_random_rand_16x16_2p_global_index.pth b/test/diffusion/data/test_patching_random_rand_16x16_2p_global_index.pth index 88c294aab8..e60138c122 100644 Binary files a/test/diffusion/data/test_patching_random_rand_16x16_2p_global_index.pth and b/test/diffusion/data/test_patching_random_rand_16x16_2p_global_index.pth differ diff --git a/test/diffusion/data/test_patching_random_rand_8x12_4p_apply.pth b/test/diffusion/data/test_patching_random_rand_8x12_4p_apply.pth index 307ec82039..a48189afce 100644 Binary files a/test/diffusion/data/test_patching_random_rand_8x12_4p_apply.pth and b/test/diffusion/data/test_patching_random_rand_8x12_4p_apply.pth differ diff --git a/test/diffusion/data/test_patching_random_rand_8x12_4p_apply_add.pth b/test/diffusion/data/test_patching_random_rand_8x12_4p_apply_add.pth index 572921ea8e..b8f173a318 100644 Binary files a/test/diffusion/data/test_patching_random_rand_8x12_4p_apply_add.pth and b/test/diffusion/data/test_patching_random_rand_8x12_4p_apply_add.pth differ diff --git a/test/diffusion/data/test_patching_random_rand_8x12_4p_global_index.pth b/test/diffusion/data/test_patching_random_rand_8x12_4p_global_index.pth index 62c4c315a6..f28d4771c3 100644 Binary files a/test/diffusion/data/test_patching_random_rand_8x12_4p_global_index.pth and b/test/diffusion/data/test_patching_random_rand_8x12_4p_global_index.pth differ