diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1bf35d9d0a..62f0a93245 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -37,6 +37,7 @@ asv.conf.json @loliverhennigh # Coding standards CODING_STANDARDS/ @CharlelieLrt @coreyjadams @loliverhennigh +CODING_STANDARDS/FUNCTIONAL_APIS.md @loliverhennigh # Interrogate "grandfathered exceptions" list ./test/ci_tests/interrogate_baseline.txt @ktangsali @coreyjadams @nickgeneva @peterdsharpe diff --git a/CHANGELOG.md b/CHANGELOG.md index 090fb6a4a1..5909c927e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 creating Voronoi regions around seed points. BVH-accelerated. - Added support for 1D, 2D, and 3D neighborhood attention (natten) via `physicsnemo.nn.functional` interface, with full `ShardTensor` support. +- Added derivative functionals in `physicsnemo.nn.functional` for + `uniform_grid_gradient`, `rectilinear_grid_gradient`, + `spectral_grid_gradient`, `meshless_fd_derivatives`, `mesh_lsq_gradient`, + and `mesh_green_gauss_gradient`. +- Added Warp custom-op backward support for `mesh_lsq_gradient` and + `mesh_green_gauss_gradient` gradients with respect to both geometry points + and field values. ### Changed diff --git a/CODING_STANDARDS/FUNCTIONAL_APIS.md b/CODING_STANDARDS/FUNCTIONAL_APIS.md index fcd375a874..024f5c4765 100644 --- a/CODING_STANDARDS/FUNCTIONAL_APIS.md +++ b/CODING_STANDARDS/FUNCTIONAL_APIS.md @@ -61,6 +61,7 @@ This document is structured in two main sections: | [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs_forward`/`make_inputs_backward`/`compare_forward` | | [`FNC-006`](#fnc-006-testing-functionals) | Testing functionals | Adding functional tests | | [`FNC-007`](#fnc-007-benchmark-registry) | Benchmark registry | Adding a functional to ASV | +| [`FNC-008`](#fnc-008-warp-integration-must-use-torch-custom-ops) | Warp integration must use torch custom ops | Adding/refactoring Warp-backed functionals | --- @@ -501,3 +502,58 @@ FUNCTIONAL_SPECS = (KNN, RadiusSearch) # Adding a functional before input generators are implemented. FUNCTIONAL_SPECS = (MyFunctionalWithoutInputs,) ``` + +--- + +### FNC-008: Warp integration must use torch custom ops + +**Description:** + +Warp-backed functionals in `physicsnemo/nn/functional/**` must be integrated +into PyTorch using `torch.library.custom_op`, `register_fake`, and (when +backward is supported) `register_autograd`. Do not use +`torch.autograd.Function` wrappers for Warp-backed functionals. + +If a functional has no meaningful backward path, `register_autograd` is not +required; otherwise, the custom op must register a backward implementation. + +**Rationale:** + +`torch.library.custom_op` provides a consistent integration path for eager, +`torch.compile`, fake tensor propagation, and runtime dispatch behavior. +Avoiding per-functional `torch.autograd.Function` wrappers also keeps backend +integration uniform across functionals. + +**Example:** + +```python +@torch.library.custom_op("physicsnemo::my_warp_op", mutates_args=()) +def my_warp_op_impl(x: torch.Tensor) -> torch.Tensor: + ... + return y + +@my_warp_op_impl.register_fake +def _my_warp_op_impl_fake(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + +def setup_my_warp_op_context(ctx, inputs, output): + ... + +def backward_my_warp_op(ctx, grad_output): + ... + return grad_x + +my_warp_op_impl.register_autograd( + backward_my_warp_op, + setup_context=setup_my_warp_op_context, +) +``` + +**Anti-pattern:** + +```python +class _MyWarpAutograd(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ... +``` diff --git a/benchmarks/physicsnemo/nn/functional/registry.py b/benchmarks/physicsnemo/nn/functional/registry.py index 0e959f90a2..6c0fda4e8d 100644 --- a/benchmarks/physicsnemo/nn/functional/registry.py +++ b/benchmarks/physicsnemo/nn/functional/registry.py @@ -16,6 +16,14 @@ """Registry of FunctionSpec classes to benchmark with ASV.""" +from physicsnemo.nn.functional.derivatives import ( + MeshGreenGaussGradient, + MeshlessFDDerivatives, + MeshLSQGradient, + RectilinearGridGradient, + SpectralGridGradient, + UniformGridGradient, +) from physicsnemo.nn.functional.fourier_spectral import ( IRFFT, IRFFT2, @@ -42,6 +50,13 @@ # Neighbor queries. KNN, RadiusSearch, + # Derivatives. + UniformGridGradient, + RectilinearGridGradient, + MeshLSQGradient, + MeshGreenGaussGradient, + SpectralGridGradient, + MeshlessFDDerivatives, # Geometry. SignedDistanceField, # Interpolation. diff --git a/docs/api/nn/functionals/derivatives.rst b/docs/api/nn/functionals/derivatives.rst new file mode 100644 index 0000000000..0e132c76a7 --- /dev/null +++ b/docs/api/nn/functionals/derivatives.rst @@ -0,0 +1,34 @@ +Derivative Functionals +====================== + +.. autofunction:: physicsnemo.nn.functional.uniform_grid_gradient + +.. figure:: /img/nn/functional/derivatives/uniform_grid_gradient.png + :alt: Uniform grid gradient example + :width: 100% + +.. autofunction:: physicsnemo.nn.functional.rectilinear_grid_gradient + +.. figure:: /img/nn/functional/derivatives/rectilinear_grid_gradient.png + :alt: Rectilinear grid gradient example + :width: 100% + +.. autofunction:: physicsnemo.nn.functional.mesh_lsq_gradient + +.. figure:: /img/nn/functional/derivatives/mesh_lsq_gradient.png + :alt: Mesh LSQ gradient example + :width: 100% + +.. autofunction:: physicsnemo.nn.functional.mesh_green_gauss_gradient + +.. figure:: /img/nn/functional/derivatives/mesh_green_gauss_gradient.png + :alt: Mesh Green-Gauss gradient example + :width: 100% + +.. autofunction:: physicsnemo.nn.functional.spectral_grid_gradient + +.. figure:: /img/nn/functional/derivatives/spectral_grid_gradient.png + :alt: Spectral grid gradient example + :width: 100% + +.. autofunction:: physicsnemo.nn.functional.meshless_fd_derivatives diff --git a/docs/api/physicsnemo.nn.functionals.rst b/docs/api/physicsnemo.nn.functionals.rst index 6608736d45..2a8d43f03e 100644 --- a/docs/api/physicsnemo.nn.functionals.rst +++ b/docs/api/physicsnemo.nn.functionals.rst @@ -19,6 +19,7 @@ in the documentation for performance comparisons. :name: PhysicsNeMo Functionals nn/functionals/neighbors + nn/functionals/derivatives nn/functionals/geometry nn/functionals/fourier_spectral nn/functionals/regularization_parameterization diff --git a/docs/img/nn/functional/derivatives/mesh_green_gauss_gradient.png b/docs/img/nn/functional/derivatives/mesh_green_gauss_gradient.png new file mode 100644 index 0000000000..dd7c7c20ed Binary files /dev/null and b/docs/img/nn/functional/derivatives/mesh_green_gauss_gradient.png differ diff --git a/docs/img/nn/functional/derivatives/mesh_lsq_gradient.png b/docs/img/nn/functional/derivatives/mesh_lsq_gradient.png new file mode 100644 index 0000000000..1c1312af09 Binary files /dev/null and b/docs/img/nn/functional/derivatives/mesh_lsq_gradient.png differ diff --git a/docs/img/nn/functional/derivatives/rectilinear_grid_gradient.png b/docs/img/nn/functional/derivatives/rectilinear_grid_gradient.png new file mode 100644 index 0000000000..c0b3e495f6 Binary files /dev/null and b/docs/img/nn/functional/derivatives/rectilinear_grid_gradient.png differ diff --git a/docs/img/nn/functional/derivatives/spectral_grid_gradient.png b/docs/img/nn/functional/derivatives/spectral_grid_gradient.png new file mode 100644 index 0000000000..966826890e Binary files /dev/null and b/docs/img/nn/functional/derivatives/spectral_grid_gradient.png differ diff --git a/docs/img/nn/functional/derivatives/uniform_grid_gradient.png b/docs/img/nn/functional/derivatives/uniform_grid_gradient.png new file mode 100644 index 0000000000..cbc52f52e7 Binary files /dev/null and b/docs/img/nn/functional/derivatives/uniform_grid_gradient.png differ diff --git a/physicsnemo/mesh/calculus/_lsq_reconstruction.py b/physicsnemo/mesh/calculus/_lsq_reconstruction.py index f94c5200f7..e146a7235c 100644 --- a/physicsnemo/mesh/calculus/_lsq_reconstruction.py +++ b/physicsnemo/mesh/calculus/_lsq_reconstruction.py @@ -29,109 +29,19 @@ import torch -from physicsnemo.mesh.utilities._tolerances import safe_eps - if TYPE_CHECKING: from physicsnemo.mesh.mesh import Mesh -def _solve_batched_lsq_gradients( - positions: torch.Tensor, # shape: (n_entities, n_spatial_dims) - values: torch.Tensor, # shape: (n_entities, ...) - adjacency, # Adjacency object - weight_power: float, - min_neighbors: int = 0, +def _to_mesh_gradient_layout( + gradients: torch.Tensor, + values: torch.Tensor, ) -> torch.Tensor: - """Core batched LSQ gradient solver (shared by point and cell versions). - - For each entity (point or cell), solves a weighted least-squares problem: - min_{∇φ} Σ_neighbors w_i ||∇φ·(x_i - x_0) - (φ_i - φ_0)||² - - Parameters - ---------- - positions : torch.Tensor - Entity positions (points or cell centroids) - values : torch.Tensor - Values at entities (scalars or tensor fields) - adjacency - Adjacency structure (entity-to-entity neighbors) - weight_power : float - Exponent for inverse distance weighting - min_neighbors : int - Minimum neighbors required for gradient computation - - Returns - ------- - torch.Tensor - Gradients at entities, shape (n_entities, n_spatial_dims) for scalars, - or (n_entities, n_spatial_dims, ...) for tensor fields. - Entities with insufficient neighbors have zero gradients. - """ - n_entities = len(positions) - n_spatial_dims = positions.shape[1] - device = positions.device - dtype = values.dtype - - ### Determine output shape - is_scalar = values.ndim == 1 - if is_scalar: - gradient_shape = (n_entities, n_spatial_dims) - else: - gradient_shape = (n_entities, n_spatial_dims) + values.shape[1:] - - gradients = torch.zeros(gradient_shape, dtype=dtype, device=device) - - ### Process each neighbor-count group in parallel - from physicsnemo.mesh.calculus._neighborhoods import iter_neighborhood_batches - - for batch in iter_neighborhood_batches( - positions, adjacency, min_neighbors=min_neighbors - ): - entity_indices = batch.entity_indices - neighbors_flat = batch.neighbor_indices - A = batch.relative_positions # (n_group, n_neighbors, n_spatial_dims) - n_group = len(entity_indices) - n_neighbors = batch.n_neighbors - - ### Entities with no neighbors retain their zero-initialized gradient - if n_neighbors == 0: - continue - - ### Function differences (b vector) - b = values[neighbors_flat] - values[entity_indices].unsqueeze(1) - - ### Compute inverse-distance weights - distances = torch.linalg.vector_norm(A, dim=-1) # (n_group, n_neighbors) - weights = 1.0 / distances.pow(weight_power).clamp(min=safe_eps(distances.dtype)) - - ### Apply sqrt-weights to system - sqrt_w = weights.sqrt().unsqueeze(-1) # (n_group, n_neighbors, 1) - A_weighted = sqrt_w * A # (n_group, n_neighbors, n_spatial_dims) - - ### Solve batched least-squares - if is_scalar: - b_weighted = sqrt_w.squeeze(-1) * b # (n_group, n_neighbors) - solution = torch.linalg.lstsq( - A_weighted, b_weighted.unsqueeze(-1), rcond=None - ).solution.squeeze(-1) # (n_group, n_spatial_dims) - - gradients[entity_indices] = solution - else: - # Tensor field: flatten extra dims, solve, reshape back - b_weighted = sqrt_w * b # (n_group, n_neighbors, ...) - orig_shape = b.shape[2:] - b_flat = b_weighted.reshape(n_group, n_neighbors, -1) - - solution = torch.linalg.lstsq( - A_weighted, b_flat, rcond=None - ).solution # (n_group, n_spatial_dims, n_components) - - solution_reshaped = solution.reshape(n_group, n_spatial_dims, *orig_shape) - # Permute spatial_dims to second position - perm = [0] + list(range(2, solution_reshaped.ndim)) + [1] - gradients[entity_indices] = solution_reshaped.permute(*perm) - - return gradients + """Convert functional layout ``(n, dims, ...)`` to mesh layout ``(n, ..., dims)``.""" + if values.ndim == 1: + return gradients + perm = [0] + list(range(2, gradients.ndim)) + [1] + return gradients.permute(*perm) def compute_point_gradient_lsq( @@ -184,14 +94,21 @@ def compute_point_gradient_lsq( ### Get point-to-point adjacency adjacency = mesh.get_point_to_points_adjacency() - ### Use shared batched LSQ solver - return _solve_batched_lsq_gradients( - positions=mesh.points, + ### Delegate LSQ solve to the functional API using the torch backend. + from physicsnemo.nn.functional.derivatives.mesh_lsq_gradient import ( + mesh_lsq_gradient, + ) + + gradients = mesh_lsq_gradient( + points=mesh.points, values=point_values, - adjacency=adjacency, + neighbor_offsets=adjacency.offsets, + neighbor_indices=adjacency.indices, weight_power=weight_power, min_neighbors=min_neighbors, + implementation="torch", ) + return _to_mesh_gradient_layout(gradients, point_values) def compute_cell_gradient_lsq( @@ -230,11 +147,18 @@ def compute_cell_gradient_lsq( ### Get cell centroids cell_centroids = mesh.cell_centroids # (n_cells, n_spatial_dims) - ### Use shared batched LSQ solver - return _solve_batched_lsq_gradients( - positions=cell_centroids, + ### Delegate LSQ solve to the functional API using the torch backend. + from physicsnemo.nn.functional.derivatives.mesh_lsq_gradient import ( + mesh_lsq_gradient, + ) + + gradients = mesh_lsq_gradient( + points=cell_centroids, values=cell_values, - adjacency=adjacency, + neighbor_offsets=adjacency.offsets, + neighbor_indices=adjacency.indices, weight_power=weight_power, - min_neighbors=0, # Cells may have fewer neighbors than points + min_neighbors=0, # Cells may have fewer neighbors than points. + implementation="torch", ) + return _to_mesh_gradient_layout(gradients, cell_values) diff --git a/physicsnemo/nn/functional/__init__.py b/physicsnemo/nn/functional/__init__.py index fe3b64cb46..e2ac5e036f 100644 --- a/physicsnemo/nn/functional/__init__.py +++ b/physicsnemo/nn/functional/__init__.py @@ -14,6 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .derivatives import ( + mesh_green_gauss_gradient, + mesh_lsq_gradient, + meshless_fd_derivatives, + rectilinear_grid_gradient, + spectral_grid_gradient, + uniform_grid_gradient, +) from .equivariant_ops import ( legendre_polynomials, polar_and_dipole_basis, @@ -36,17 +44,23 @@ "interpolation", "knn", "legendre_polynomials", + "mesh_green_gauss_gradient", + "meshless_fd_derivatives", + "mesh_lsq_gradient", "na1d", "na2d", "na3d", "polar_and_dipole_basis", "radius_search", "real", + "rectilinear_grid_gradient", "rfft", "rfft2", "signed_distance_field", "smooth_log", + "spectral_grid_gradient", "spherical_basis", + "uniform_grid_gradient", "vector_project", "view_as_complex", "weight_fact", diff --git a/physicsnemo/nn/functional/derivatives/__init__.py b/physicsnemo/nn/functional/derivatives/__init__.py new file mode 100644 index 0000000000..0306f0e6c8 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/__init__.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mesh_green_gauss_gradient import MeshGreenGaussGradient, mesh_green_gauss_gradient +from .mesh_lsq_gradient import MeshLSQGradient, mesh_lsq_gradient +from .meshless_finite_difference import ( + MeshlessFDDerivatives, + meshless_fd_derivatives, +) +from .rectilinear_grid_gradient import ( + RectilinearGridGradient, + rectilinear_grid_gradient, +) +from .spectral_grid_gradient import SpectralGridGradient, spectral_grid_gradient +from .uniform_grid_gradient import UniformGridGradient, uniform_grid_gradient + +__all__ = [ + "MeshGreenGaussGradient", + "MeshlessFDDerivatives", + "MeshLSQGradient", + "RectilinearGridGradient", + "SpectralGridGradient", + "UniformGridGradient", + "mesh_green_gauss_gradient", + "meshless_fd_derivatives", + "mesh_lsq_gradient", + "rectilinear_grid_gradient", + "spectral_grid_gradient", + "uniform_grid_gradient", +] diff --git a/physicsnemo/nn/functional/derivatives/_request_utils.py b/physicsnemo/nn/functional/derivatives/_request_utils.py new file mode 100644 index 0000000000..ae028ae1c6 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/_request_utils.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +_SUPPORTED_DERIVATIVE_ORDERS = (1, 2) + + +def _normalize_orders_core( + derivative_orders: int | Sequence[int], + *, + function_name: str, +) -> tuple[int, ...]: + """Normalize a derivative-order request into canonical order.""" + if isinstance(derivative_orders, bool): + raise TypeError( + f"{function_name} derivative_orders must be int or sequence[int], got bool" + ) + + if isinstance(derivative_orders, int): + orders = (int(derivative_orders),) + elif isinstance(derivative_orders, Sequence): + orders = tuple(derivative_orders) + else: + raise TypeError( + f"{function_name} derivative_orders must be int or sequence[int], " + f"got {type(derivative_orders)}" + ) + + if len(orders) == 0: + raise ValueError(f"{function_name} derivative_orders cannot be empty") + + normalized: set[int] = set() + for order in orders: + if isinstance(order, bool) or not isinstance(order, int): + raise TypeError( + f"{function_name} derivative_orders entries must be integers, got {type(order)}" + ) + if order not in _SUPPORTED_DERIVATIVE_ORDERS: + raise ValueError( + f"{function_name} supports derivative orders {list(_SUPPORTED_DERIVATIVE_ORDERS)}, " + f"got {order}" + ) + normalized.add(order) + + return tuple(order for order in _SUPPORTED_DERIVATIVE_ORDERS if order in normalized) + + +def normalize_derivative_orders( + *, + derivative_orders: int | Sequence[int], + function_name: str, +) -> tuple[int, ...]: + """Resolve derivative-order request to canonical tuple.""" + return _normalize_orders_core(derivative_orders, function_name=function_name) + + +def normalize_include_mixed( + *, + include_mixed: bool, + function_name: str, +) -> bool: + """Validate mixed-derivative flag.""" + if not isinstance(include_mixed, bool): + raise TypeError(f"{function_name} include_mixed must be a bool") + return include_mixed + + +def validate_mixed_request( + *, + derivative_orders: tuple[int, ...], + include_mixed: bool, + ndim: int, + function_name: str, +) -> None: + """Validate that mixed-derivative requests are structurally valid.""" + if include_mixed and 2 not in derivative_orders: + raise ValueError( + f"{function_name} include_mixed is only valid when requesting 2nd derivatives" + ) + if include_mixed and ndim < 2: + raise ValueError( + f"{function_name} mixed derivatives require at least 2D inputs" + ) diff --git a/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/__init__.py b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/__init__.py new file mode 100644 index 0000000000..f829c083c9 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mesh_green_gauss_gradient import MeshGreenGaussGradient, mesh_green_gauss_gradient + +__all__ = ["MeshGreenGaussGradient", "mesh_green_gauss_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/_torch_impl.py b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/_torch_impl.py new file mode 100644 index 0000000000..68a064f751 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/_torch_impl.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from .utils import validate_inputs + + +def mesh_green_gauss_gradient_torch( + points: torch.Tensor, + cells: torch.Tensor, + neighbors: torch.Tensor, + values: torch.Tensor, +) -> torch.Tensor: + """Compute Green-Gauss cell gradients with eager PyTorch tensor ops.""" + ### Validate mesh/value tensors and geometry compatibility. + validate_inputs(points=points, cells=cells, neighbors=neighbors, values=values) + + n_cells = cells.shape[0] + dims = points.shape[1] + n_faces = cells.shape[1] + value_shape = values.shape[1:] + values_flat = values.reshape(n_cells, -1) + n_components = values_flat.shape[1] + + cells_i64 = cells.to(dtype=torch.int64) + neighbors_i64 = neighbors.to(dtype=torch.int64) + + cell_points = points[cells_i64] + centroids = cell_points.mean(dim=1) + + if dims == 2: + p0, p1, p2 = cell_points[:, 0], cell_points[:, 1], cell_points[:, 2] + cell_volume = 0.5 * torch.abs( + (p1[:, 0] - p0[:, 0]) * (p2[:, 1] - p0[:, 1]) + - (p1[:, 1] - p0[:, 1]) * (p2[:, 0] - p0[:, 0]) + ) + else: + p0, p1, p2, p3 = ( + cell_points[:, 0], + cell_points[:, 1], + cell_points[:, 2], + cell_points[:, 3], + ) + cell_volume = ( + torch.abs( + torch.einsum("bi,bi->b", p1 - p0, torch.cross(p2 - p0, p3 - p0, dim=-1)) + ) + / 6.0 + ) + cell_volume = torch.clamp(cell_volume, min=1.0e-12) + + grad_flat = torch.zeros( + (n_cells, dims, n_components), + device=values.device, + dtype=values.dtype, + ) + + ### Accumulate Green-Gauss face fluxes into per-cell gradients. + for face_idx in range(n_faces): + face_local = [idx for idx in range(n_faces) if idx != face_idx] + verts = cell_points[:, face_local, :] + + if dims == 2: + va = verts[:, 0] + vb = verts[:, 1] + edge = vb - va + normal = torch.stack((edge[:, 1], -edge[:, 0]), dim=-1) + face_center = 0.5 * (va + vb) + else: + va = verts[:, 0] + vb = verts[:, 1] + vc = verts[:, 2] + normal = 0.5 * torch.cross(vb - va, vc - va, dim=-1) + face_center = (va + vb + vc) / 3.0 + + to_face = face_center - centroids + sign = torch.where( + torch.einsum("bi,bi->b", normal, to_face) >= 0.0, + 1.0, + -1.0, + ).unsqueeze(-1) + coeff = (sign * normal) / cell_volume.unsqueeze(-1) + coeff = coeff.to(dtype=values.dtype) + + neigh = neighbors_i64[:, face_idx] + face_values = values_flat + interior = neigh >= 0 + if torch.any(interior): + face_values = values_flat.clone() + face_values[interior] = 0.5 * ( + values_flat[interior] + values_flat[neigh[interior]] + ) + + grad_flat = grad_flat + coeff.unsqueeze(-1) * face_values.unsqueeze(1) + + ### Restore gradient output layout. + return grad_flat.reshape(n_cells, dims, *value_shape) diff --git a/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/_warp_impl.py b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/_warp_impl.py new file mode 100644 index 0000000000..ffb76efa51 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/_warp_impl.py @@ -0,0 +1,548 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import torch +import warp as wp + +from physicsnemo.core.function_spec import FunctionSpec + +from .utils import validate_inputs + +### Warp runtime initialization for custom kernels. +wp.init() +wp.config.quiet = True + + +@wp.func +def _point2(points: wp.array2d(dtype=wp.float32), idx: int) -> wp.vec2f: + return wp.vec2f(points[idx, 0], points[idx, 1]) + + +@wp.func +def _point3(points: wp.array2d(dtype=wp.float32), idx: int) -> wp.vec3f: + return wp.vec3f(points[idx, 0], points[idx, 1], points[idx, 2]) + + +@wp.func +def _triangle_face_coeff( + points: wp.array2d(dtype=wp.float32), + cells: wp.array2d(dtype=wp.int32), + cell_idx: int, + face_idx: int, +) -> wp.vec2f: + i0 = cells[cell_idx, 0] + i1 = cells[cell_idx, 1] + i2 = cells[cell_idx, 2] + + p0 = _point2(points, i0) + p1 = _point2(points, i1) + p2 = _point2(points, i2) + + centroid = (p0 + p1 + p2) / 3.0 + area = 0.5 * wp.abs( + (p1[0] - p0[0]) * (p2[1] - p0[1]) - (p1[1] - p0[1]) * (p2[0] - p0[0]) + ) + inv_area = 1.0 / wp.max(area, 1.0e-12) + + va = p1 + vb = p2 + if face_idx == 1: + va = p0 + vb = p2 + elif face_idx == 2: + va = p0 + vb = p1 + + edge = vb - va + normal = wp.vec2f(edge[1], -edge[0]) + face_center = 0.5 * (va + vb) + to_face = face_center - centroid + if wp.dot(normal, to_face) < 0.0: + normal = -normal + + return normal * inv_area + + +@wp.func +def _tetra_face_coeff( + points: wp.array2d(dtype=wp.float32), + cells: wp.array2d(dtype=wp.int32), + cell_idx: int, + face_idx: int, +) -> wp.vec3f: + i0 = cells[cell_idx, 0] + i1 = cells[cell_idx, 1] + i2 = cells[cell_idx, 2] + i3 = cells[cell_idx, 3] + + p0 = _point3(points, i0) + p1 = _point3(points, i1) + p2 = _point3(points, i2) + p3 = _point3(points, i3) + + centroid = 0.25 * (p0 + p1 + p2 + p3) + volume = wp.abs(wp.dot(p1 - p0, wp.cross(p2 - p0, p3 - p0))) / 6.0 + inv_volume = 1.0 / wp.max(volume, 1.0e-12) + + va = p1 + vb = p2 + vc = p3 + if face_idx == 1: + va = p0 + vb = p2 + vc = p3 + elif face_idx == 2: + va = p0 + vb = p1 + vc = p3 + elif face_idx == 3: + va = p0 + vb = p1 + vc = p2 + + normal = 0.5 * wp.cross(vb - va, vc - va) + face_center = (va + vb + vc) / 3.0 + to_face = face_center - centroid + if wp.dot(normal, to_face) < 0.0: + normal = -normal + + return normal * inv_volume + + +@wp.kernel +def _mesh_green_gauss_2d_forward_kernel( + points: wp.array2d(dtype=wp.float32), + cells: wp.array2d(dtype=wp.int32), + neighbors: wp.array2d(dtype=wp.int32), + values: wp.array2d(dtype=wp.float32), + gradients: wp.array3d(dtype=wp.float32), +): + i, comp = wp.tid() + + vi = values[i, comp] + gx = float(0.0) + gy = float(0.0) + for f in range(3): + coeff = _triangle_face_coeff(points, cells, i, f) + j = neighbors[i, f] + + phi_f = vi + if j >= 0: + phi_f = 0.5 * (vi + values[j, comp]) + + gx = gx + coeff[0] * phi_f + gy = gy + coeff[1] * phi_f + + gradients[i, 0, comp] = gx + gradients[i, 1, comp] = gy + + +@wp.kernel +def _mesh_green_gauss_3d_forward_kernel( + points: wp.array2d(dtype=wp.float32), + cells: wp.array2d(dtype=wp.int32), + neighbors: wp.array2d(dtype=wp.int32), + values: wp.array2d(dtype=wp.float32), + gradients: wp.array3d(dtype=wp.float32), +): + i, comp = wp.tid() + + vi = values[i, comp] + gx = float(0.0) + gy = float(0.0) + gz = float(0.0) + for f in range(4): + coeff = _tetra_face_coeff(points, cells, i, f) + j = neighbors[i, f] + + phi_f = vi + if j >= 0: + phi_f = 0.5 * (vi + values[j, comp]) + + gx = gx + coeff[0] * phi_f + gy = gy + coeff[1] * phi_f + gz = gz + coeff[2] * phi_f + + gradients[i, 0, comp] = gx + gradients[i, 1, comp] = gy + gradients[i, 2, comp] = gz + + +@wp.kernel +def _mesh_green_gauss_2d_backward_kernel( + points: wp.array2d(dtype=wp.float32), + cells: wp.array2d(dtype=wp.int32), + neighbors: wp.array2d(dtype=wp.int32), + grad_output: wp.array3d(dtype=wp.float32), + grad_values: wp.array2d(dtype=wp.float32), +): + i, f, comp = wp.tid() + j = neighbors[i, f] + coeff = _triangle_face_coeff(points, cells, i, f) + + dot_go = grad_output[i, 0, comp] * coeff[0] + grad_output[i, 1, comp] * coeff[1] + owner_contrib = dot_go + if j >= 0: + owner_contrib = 0.5 * dot_go + + wp.atomic_add(grad_values, i, comp, owner_contrib) + if j >= 0: + wp.atomic_add(grad_values, j, comp, 0.5 * dot_go) + + +@wp.kernel +def _mesh_green_gauss_3d_backward_kernel( + points: wp.array2d(dtype=wp.float32), + cells: wp.array2d(dtype=wp.int32), + neighbors: wp.array2d(dtype=wp.int32), + grad_output: wp.array3d(dtype=wp.float32), + grad_values: wp.array2d(dtype=wp.float32), +): + i, f, comp = wp.tid() + j = neighbors[i, f] + coeff = _tetra_face_coeff(points, cells, i, f) + + dot_go = ( + grad_output[i, 0, comp] * coeff[0] + + grad_output[i, 1, comp] * coeff[1] + + grad_output[i, 2, comp] * coeff[2] + ) + owner_contrib = dot_go + if j >= 0: + owner_contrib = 0.5 * dot_go + + wp.atomic_add(grad_values, i, comp, owner_contrib) + if j >= 0: + wp.atomic_add(grad_values, j, comp, 0.5 * dot_go) + + +def _launch_forward( + *, + points_fp32: torch.Tensor, + cells_i32: torch.Tensor, + neighbors_i32: torch.Tensor, + values_flat_fp32: torch.Tensor, + grads_flat: torch.Tensor, + dims: int, + wp_device, + wp_stream, +) -> None: + ### Launch a single forward kernel across cells and value components. + kernel = ( + _mesh_green_gauss_2d_forward_kernel + if dims == 2 + else _mesh_green_gauss_3d_forward_kernel + ) + + n_cells = values_flat_fp32.shape[0] + n_components = values_flat_fp32.shape[1] + + with wp.ScopedStream(wp_stream): + wp.launch( + kernel=kernel, + dim=(n_cells, n_components), + inputs=[ + wp.from_torch(points_fp32, dtype=wp.float32), + wp.from_torch(cells_i32, dtype=wp.int32), + wp.from_torch(neighbors_i32, dtype=wp.int32), + wp.from_torch(values_flat_fp32, dtype=wp.float32), + wp.from_torch(grads_flat, dtype=wp.float32), + ], + device=wp_device, + stream=wp_stream, + ) + + +def _launch_backward( + *, + points_fp32: torch.Tensor, + cells_i32: torch.Tensor, + neighbors_i32: torch.Tensor, + grad_output_components_fp32: torch.Tensor, + grad_values_flat: torch.Tensor, + dims: int, + wp_device, + wp_stream, +) -> None: + ### Launch a single backward kernel across cells, faces, and value components. + kernel = ( + _mesh_green_gauss_2d_backward_kernel + if dims == 2 + else _mesh_green_gauss_3d_backward_kernel + ) + + n_cells = grad_output_components_fp32.shape[0] + n_faces = neighbors_i32.shape[1] + n_components = grad_output_components_fp32.shape[2] + + with wp.ScopedStream(wp_stream): + wp.launch( + kernel=kernel, + dim=(n_cells, n_faces, n_components), + inputs=[ + wp.from_torch(points_fp32, dtype=wp.float32), + wp.from_torch(cells_i32, dtype=wp.int32), + wp.from_torch(neighbors_i32, dtype=wp.int32), + wp.from_torch(grad_output_components_fp32, dtype=wp.float32), + wp.from_torch(grad_values_flat, dtype=wp.float32), + ], + device=wp_device, + stream=wp_stream, + ) + + +def _launch_backward_with_tape( + *, + points_fp32: torch.Tensor, + cells_i32: torch.Tensor, + neighbors_i32: torch.Tensor, + values_flat_fp32: torch.Tensor, + grad_output_components_fp32: torch.Tensor, + dims: int, + needs_points: bool, + needs_values: bool, + wp_device, + wp_stream, +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Run Warp Tape autodiff for gradients w.r.t. points and/or values.""" + kernel = ( + _mesh_green_gauss_2d_forward_kernel + if dims == 2 + else _mesh_green_gauss_3d_forward_kernel + ) + + n_cells = values_flat_fp32.shape[0] + n_components = values_flat_fp32.shape[1] + grads_flat = torch.empty( + (n_cells, dims, n_components), + device=values_flat_fp32.device, + dtype=torch.float32, + ) + + with wp.ScopedStream(wp_stream): + with wp.Tape() as tape: + wp_points = wp.from_torch( + points_fp32, dtype=wp.float32, requires_grad=needs_points + ) + wp_cells = wp.from_torch(cells_i32, dtype=wp.int32) + wp_neighbors = wp.from_torch(neighbors_i32, dtype=wp.int32) + wp_values = wp.from_torch( + values_flat_fp32, dtype=wp.float32, requires_grad=needs_values + ) + wp_grads = wp.from_torch(grads_flat, dtype=wp.float32, requires_grad=True) + + wp.launch( + kernel=kernel, + dim=(n_cells, n_components), + inputs=[ + wp_points, + wp_cells, + wp_neighbors, + wp_values, + wp_grads, + ], + device=wp_device, + stream=wp_stream, + ) + + grad_map = { + wp_grads: wp.from_torch(grad_output_components_fp32, dtype=wp.float32) + } + tape.backward(grads=grad_map) + + grad_points = None + grad_values = None + if needs_points: + grad_points = wp.to_torch(tape.gradients[wp_points]) + if needs_values: + grad_values = wp.to_torch(tape.gradients[wp_values]) + return grad_points, grad_values + + +@torch.library.custom_op( + "physicsnemo::mesh_green_gauss_gradient_warp_impl", mutates_args=() +) +def mesh_green_gauss_gradient_impl( + points: torch.Tensor, + cells: torch.Tensor, + neighbors: torch.Tensor, + values: torch.Tensor, +) -> torch.Tensor: + """Compute Green-Gauss cell-centered gradients with Warp kernels.""" + validate_inputs(points=points, cells=cells, neighbors=neighbors, values=values) + + points_fp32 = points.to(dtype=torch.float32).contiguous() + cells_i32 = cells.to(dtype=torch.int32).contiguous() + neighbors_i32 = neighbors.to(dtype=torch.int32).contiguous() + values_fp32 = values.to(dtype=torch.float32).contiguous() + + n_cells = values.shape[0] + dims = points.shape[1] + value_shape = values.shape[1:] + values_flat_fp32 = values_fp32.reshape(n_cells, -1).contiguous() + + grads_flat = torch.empty( + (n_cells, dims, values_flat_fp32.shape[1]), + device=values.device, + dtype=torch.float32, + ) + + wp_device, wp_stream = FunctionSpec.warp_launch_context(values_fp32) + _launch_forward( + points_fp32=points_fp32, + cells_i32=cells_i32, + neighbors_i32=neighbors_i32, + values_flat_fp32=values_flat_fp32, + grads_flat=grads_flat, + dims=dims, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + output = grads_flat.reshape(n_cells, dims, *value_shape) + if output.dtype != values.dtype: + output = output.to(dtype=values.dtype) + return output + + +@mesh_green_gauss_gradient_impl.register_fake +def _mesh_green_gauss_gradient_impl_fake( + points: torch.Tensor, + cells: torch.Tensor, + neighbors: torch.Tensor, + values: torch.Tensor, +) -> torch.Tensor: + """Fake tensor propagation for Green-Gauss custom op.""" + _ = (cells, neighbors) + dims = points.shape[1] + return torch.empty( + (values.shape[0], dims, *values.shape[1:]), + device=values.device, + dtype=values.dtype, + ) + + +def setup_mesh_green_gauss_gradient_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for Green-Gauss custom-op autograd.""" + points, cells, neighbors, values = inputs + _ = output + values_fp32 = values.to(dtype=torch.float32).contiguous() + n_cells = values_fp32.shape[0] + ctx.save_for_backward( + points.to(dtype=torch.float32).contiguous(), + cells.to(dtype=torch.int32).contiguous(), + neighbors.to(dtype=torch.int32).contiguous(), + values_fp32.reshape(n_cells, -1).contiguous(), + ) + ctx.points_dtype = points.dtype + ctx.value_shape = values.shape + ctx.values_dtype = values.dtype + ctx.dims = points.shape[1] + + +def backward_mesh_green_gauss_gradient( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None, torch.Tensor | None]: + """Backward pass for Green-Gauss custom op.""" + needs_points = ctx.needs_input_grad[0] + needs_values = ctx.needs_input_grad[3] + if grad_output is None or (not needs_points and not needs_values): + return None, None, None, None + + points_fp32, cells_i32, neighbors_i32, values_flat_fp32 = ctx.saved_tensors + grad_output_fp32 = grad_output.to(dtype=torch.float32).contiguous() + + values_shape = ctx.value_shape + n_cells = values_shape[0] + value_shape = values_shape[1:] + n_components = math.prod(value_shape) if value_shape else 1 + + grad_output_components = grad_output_fp32.reshape(n_cells, ctx.dims, n_components) + grad_output_components = grad_output_components.contiguous() + + grad_points = None + grad_values_flat = torch.zeros( + (n_cells, n_components), + device=grad_output.device, + dtype=torch.float32, + ) + wp_device, wp_stream = FunctionSpec.warp_launch_context(grad_output_fp32) + if needs_points: + grad_points_fp32, grad_values_fp32 = _launch_backward_with_tape( + points_fp32=points_fp32, + cells_i32=cells_i32, + neighbors_i32=neighbors_i32, + values_flat_fp32=values_flat_fp32, + grad_output_components_fp32=grad_output_components, + dims=ctx.dims, + needs_points=needs_points, + needs_values=needs_values, + wp_device=wp_device, + wp_stream=wp_stream, + ) + if grad_points_fp32 is not None: + grad_points = grad_points_fp32 + if grad_points.dtype != ctx.points_dtype: + grad_points = grad_points.to(dtype=ctx.points_dtype) + if needs_values and grad_values_fp32 is not None: + grad_values_flat = grad_values_fp32 + elif needs_values: + _launch_backward( + points_fp32=points_fp32, + cells_i32=cells_i32, + neighbors_i32=neighbors_i32, + grad_output_components_fp32=grad_output_components, + grad_values_flat=grad_values_flat, + dims=ctx.dims, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + grad_values = None + if needs_values: + grad_values = grad_values_flat.reshape(values_shape) + if grad_values.dtype != ctx.values_dtype: + grad_values = grad_values.to(dtype=ctx.values_dtype) + return grad_points, None, None, grad_values + + +mesh_green_gauss_gradient_impl.register_autograd( + backward_mesh_green_gauss_gradient, + setup_context=setup_mesh_green_gauss_gradient_context, +) + + +def mesh_green_gauss_gradient_warp( + points: torch.Tensor, + cells: torch.Tensor, + neighbors: torch.Tensor, + values: torch.Tensor, +) -> torch.Tensor: + """Compute Green-Gauss cell gradients with Warp kernels. + + Notes + ----- + Warp kernels compute in ``float32`` internally. Inputs in wider floating + dtypes are accepted and cast to ``float32`` for compute. + """ + return mesh_green_gauss_gradient_impl(points, cells, neighbors, values) diff --git a/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/mesh_green_gauss_gradient.py b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/mesh_green_gauss_gradient.py new file mode 100644 index 0000000000..c248355d5e --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/mesh_green_gauss_gradient.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from ._torch_impl import mesh_green_gauss_gradient_torch +from ._warp_impl import mesh_green_gauss_gradient_warp +from .utils import build_neighbors + + +class MeshGreenGaussGradient(FunctionSpec): + r"""Compute cell-centered gradients using Green-Gauss face flux balances. + + This functional reconstructs gradients from cell-centered values on + simplicial meshes (2D triangles or 3D tetrahedra) using: + + .. math:: + + \nabla \phi_i \approx \frac{1}{V_i} + \sum_{f \in \partial i} \phi_f \, \mathbf{A}_{i,f} + + where :math:`V_i` is cell volume/area, :math:`\mathbf{A}_{i,f}` is outward + face-area vector, and face value :math:`\phi_f` uses centered interpolation + on interior faces: + + .. math:: + + \phi_f = \tfrac{1}{2}(\phi_i + \phi_j) + + while boundary faces use :math:`\phi_f=\phi_i`. + + Parameters + ---------- + points : torch.Tensor + Mesh point coordinates with shape ``(n_points, dims)`` for ``dims`` in + ``{2, 3}``. + cells : torch.Tensor + Simplicial connectivity with shape ``(n_cells, dims+1)``. + neighbors : torch.Tensor + Precomputed cell-neighbor indices with shape ``(n_cells, n_faces)``, + where boundary faces are marked with ``-1``. + values : torch.Tensor + Cell-centered values with shape ``(n_cells,)`` or ``(n_cells, ...)``. + implementation : {"warp", "torch"} or None + Explicit backend selection. When ``None``, dispatch selects by rank. + + Returns + ------- + torch.Tensor + Reconstructed gradients with shape ``(n_cells, dims)`` for scalar + values or ``(n_cells, dims, ...)`` for tensor values. + """ + + ### Benchmark input presets (small -> large workload). + _BENCHMARK_CASES = ( + ("2d-tri-24x24-scalar", 24, 24, False), + ("2d-tri-36x36-scalar", 36, 36, False), + ("2d-tri-36x36-vector", 36, 36, True), + ) + + _COMPARE_ATOL = 3e-4 + _COMPARE_RTOL = 3e-4 + _COMPARE_BACKWARD_ATOL = 8e-3 + _COMPARE_BACKWARD_RTOL = 8e-3 + + @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0) + def warp_forward( + points: torch.Tensor, + cells: torch.Tensor, + neighbors: torch.Tensor, + values: torch.Tensor, + ) -> torch.Tensor: + """Dispatch Green-Gauss gradients to the Warp backend.""" + ### Warp backend implementation. + return mesh_green_gauss_gradient_warp( + points=points, + cells=cells, + neighbors=neighbors, + values=values, + ) + + @FunctionSpec.register(name="torch", rank=1, baseline=True) + def torch_forward( + points: torch.Tensor, + cells: torch.Tensor, + neighbors: torch.Tensor, + values: torch.Tensor, + ) -> torch.Tensor: + """Dispatch Green-Gauss gradients to eager PyTorch.""" + ### PyTorch backend implementation. + return mesh_green_gauss_gradient_torch( + points=points, + cells=cells, + neighbors=neighbors, + values=values, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build deterministic triangulated meshes and scalar/vector cell values. + for label, nx, ny, vector_values in cls._BENCHMARK_CASES: + ### Construct a structured point set and split each quad into two triangles. + x = torch.linspace(0.0, 1.0, nx, device=device, dtype=torch.float32) + y = torch.linspace(0.0, 1.0, ny, device=device, dtype=torch.float32) + xx, yy = torch.meshgrid(x, y, indexing="ij") + points = torch.stack((xx.reshape(-1), yy.reshape(-1)), dim=-1) + + ### Add small deterministic interior jitter for non-degenerate geometry. + rng = torch.Generator(device=device) + rng.manual_seed(2027 + nx + ny) + perturb = 0.01 * ( + torch.rand( + points.shape, generator=rng, device=device, dtype=points.dtype + ) + - 0.5 + ) + border = ( + (points[:, 0] <= 0.0) + | (points[:, 0] >= 1.0) + | (points[:, 1] <= 0.0) + | (points[:, 1] >= 1.0) + ) + perturb[border] = 0.0 + points = torch.clamp(points + perturb, 0.0, 1.0).contiguous() + + cells = [] + for i in range(nx - 1): + for j in range(ny - 1): + p00 = i * ny + j + p10 = (i + 1) * ny + j + p01 = i * ny + (j + 1) + p11 = (i + 1) * ny + (j + 1) + cells.append((p00, p10, p11)) + cells.append((p00, p11, p01)) + cells = torch.tensor(cells, device=device, dtype=torch.int64).contiguous() + neighbors = build_neighbors(cells).to(dtype=torch.int64).contiguous() + centroids = points[cells].mean(dim=1) + base = ( + torch.sin(2.0 * torch.pi * centroids[:, 0]) + + 0.4 * torch.cos(2.0 * torch.pi * centroids[:, 1]) + ).to(torch.float32) + + if vector_values: + values = torch.stack( + ( + base, + torch.cos(2.0 * torch.pi * centroids[:, 0] - 0.2), + torch.sin(2.0 * torch.pi * centroids[:, 1] + 0.3), + ), + dim=-1, + ).to(torch.float32) + else: + values = base + + yield ( + label, + ( + points, + cells, + neighbors, + values, + ), + {}, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build differentiable cell-value inputs for backward parity checks. + backward_cases = ( + ("backward-2d-tri-24x24-scalar", 24, 24, False), + ("backward-2d-tri-32x32-vector", 32, 32, True), + ) + + for label, nx, ny, vector_values in backward_cases: + ### Construct a structured point set and split each quad into two triangles. + x = torch.linspace(0.0, 1.0, nx, device=device, dtype=torch.float32) + y = torch.linspace(0.0, 1.0, ny, device=device, dtype=torch.float32) + xx, yy = torch.meshgrid(x, y, indexing="ij") + points = torch.stack((xx.reshape(-1), yy.reshape(-1)), dim=-1) + + ### Add small deterministic interior jitter for non-degenerate geometry. + rng = torch.Generator(device=device) + rng.manual_seed(2027 + nx + ny) + perturb = 0.01 * ( + torch.rand( + points.shape, generator=rng, device=device, dtype=points.dtype + ) + - 0.5 + ) + border = ( + (points[:, 0] <= 0.0) + | (points[:, 0] >= 1.0) + | (points[:, 1] <= 0.0) + | (points[:, 1] >= 1.0) + ) + perturb[border] = 0.0 + points = torch.clamp(points + perturb, 0.0, 1.0).contiguous() + + cells = [] + for i in range(nx - 1): + for j in range(ny - 1): + p00 = i * ny + j + p10 = (i + 1) * ny + j + p01 = i * ny + (j + 1) + p11 = (i + 1) * ny + (j + 1) + cells.append((p00, p10, p11)) + cells.append((p00, p11, p01)) + cells = torch.tensor(cells, device=device, dtype=torch.int64).contiguous() + neighbors = build_neighbors(cells).to(dtype=torch.int64).contiguous() + centroids = points[cells].mean(dim=1) + base = ( + torch.sin(2.0 * torch.pi * centroids[:, 0]) + + 0.4 * torch.cos(2.0 * torch.pi * centroids[:, 1]) + ).to(torch.float32) + + if vector_values: + values = torch.stack( + ( + base, + torch.cos(2.0 * torch.pi * centroids[:, 0] - 0.2), + torch.sin(2.0 * torch.pi * centroids[:, 1] + 0.3), + ), + dim=-1, + ).to(torch.float32) + else: + values = base + + values = values.detach().clone().requires_grad_(True) + yield ( + label, + ( + points, + cells, + neighbors, + values, + ), + {}, + ) + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare forward outputs across implementations.""" + ### Validate forward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + @classmethod + def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare backward gradients across implementations.""" + ### Validate backward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_BACKWARD_ATOL, + rtol=cls._COMPARE_BACKWARD_RTOL, + ) + + +mesh_green_gauss_gradient = MeshGreenGaussGradient.make_function( + "mesh_green_gauss_gradient" +) + + +__all__ = ["MeshGreenGaussGradient", "mesh_green_gauss_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/utils.py b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/utils.py new file mode 100644 index 0000000000..8d8dd0b6f9 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_green_gauss_gradient/utils.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + + +def validate_inputs( + points: torch.Tensor, + cells: torch.Tensor, + neighbors: torch.Tensor, + values: torch.Tensor, +) -> None: + """Validate mesh, adjacency, and value tensors for Green-Gauss reconstruction.""" + ### Validate mesh point coordinates and supported spatial dimensionality. + if points.ndim != 2: + raise ValueError( + f"points must have shape (n_points, dims), got points.shape={points.shape}" + ) + if points.shape[1] not in (2, 3): + raise ValueError( + f"mesh_green_gauss_gradient supports dims in {{2, 3}}, got {points.shape[1]}" + ) + if not torch.is_floating_point(points): + raise TypeError("points must be a floating-point tensor") + + ### Validate simplicial connectivity and compatibility with spatial dims. + if cells.ndim != 2: + raise ValueError( + f"cells must have shape (n_cells, n_vertices), got cells.shape={cells.shape}" + ) + expected_vertices = points.shape[1] + 1 + if cells.shape[1] != expected_vertices: + raise ValueError( + f"cells must contain {expected_vertices} vertices per simplex for dims={points.shape[1]}, " + f"got {cells.shape[1]}" + ) + if cells.dtype not in (torch.int32, torch.int64): + raise TypeError("cells must be int32 or int64") + + ### Validate precomputed cell-neighbor tensor. + if neighbors.ndim != 2: + raise ValueError( + "neighbors must have shape (n_cells, n_faces), " + f"got neighbors.shape={neighbors.shape}" + ) + n_cells = cells.shape[0] + n_faces = cells.shape[1] + if neighbors.shape != (n_cells, n_faces): + raise ValueError( + "neighbors shape must match (n_cells, n_faces): " + f"expected ({n_cells}, {n_faces}), got {tuple(neighbors.shape)}" + ) + if neighbors.dtype not in (torch.int32, torch.int64): + raise TypeError("neighbors must be int32 or int64") + + ### Validate cell-centered values tensor. + if values.ndim < 1: + raise ValueError( + f"values must have shape (n_cells, ...), got values.shape={values.shape}" + ) + if values.shape[0] != n_cells: + raise ValueError( + f"values leading dimension must match n_cells: {values.shape[0]} != {n_cells}" + ) + if not torch.is_floating_point(values): + raise TypeError("values must be a floating-point tensor") + + ### Validate co-located tensors and index range invariants. + if ( + points.device != cells.device + or points.device != neighbors.device + or points.device != values.device + ): + raise ValueError( + "points, cells, neighbors, and values must be on the same device" + ) + if cells.numel() > 0: + idx_min = int(cells.min().item()) + idx_max = int(cells.max().item()) + if idx_min < 0 or idx_max >= points.shape[0]: + raise ValueError( + f"cells indices must satisfy 0 <= index < n_points ({points.shape[0]})" + ) + if neighbors.numel() > 0: + neigh_min = int(neighbors.min().item()) + neigh_max = int(neighbors.max().item()) + if neigh_min < -1 or neigh_max >= n_cells: + raise ValueError( + "neighbors entries must satisfy -1 <= index < n_cells " + f"({n_cells}); got [{neigh_min}, {neigh_max}]" + ) + + +def build_neighbors(cells: torch.Tensor) -> torch.Tensor: + """Build simplicial face-neighbor adjacency as ``(n_cells, n_faces)``. + + Face ``f`` corresponds to the simplex face opposite local vertex ``f``. + Boundary faces are marked with ``-1``. + """ + if cells.ndim != 2: + raise ValueError( + f"cells must have shape (n_cells, n_vertices), got {cells.shape=}" + ) + if cells.dtype not in (torch.int32, torch.int64): + raise TypeError("cells must be int32 or int64") + + n_cells, n_vertices = cells.shape + neighbors = torch.full( + (n_cells, n_vertices), -1, device=cells.device, dtype=torch.int64 + ) + + open_faces: dict[tuple[int, ...], tuple[int, int]] = {} + cells_cpu = cells.to(dtype=torch.int64).detach().cpu().tolist() + + for cell_idx, cell in enumerate(cells_cpu): + for face_idx in range(n_vertices): + face_verts = tuple(int(cell[v]) for v in range(n_vertices) if v != face_idx) + key = tuple(sorted(face_verts)) + if key in open_faces: + other_cell, other_face = open_faces.pop(key) + neighbors[cell_idx, face_idx] = other_cell + neighbors[other_cell, other_face] = cell_idx + else: + open_faces[key] = (cell_idx, face_idx) + + return neighbors diff --git a/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/__init__.py b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/__init__.py new file mode 100644 index 0000000000..95424c6ec2 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mesh_lsq_gradient import MeshLSQGradient, mesh_lsq_gradient + +__all__ = ["MeshLSQGradient", "mesh_lsq_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/_torch_impl.py b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/_torch_impl.py new file mode 100644 index 0000000000..f5b33aaaf8 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/_torch_impl.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from .utils import resolve_safe_epsilon, validate_inputs + + +def mesh_lsq_gradient_torch( + points: torch.Tensor, + values: torch.Tensor, + neighbor_offsets: torch.Tensor, + neighbor_indices: torch.Tensor, + weight_power: float = 2.0, + min_neighbors: int = 0, + safe_epsilon: float | None = None, +) -> torch.Tensor: + """Compute weighted LSQ mesh gradients with PyTorch tensor ops.""" + ### Validate inputs before building LSQ systems. + validate_inputs( + points=points, + values=values, + neighbor_offsets=neighbor_offsets, + neighbor_indices=neighbor_indices, + min_neighbors=min_neighbors, + ) + + ### Normalize dtypes/layout for stable downstream linear algebra. + points = points.contiguous() + values = values.contiguous() + neighbor_offsets = neighbor_offsets.to( + dtype=torch.int64, device=points.device + ).contiguous() + neighbor_indices = neighbor_indices.to( + dtype=torch.int64, device=points.device + ).contiguous() + + n_entities = points.shape[0] + n_dims = points.shape[1] + value_shape = values.shape[1:] + counts = neighbor_offsets[1:] - neighbor_offsets[:-1] + + ### Flatten component dimensions so scalar and tensor fields share one solve path. + values_flat = values.reshape(n_entities, -1) + n_components = values_flat.shape[1] + gradients_flat = torch.zeros( + (n_entities, n_dims, n_components), + dtype=values.dtype, + device=values.device, + ) + + points_cast = points.to(dtype=values.dtype) + dist_eps = resolve_safe_epsilon(safe_epsilon=safe_epsilon, dtype=points_cast.dtype) + + ### Process one dense batch per neighbor-count group (mesh-module strategy). + unique_counts = torch.unique(counts) + for count_tensor in unique_counts: + n_neighbors = int(count_tensor.item()) + if n_neighbors < min_neighbors or n_neighbors == 0: + continue + + entity_indices = torch.where(counts == count_tensor)[0] + if entity_indices.numel() == 0: + continue + + offsets_group = neighbor_offsets[entity_indices] + col_range = torch.arange(n_neighbors, device=points.device, dtype=torch.int64) + flat_indices = offsets_group.unsqueeze(1) + col_range.unsqueeze(0) + neighbors = neighbor_indices[flat_indices].to(torch.long) + + center_points = points_cast[entity_indices] + relative = points_cast[neighbors] - center_points.unsqueeze(1) + + values_center = values_flat[entity_indices] + delta_values = values_flat[neighbors] - values_center.unsqueeze(1) + + dist2 = (relative * relative).sum(dim=-1).clamp_min(dist_eps) + sqrt_w = dist2.pow(-0.25 * weight_power).unsqueeze(-1) + + A_weighted = sqrt_w * relative + b_weighted = sqrt_w * delta_values + + solution = torch.linalg.lstsq( + A_weighted, + b_weighted, + rcond=None, + ).solution + gradients_flat[entity_indices] = solution + + ### Restore gradient output shape. + gradients = gradients_flat.reshape(n_entities, n_dims, *value_shape) + + return gradients diff --git a/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/_warp_impl.py b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/_warp_impl.py new file mode 100644 index 0000000000..64b58e421f --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/_warp_impl.py @@ -0,0 +1,872 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import warp as wp + +from physicsnemo.core.function_spec import FunctionSpec + +from .utils import resolve_safe_epsilon, validate_inputs + +### Warp runtime initialization for custom kernels. +wp.init() +wp.config.quiet = True + + +@wp.kernel +def _mesh_lsq_gradient_1d_kernel( + points: wp.array2d(dtype=wp.float32), + values: wp.array(dtype=wp.float32), + offsets: wp.array(dtype=wp.int32), + indices: wp.array(dtype=wp.int32), + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + gradients: wp.array2d(dtype=wp.float32), +): + i = wp.tid() + + # Read the CSR neighbor segment for this entity. + start = offsets[i] + end = offsets[i + 1] + count = end - start + if count < min_neighbors: + gradients[i, 0] = 0.0 + return + + # Gather center state and initialize normal-equation accumulators. + px = points[i, 0] + pval = values[i] + + m00 = float(reg_eps) + b0 = float(0.0) + + # Accumulate A^T W A and A^T W b over neighbors. + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dphi = values[n] - pval + + dist2 = dx * dx + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + + m00 = m00 + w * dx * dx + b0 = b0 + w * dx * dphi + + # Solve the 1x1 normal equation with a numerical floor. + gx = float(0.0) + if m00 > dist_eps: + gx = b0 / m00 + + gradients[i, 0] = gx + + +@wp.kernel +def _mesh_lsq_gradient_2d_kernel( + points: wp.array2d(dtype=wp.float32), + values: wp.array(dtype=wp.float32), + offsets: wp.array(dtype=wp.int32), + indices: wp.array(dtype=wp.int32), + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + gradients: wp.array2d(dtype=wp.float32), +): + i = wp.tid() + + # Read the CSR neighbor segment for this entity. + start = offsets[i] + end = offsets[i + 1] + count = end - start + if count < min_neighbors: + gradients[i, 0] = 0.0 + gradients[i, 1] = 0.0 + return + + # Gather center state and initialize normal-equation accumulators. + px = points[i, 0] + py = points[i, 1] + pval = values[i] + + m00 = float(reg_eps) + m01 = float(0.0) + m11 = float(reg_eps) + b0 = float(0.0) + b1 = float(0.0) + + # Accumulate A^T W A and A^T W b over neighbors. + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dy = points[n, 1] - py + dphi = values[n] - pval + + dist2 = dx * dx + dy * dy + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + + m00 = m00 + w * dx * dx + m01 = m01 + w * dx * dy + m11 = m11 + w * dy * dy + + b0 = b0 + w * dx * dphi + b1 = b1 + w * dy * dphi + + # Solve the 2x2 system analytically with determinant-based conditioning. + det = m00 * m11 - m01 * m01 + + gx = float(0.0) + gy = float(0.0) + stability_scale = m00 * m11 + dist_eps + if wp.abs(det) > 1.0e-6 * stability_scale: + inv00 = m11 / det + inv01 = -m01 / det + inv11 = m00 / det + gx = inv00 * b0 + inv01 * b1 + gy = inv01 * b0 + inv11 * b1 + + gradients[i, 0] = gx + gradients[i, 1] = gy + + +@wp.kernel +def _mesh_lsq_gradient_3d_kernel( + points: wp.array2d(dtype=wp.float32), + values: wp.array(dtype=wp.float32), + offsets: wp.array(dtype=wp.int32), + indices: wp.array(dtype=wp.int32), + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + gradients: wp.array2d(dtype=wp.float32), +): + i = wp.tid() + + # Read the CSR neighbor segment for this entity. + start = offsets[i] + end = offsets[i + 1] + count = end - start + if count < min_neighbors: + gradients[i, 0] = 0.0 + gradients[i, 1] = 0.0 + gradients[i, 2] = 0.0 + return + + # Gather center state and initialize normal-equation accumulators. + px = points[i, 0] + py = points[i, 1] + pz = points[i, 2] + pval = values[i] + + m00 = float(reg_eps) + m01 = float(0.0) + m02 = float(0.0) + m11 = float(reg_eps) + m12 = float(0.0) + m22 = float(reg_eps) + + b0 = float(0.0) + b1 = float(0.0) + b2 = float(0.0) + + # Accumulate A^T W A and A^T W b over neighbors. + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dy = points[n, 1] - py + dz = points[n, 2] - pz + dphi = values[n] - pval + + dist2 = dx * dx + dy * dy + dz * dz + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + + m00 = m00 + w * dx * dx + m01 = m01 + w * dx * dy + m02 = m02 + w * dx * dz + m11 = m11 + w * dy * dy + m12 = m12 + w * dy * dz + m22 = m22 + w * dz * dz + + b0 = b0 + w * dx * dphi + b1 = b1 + w * dy * dphi + b2 = b2 + w * dz * dphi + + # Build cofactors and solve the 3x3 system analytically. + c00 = m11 * m22 - m12 * m12 + c01 = -(m01 * m22 - m12 * m02) + c02 = m01 * m12 - m11 * m02 + c11 = m00 * m22 - m02 * m02 + c12 = -(m00 * m12 - m01 * m02) + c22 = m00 * m11 - m01 * m01 + + det = m00 * c00 + m01 * c01 + m02 * c02 + + gx = float(0.0) + gy = float(0.0) + gz = float(0.0) + trace = m00 + m11 + m22 + stability_scale = trace * trace * trace + dist_eps + if wp.abs(det) > 1.0e-8 * stability_scale: + inv_det = 1.0 / det + inv00 = c00 * inv_det + inv01 = c01 * inv_det + inv02 = c02 * inv_det + inv11 = c11 * inv_det + inv12 = c12 * inv_det + inv22 = c22 * inv_det + + gx = inv00 * b0 + inv01 * b1 + inv02 * b2 + gy = inv01 * b0 + inv11 * b1 + inv12 * b2 + gz = inv02 * b0 + inv12 * b1 + inv22 * b2 + + gradients[i, 0] = gx + gradients[i, 1] = gy + gradients[i, 2] = gz + + +@wp.kernel +def _mesh_lsq_gradient_1d_backward_kernel( + points: wp.array2d(dtype=wp.float32), + offsets: wp.array(dtype=wp.int32), + indices: wp.array(dtype=wp.int32), + grad_output: wp.array2d(dtype=wp.float32), + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + grad_values: wp.array(dtype=wp.float32), +): + i = wp.tid() + + start = offsets[i] + end = offsets[i + 1] + count = end - start + if count < min_neighbors: + return + + px = points[i, 0] + m00 = float(reg_eps) + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dist2 = dx * dx + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + m00 = m00 + w * dx * dx + + p0 = float(0.0) + if m00 > dist_eps: + p0 = grad_output[i, 0] / m00 + + self_contrib = float(0.0) + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dist2 = dx * dx + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + c = w * p0 * dx + wp.atomic_add(grad_values, n, c) + self_contrib = self_contrib - c + + wp.atomic_add(grad_values, i, self_contrib) + + +@wp.kernel +def _mesh_lsq_gradient_2d_backward_kernel( + points: wp.array2d(dtype=wp.float32), + offsets: wp.array(dtype=wp.int32), + indices: wp.array(dtype=wp.int32), + grad_output: wp.array2d(dtype=wp.float32), + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + grad_values: wp.array(dtype=wp.float32), +): + i = wp.tid() + + start = offsets[i] + end = offsets[i + 1] + count = end - start + if count < min_neighbors: + return + + px = points[i, 0] + py = points[i, 1] + m00 = float(reg_eps) + m01 = float(0.0) + m11 = float(reg_eps) + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dy = points[n, 1] - py + dist2 = dx * dx + dy * dy + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + m00 = m00 + w * dx * dx + m01 = m01 + w * dx * dy + m11 = m11 + w * dy * dy + + p0 = float(0.0) + p1 = float(0.0) + det = m00 * m11 - m01 * m01 + stability_scale = m00 * m11 + dist_eps + if wp.abs(det) > 1.0e-6 * stability_scale: + inv00 = m11 / det + inv01 = -m01 / det + inv11 = m00 / det + go0 = grad_output[i, 0] + go1 = grad_output[i, 1] + p0 = inv00 * go0 + inv01 * go1 + p1 = inv01 * go0 + inv11 * go1 + + self_contrib = float(0.0) + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dy = points[n, 1] - py + dist2 = dx * dx + dy * dy + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + c = w * (p0 * dx + p1 * dy) + wp.atomic_add(grad_values, n, c) + self_contrib = self_contrib - c + + wp.atomic_add(grad_values, i, self_contrib) + + +@wp.kernel +def _mesh_lsq_gradient_3d_backward_kernel( + points: wp.array2d(dtype=wp.float32), + offsets: wp.array(dtype=wp.int32), + indices: wp.array(dtype=wp.int32), + grad_output: wp.array2d(dtype=wp.float32), + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + grad_values: wp.array(dtype=wp.float32), +): + i = wp.tid() + + start = offsets[i] + end = offsets[i + 1] + count = end - start + if count < min_neighbors: + return + + px = points[i, 0] + py = points[i, 1] + pz = points[i, 2] + m00 = float(reg_eps) + m01 = float(0.0) + m02 = float(0.0) + m11 = float(reg_eps) + m12 = float(0.0) + m22 = float(reg_eps) + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dy = points[n, 1] - py + dz = points[n, 2] - pz + dist2 = dx * dx + dy * dy + dz * dz + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + m00 = m00 + w * dx * dx + m01 = m01 + w * dx * dy + m02 = m02 + w * dx * dz + m11 = m11 + w * dy * dy + m12 = m12 + w * dy * dz + m22 = m22 + w * dz * dz + + c00 = m11 * m22 - m12 * m12 + c01 = -(m01 * m22 - m12 * m02) + c02 = m01 * m12 - m11 * m02 + c11 = m00 * m22 - m02 * m02 + c12 = -(m00 * m12 - m01 * m02) + c22 = m00 * m11 - m01 * m01 + det = m00 * c00 + m01 * c01 + m02 * c02 + + p0 = float(0.0) + p1 = float(0.0) + p2 = float(0.0) + trace = m00 + m11 + m22 + stability_scale = trace * trace * trace + dist_eps + if wp.abs(det) > 1.0e-8 * stability_scale: + inv_det = 1.0 / det + inv00 = c00 * inv_det + inv01 = c01 * inv_det + inv02 = c02 * inv_det + inv11 = c11 * inv_det + inv12 = c12 * inv_det + inv22 = c22 * inv_det + go0 = grad_output[i, 0] + go1 = grad_output[i, 1] + go2 = grad_output[i, 2] + p0 = inv00 * go0 + inv01 * go1 + inv02 * go2 + p1 = inv01 * go0 + inv11 * go1 + inv12 * go2 + p2 = inv02 * go0 + inv12 * go1 + inv22 * go2 + + self_contrib = float(0.0) + for p in range(start, end): + n = indices[p] + dx = points[n, 0] - px + dy = points[n, 1] - py + dz = points[n, 2] - pz + dist2 = dx * dx + dy * dy + dz * dz + dist_eps + w = wp.pow(dist2, -0.5 * weight_power) + c = w * (p0 * dx + p1 * dy + p2 * dz) + wp.atomic_add(grad_values, n, c) + self_contrib = self_contrib - c + + wp.atomic_add(grad_values, i, self_contrib) + + +def _launch_forward( + *, + points_fp32: torch.Tensor, + values_flat_fp32: torch.Tensor, + offsets_i32: torch.Tensor, + indices_i32: torch.Tensor, + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + grads_components: torch.Tensor, + wp_device, + wp_stream, +) -> None: + ### Launch one LSQ forward kernel per value component. + n_dims = points_fp32.shape[1] + n_entities = points_fp32.shape[0] + n_components = values_flat_fp32.shape[1] + kernel = ( + _mesh_lsq_gradient_1d_kernel + if n_dims == 1 + else _mesh_lsq_gradient_2d_kernel + if n_dims == 2 + else _mesh_lsq_gradient_3d_kernel + ) + + wp_points = wp.from_torch(points_fp32, dtype=wp.float32) + wp_offsets = wp.from_torch(offsets_i32, dtype=wp.int32) + wp_indices = wp.from_torch(indices_i32, dtype=wp.int32) + + with wp.ScopedStream(wp_stream): + for comp in range(n_components): + wp.launch( + kernel=kernel, + dim=n_entities, + inputs=[ + wp_points, + wp.from_torch( + values_flat_fp32[:, comp].contiguous(), dtype=wp.float32 + ), + wp_offsets, + wp_indices, + float(weight_power), + int(min_neighbors), + float(reg_eps), + float(dist_eps), + wp.from_torch(grads_components[comp], dtype=wp.float32), + ], + device=wp_device, + stream=wp_stream, + ) + + +def _launch_backward( + *, + points_fp32: torch.Tensor, + offsets_i32: torch.Tensor, + indices_i32: torch.Tensor, + grad_output_components_fp32: torch.Tensor, + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + grad_values_flat: torch.Tensor, + wp_device, + wp_stream, +) -> None: + ### Launch one LSQ backward kernel per value component. + n_dims = points_fp32.shape[1] + n_entities = points_fp32.shape[0] + n_components = grad_output_components_fp32.shape[0] + kernel = ( + _mesh_lsq_gradient_1d_backward_kernel + if n_dims == 1 + else _mesh_lsq_gradient_2d_backward_kernel + if n_dims == 2 + else _mesh_lsq_gradient_3d_backward_kernel + ) + + wp_points = wp.from_torch(points_fp32, dtype=wp.float32) + wp_offsets = wp.from_torch(offsets_i32, dtype=wp.int32) + wp_indices = wp.from_torch(indices_i32, dtype=wp.int32) + + with wp.ScopedStream(wp_stream): + for comp in range(n_components): + comp_grad_values = torch.zeros( + (n_entities,), + device=grad_values_flat.device, + dtype=torch.float32, + ) + wp.launch( + kernel=kernel, + dim=n_entities, + inputs=[ + wp_points, + wp_offsets, + wp_indices, + wp.from_torch(grad_output_components_fp32[comp], dtype=wp.float32), + float(weight_power), + int(min_neighbors), + float(reg_eps), + float(dist_eps), + wp.from_torch(comp_grad_values, dtype=wp.float32), + ], + device=wp_device, + stream=wp_stream, + ) + grad_values_flat[:, comp] = comp_grad_values + + +def _launch_backward_with_tape( + *, + points_fp32: torch.Tensor, + values_flat_fp32: torch.Tensor, + offsets_i32: torch.Tensor, + indices_i32: torch.Tensor, + grad_output_components_fp32: torch.Tensor, + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, + needs_points: bool, + needs_values: bool, + wp_device, + wp_stream, +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Run Warp Tape autodiff for gradients w.r.t. points and/or values.""" + n_dims = points_fp32.shape[1] + n_entities = points_fp32.shape[0] + n_components = values_flat_fp32.shape[1] + kernel = ( + _mesh_lsq_gradient_1d_kernel + if n_dims == 1 + else _mesh_lsq_gradient_2d_kernel + if n_dims == 2 + else _mesh_lsq_gradient_3d_kernel + ) + + values_components = values_flat_fp32.transpose(0, 1).contiguous() + grads_components = torch.empty( + (n_components, n_entities, n_dims), + dtype=torch.float32, + device=points_fp32.device, + ) + + with wp.ScopedStream(wp_stream): + with wp.Tape() as tape: + wp_points = wp.from_torch( + points_fp32, dtype=wp.float32, requires_grad=needs_points + ) + wp_offsets = wp.from_torch(offsets_i32, dtype=wp.int32) + wp_indices = wp.from_torch(indices_i32, dtype=wp.int32) + wp_values_components: list[wp.array] = [] + wp_grads_components: list[wp.array] = [] + for comp in range(n_components): + wp_values = wp.from_torch( + values_components[comp], + dtype=wp.float32, + requires_grad=needs_values, + ) + wp_grads = wp.from_torch( + grads_components[comp], dtype=wp.float32, requires_grad=True + ) + wp.launch( + kernel=kernel, + dim=n_entities, + inputs=[ + wp_points, + wp_values, + wp_offsets, + wp_indices, + float(weight_power), + int(min_neighbors), + float(reg_eps), + float(dist_eps), + wp_grads, + ], + device=wp_device, + stream=wp_stream, + ) + wp_values_components.append(wp_values) + wp_grads_components.append(wp_grads) + + grad_map = { + wp_grads_components[comp]: wp.from_torch( + grad_output_components_fp32[comp].contiguous(), dtype=wp.float32 + ) + for comp in range(n_components) + } + tape.backward(grads=grad_map) + + grad_points = None + grad_values = None + if needs_points: + grad_points = wp.to_torch(tape.gradients[wp_points]) + if needs_values: + grad_values_components = torch.empty_like(values_components) + for comp in range(n_components): + grad_values_components[comp] = wp.to_torch( + tape.gradients[wp_values_components[comp]] + ) + grad_values = grad_values_components.transpose(0, 1).contiguous() + return grad_points, grad_values + + +@torch.library.custom_op("physicsnemo::mesh_lsq_gradient_warp_impl", mutates_args=()) +def mesh_lsq_gradient_impl( + points: torch.Tensor, + values: torch.Tensor, + neighbor_offsets: torch.Tensor, + neighbor_indices: torch.Tensor, + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, +) -> torch.Tensor: + """Compute weighted LSQ gradients with Warp kernels.""" + validate_inputs( + points=points, + values=values, + neighbor_offsets=neighbor_offsets, + neighbor_indices=neighbor_indices, + min_neighbors=int(min_neighbors), + ) + points_fp32 = points.to(dtype=torch.float32).contiguous() + values_fp32 = values.to(dtype=torch.float32).contiguous() + offsets_i32 = neighbor_offsets.to( + dtype=torch.int32, device=points.device + ).contiguous() + indices_i32 = neighbor_indices.to( + dtype=torch.int32, device=points.device + ).contiguous() + + n_entities = points_fp32.shape[0] + n_dims = points_fp32.shape[1] + value_shape = values.shape[1:] + values_flat = values_fp32.reshape(n_entities, -1) + n_components = values_flat.shape[1] + + ### Store component-wise output as (C, N, dims) for contiguous warp writes. + grads_components = torch.empty( + (n_components, n_entities, n_dims), + dtype=torch.float32, + device=points.device, + ) + + wp_device, wp_stream = FunctionSpec.warp_launch_context(points_fp32) + _launch_forward( + points_fp32=points_fp32, + values_flat_fp32=values_flat, + offsets_i32=offsets_i32, + indices_i32=indices_i32, + weight_power=float(weight_power), + min_neighbors=int(min_neighbors), + reg_eps=float(reg_eps), + dist_eps=float(dist_eps), + grads_components=grads_components, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + output = grads_components.permute(1, 2, 0).reshape(n_entities, n_dims, *value_shape) + if output.dtype != values.dtype: + output = output.to(dtype=values.dtype) + return output + + +@mesh_lsq_gradient_impl.register_fake +def _mesh_lsq_gradient_impl_fake( + points: torch.Tensor, + values: torch.Tensor, + neighbor_offsets: torch.Tensor, + neighbor_indices: torch.Tensor, + weight_power: float, + min_neighbors: int, + reg_eps: float, + dist_eps: float, +) -> torch.Tensor: + """Fake tensor propagation for LSQ custom op.""" + _ = ( + neighbor_offsets, + neighbor_indices, + weight_power, + min_neighbors, + reg_eps, + dist_eps, + ) + return torch.empty( + (values.shape[0], points.shape[1], *values.shape[1:]), + device=values.device, + dtype=values.dtype, + ) + + +def setup_mesh_lsq_gradient_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for LSQ custom-op autograd.""" + ( + points, + values, + neighbor_offsets, + neighbor_indices, + weight_power, + min_neighbors, + reg_eps, + dist_eps, + ) = inputs + _ = output + values_fp32 = values.to(dtype=torch.float32).contiguous() + n_entities = values_fp32.shape[0] + ctx.save_for_backward( + points.to(dtype=torch.float32).contiguous(), + values_fp32.reshape(n_entities, -1).contiguous(), + neighbor_offsets.to(dtype=torch.int32, device=points.device).contiguous(), + neighbor_indices.to(dtype=torch.int32, device=points.device).contiguous(), + ) + ctx.points_dtype = points.dtype + ctx.value_shape = values.shape + ctx.values_dtype = values.dtype + ctx.weight_power = float(weight_power) + ctx.min_neighbors = int(min_neighbors) + ctx.reg_eps = float(reg_eps) + ctx.dist_eps = float(dist_eps) + + +def backward_mesh_lsq_gradient( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[ + torch.Tensor | None, torch.Tensor | None, None, None, None, None, None, None +]: + """Backward pass for LSQ custom op.""" + needs_points = ctx.needs_input_grad[0] + needs_values = ctx.needs_input_grad[1] + if grad_output is None or (not needs_points and not needs_values): + return None, None, None, None, None, None, None, None + + points_fp32, values_flat_fp32, offsets_i32, indices_i32 = ctx.saved_tensors + grad_output_fp32 = grad_output.to(dtype=torch.float32).contiguous() + values_shape = ctx.value_shape + n_entities = values_shape[0] + value_shape = values_shape[1:] + n_components = int(torch.tensor(value_shape).prod().item()) if value_shape else 1 + + grad_output_components = grad_output_fp32.reshape( + n_entities, grad_output_fp32.shape[1], n_components + ) + grad_output_components = grad_output_components.permute(2, 0, 1).contiguous() + grad_points = None + grad_values_flat = torch.empty( + (n_entities, n_components), + device=grad_output.device, + dtype=torch.float32, + ) + wp_device, wp_stream = FunctionSpec.warp_launch_context(grad_output_fp32) + if needs_points: + grad_points_fp32, grad_values_fp32 = _launch_backward_with_tape( + points_fp32=points_fp32, + values_flat_fp32=values_flat_fp32, + offsets_i32=offsets_i32, + indices_i32=indices_i32, + grad_output_components_fp32=grad_output_components, + weight_power=ctx.weight_power, + min_neighbors=ctx.min_neighbors, + reg_eps=ctx.reg_eps, + dist_eps=ctx.dist_eps, + needs_points=needs_points, + needs_values=needs_values, + wp_device=wp_device, + wp_stream=wp_stream, + ) + if grad_points_fp32 is not None: + grad_points = grad_points_fp32 + if grad_points.dtype != ctx.points_dtype: + grad_points = grad_points.to(dtype=ctx.points_dtype) + if needs_values and grad_values_fp32 is not None: + grad_values_flat = grad_values_fp32 + elif needs_values: + _launch_backward( + points_fp32=points_fp32, + offsets_i32=offsets_i32, + indices_i32=indices_i32, + grad_output_components_fp32=grad_output_components, + weight_power=ctx.weight_power, + min_neighbors=ctx.min_neighbors, + reg_eps=ctx.reg_eps, + dist_eps=ctx.dist_eps, + grad_values_flat=grad_values_flat, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + grad_values = None + if needs_values: + grad_values = grad_values_flat.reshape(values_shape) + if grad_values.dtype != ctx.values_dtype: + grad_values = grad_values.to(dtype=ctx.values_dtype) + return grad_points, grad_values, None, None, None, None, None, None + + +mesh_lsq_gradient_impl.register_autograd( + backward_mesh_lsq_gradient, + setup_context=setup_mesh_lsq_gradient_context, +) + + +def mesh_lsq_gradient_warp( + points: torch.Tensor, + values: torch.Tensor, + neighbor_offsets: torch.Tensor, + neighbor_indices: torch.Tensor, + weight_power: float = 2.0, + min_neighbors: int = 0, + reg_eps: float = 1.0e-6, + safe_epsilon: float | None = None, +) -> torch.Tensor: + """Compute weighted LSQ mesh gradients with Warp kernels. + + Notes + ----- + Warp kernels compute in ``float32`` internally. Inputs in wider floating + dtypes are accepted and cast to ``float32`` for compute. + """ + dist_eps = resolve_safe_epsilon(safe_epsilon=safe_epsilon, dtype=torch.float32) + return mesh_lsq_gradient_impl( + points, + values, + neighbor_offsets, + neighbor_indices, + float(weight_power), + int(min_neighbors), + float(reg_eps), + float(dist_eps), + ) diff --git a/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/mesh_lsq_gradient.py b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/mesh_lsq_gradient.py new file mode 100644 index 0000000000..8fb6c1de71 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/mesh_lsq_gradient.py @@ -0,0 +1,290 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from ._torch_impl import mesh_lsq_gradient_torch +from ._warp_impl import mesh_lsq_gradient_warp + + +class MeshLSQGradient(FunctionSpec): + r"""Weighted least-squares gradient reconstruction on unstructured entities. + + This functional computes gradients from unstructured neighborhoods provided + as CSR adjacency (`neighbor_offsets`, `neighbor_indices`). + + For each entity :math:`i`, it solves the weighted least-squares problem: + + .. math:: + + \nabla \phi_i = \arg\min_g + \sum_{j \in \mathcal{N}(i)} w_{ij} \left(g^T(x_j - x_i) - (\phi_j - \phi_i)\right)^2 + + with inverse-distance weighting: + + .. math:: + + w_{ij} = ||x_j - x_i||^{-\alpha} + + where :math:`\alpha` is ``weight_power``. + + Parameters + ---------- + points : torch.Tensor + Entity coordinates with shape ``(n_entities, dims)``. + values : torch.Tensor + Scalar or tensor values with shape ``(n_entities,)`` or + ``(n_entities, ...)``. + neighbor_offsets : torch.Tensor + CSR offsets with shape ``(n_entities + 1,)``. + neighbor_indices : torch.Tensor + CSR flattened neighbor indices with shape ``(nnz,)``. + weight_power : float, optional + Inverse-distance exponent used for weighting. + min_neighbors : int, optional + Entities with fewer than this count get zero gradients. + safe_epsilon : float | None, optional + Positive floor applied to squared neighbor distances before + inverse-distance weighting. When ``None``, a dtype-derived default + is used by each backend. + implementation : {"warp", "torch"} or None + Explicit backend selection. When ``None``, dispatch selects by rank. + + Returns + ------- + torch.Tensor + Gradients with shape ``(n_entities, dims)`` for scalar values or + ``(n_entities, dims, ...)`` for tensor values. + """ + + ### Benchmark input presets (small -> large workload). + _BENCHMARK_CASES = ( + ("small-1d-scalar-n2048-k16", 2048, 1, 16, False), + ("small-2d-scalar-n1024-k16", 1024, 2, 16, False), + ("medium-3d-scalar-n2048-k16", 2048, 3, 16, False), + ("medium-3d-vector-n2048-k16", 2048, 3, 16, True), + ) + + _COMPARE_ATOL = 5e-3 + _COMPARE_RTOL = 5e-3 + _COMPARE_BACKWARD_ATOL = 8e-3 + _COMPARE_BACKWARD_RTOL = 8e-3 + + @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0) + def warp_forward( + points: torch.Tensor, + values: torch.Tensor, + neighbor_offsets: torch.Tensor, + neighbor_indices: torch.Tensor, + weight_power: float = 2.0, + min_neighbors: int = 0, + safe_epsilon: float | None = None, + ) -> torch.Tensor: + """Dispatch mesh LSQ gradients to the Warp backend.""" + ### Warp backend implementation. + return mesh_lsq_gradient_warp( + points=points, + values=values, + neighbor_offsets=neighbor_offsets, + neighbor_indices=neighbor_indices, + weight_power=weight_power, + min_neighbors=min_neighbors, + safe_epsilon=safe_epsilon, + ) + + @FunctionSpec.register(name="torch", rank=1, baseline=True) + def torch_forward( + points: torch.Tensor, + values: torch.Tensor, + neighbor_offsets: torch.Tensor, + neighbor_indices: torch.Tensor, + weight_power: float = 2.0, + min_neighbors: int = 0, + safe_epsilon: float | None = None, + ) -> torch.Tensor: + """Dispatch mesh LSQ gradients to eager PyTorch.""" + ### PyTorch backend implementation. + return mesh_lsq_gradient_torch( + points=points, + values=values, + neighbor_offsets=neighbor_offsets, + neighbor_indices=neighbor_indices, + weight_power=weight_power, + min_neighbors=min_neighbors, + safe_epsilon=safe_epsilon, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build deterministic synthetic CSR neighborhoods and signal fields. + for ( + label, + n_entities, + n_dims, + k_neighbors, + vector_values, + ) in cls._BENCHMARK_CASES: + generator = torch.Generator(device=device) + generator.manual_seed(1234 + n_entities + n_dims) + + ### Generate point cloud and fixed-k CSR adjacency. + points = torch.rand( + (n_entities, n_dims), generator=generator, device=device + ) + dists = torch.cdist(points, points) + knn = torch.topk(dists, k=k_neighbors + 1, largest=False, dim=1).indices[ + :, 1: + ] + offsets = torch.arange( + 0, + n_entities * k_neighbors + 1, + k_neighbors, + device=device, + dtype=torch.int64, + ) + indices = knn.reshape(-1).to(torch.int64) + + ### Build scalar/vector fields from analytic trigonometric signals. + if vector_values: + values = torch.stack( + [ + torch.sin(2.0 * torch.pi * points[:, 0]), + torch.cos(2.0 * torch.pi * points[:, 1]), + torch.sin(2.0 * torch.pi * points[:, -1]), + ], + dim=-1, + ).to(torch.float32) + else: + values = ( + torch.sin(2.0 * torch.pi * points[:, 0]) + + 0.5 * torch.cos(2.0 * torch.pi * points[:, -1]) + ).to(torch.float32) + + ### Yield the labeled functional input case. + yield ( + label, + ( + points.to(torch.float32), + values, + offsets, + indices, + ), + { + "weight_power": 2.0, + "min_neighbors": 0, + }, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build representative scalar/vector LSQ inputs for backward parity. + backward_cases = ( + ("backward-2d-scalar-n512-k12", 512, 2, 12, False), + ("backward-3d-scalar-n768-k12", 768, 3, 12, False), + ("backward-3d-vector-n768-k12", 768, 3, 12, True), + ) + + for label, n_entities, n_dims, k_neighbors, vector_values in backward_cases: + generator = torch.Generator(device=device) + generator.manual_seed(8411 + n_entities + n_dims) + + ### Build deterministic KNN-CSR adjacency. + points = torch.rand( + (n_entities, n_dims), generator=generator, device=device + ) + dists = torch.cdist(points, points) + knn = torch.topk(dists, k=k_neighbors + 1, largest=False, dim=1).indices[ + :, 1: + ] + offsets = torch.arange( + 0, + n_entities * k_neighbors + 1, + k_neighbors, + device=device, + dtype=torch.int64, + ) + indices = knn.reshape(-1).to(torch.int64) + + ### Build differentiable scalar/vector field values. + if vector_values: + values = torch.stack( + [ + torch.sin(2.0 * torch.pi * points[:, 0]), + torch.cos(2.0 * torch.pi * points[:, 1]), + torch.sin(2.0 * torch.pi * points[:, -1]), + ], + dim=-1, + ).to(torch.float32) + else: + values = ( + torch.sin(2.0 * torch.pi * points[:, 0]) + + 0.5 * torch.cos(2.0 * torch.pi * points[:, -1]) + ).to(torch.float32) + + values = values.detach().clone().requires_grad_(True) + + ### Keep point coordinates fixed for backward parity on field gradients. + yield ( + label, + ( + points.to(torch.float32), + values, + offsets, + indices, + ), + { + "weight_power": 2.0, + "min_neighbors": 0, + }, + ) + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare forward outputs across implementations.""" + ### Validate forward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + @classmethod + def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare backward gradients across implementations.""" + ### Validate backward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_BACKWARD_ATOL, + rtol=cls._COMPARE_BACKWARD_RTOL, + ) + + +mesh_lsq_gradient = MeshLSQGradient.make_function("mesh_lsq_gradient") + + +__all__ = ["MeshLSQGradient", "mesh_lsq_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/utils.py b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/utils.py new file mode 100644 index 0000000000..5e52977716 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/mesh_lsq_gradient/utils.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math + +import torch + + +def validate_inputs( + points: torch.Tensor, + values: torch.Tensor, + neighbor_offsets: torch.Tensor, + neighbor_indices: torch.Tensor, + *, + min_neighbors: int, +) -> None: + """Validate shared mesh LSQ input contracts for torch and warp backends.""" + ### Validate core tensor shapes and dimensions. + if points.ndim != 2: + raise ValueError( + f"points must have shape (n_entities, dims), got {points.shape=}" + ) + if points.shape[1] < 1 or points.shape[1] > 3: + raise ValueError(f"points must be 1D/2D/3D, got dims={points.shape[1]}") + if values.ndim < 1: + raise ValueError( + f"values must have shape (n_entities, ...), got {values.shape=}" + ) + if values.shape[0] != points.shape[0]: + raise ValueError( + f"values leading dimension must match points: {values.shape[0]} != {points.shape[0]}" + ) + if neighbor_offsets.ndim != 1: + raise ValueError("neighbor_offsets must be rank-1") + if neighbor_offsets.shape[0] != points.shape[0] + 1: + raise ValueError( + "neighbor_offsets must have shape (n_entities + 1,), " + f"got {neighbor_offsets.shape} for n_entities={points.shape[0]}" + ) + if neighbor_indices.ndim != 1: + raise ValueError("neighbor_indices must be rank-1") + if min_neighbors < 0: + raise ValueError("min_neighbors must be non-negative") + + ### Validate all inputs are co-located on the same device. + if not ( + points.device == values.device + and points.device == neighbor_offsets.device + and points.device == neighbor_indices.device + ): + raise ValueError( + "points, values, neighbor_offsets, and neighbor_indices must be on the same device" + ) + + ### Validate floating-point and index dtypes. + if not torch.is_floating_point(points): + raise TypeError("points must be floating-point") + if not torch.is_floating_point(values): + raise TypeError("values must be floating-point") + if neighbor_offsets.dtype not in (torch.int32, torch.int64): + raise TypeError("neighbor_offsets must be int32 or int64") + if neighbor_indices.dtype not in (torch.int32, torch.int64): + raise TypeError("neighbor_indices must be int32 or int64") + + ### Validate CSR range invariants. + if int(neighbor_offsets[0].item()) != 0: + raise ValueError("neighbor_offsets must start at 0") + if int(neighbor_offsets[-1].item()) != neighbor_indices.shape[0]: + raise ValueError("neighbor_offsets[-1] must equal len(neighbor_indices)") + if torch.any(neighbor_offsets[1:] < neighbor_offsets[:-1]): + raise ValueError("neighbor_offsets must be non-decreasing") + + if neighbor_indices.numel() > 0: + idx_min = int(neighbor_indices.min().item()) + idx_max = int(neighbor_indices.max().item()) + if idx_min < 0 or idx_max >= points.shape[0]: + raise ValueError( + f"neighbor_indices must satisfy 0 <= index < n_entities ({points.shape[0]})" + ) + + +def resolve_safe_epsilon(*, safe_epsilon: float | None, dtype: torch.dtype) -> float: + """Resolve user-provided or dtype-derived distance floor epsilon.""" + if safe_epsilon is None: + return float(torch.finfo(dtype).tiny ** 0.25) + eps = float(safe_epsilon) + if not math.isfinite(eps) or eps <= 0.0: + raise ValueError("safe_epsilon must be a finite positive value") + return eps diff --git a/physicsnemo/nn/functional/derivatives/meshless_finite_difference/__init__.py b/physicsnemo/nn/functional/derivatives/meshless_finite_difference/__init__.py new file mode 100644 index 0000000000..d3d27c771f --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/meshless_finite_difference/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .meshless_finite_difference import ( + MeshlessFDDerivatives, + meshless_fd_derivatives, +) + +__all__ = [ + "MeshlessFDDerivatives", + "meshless_fd_derivatives", +] diff --git a/physicsnemo/nn/functional/derivatives/meshless_finite_difference/_torch_impl.py b/physicsnemo/nn/functional/derivatives/meshless_finite_difference/_torch_impl.py new file mode 100644 index 0000000000..3dc6fcba42 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/meshless_finite_difference/_torch_impl.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from itertools import combinations, product + +import torch + + +def _normalize_spacing(spacing: float | Sequence[float], dim: int) -> tuple[float, ...]: + """Normalize meshless finite-difference spacing by spatial dimensionality.""" + if isinstance(spacing, (float, int)): + spacing_tuple = tuple(float(spacing) for _ in range(dim)) + else: + spacing_tuple = tuple(float(v) for v in spacing) + if len(spacing_tuple) != dim: + raise ValueError( + f"spacing must have {dim} entries for {dim}D inputs, got {len(spacing_tuple)}" + ) + + for axis, step in enumerate(spacing_tuple): + if step <= 0.0: + raise ValueError(f"spacing[{axis}] must be strictly positive") + return spacing_tuple + + +def _offset_lattice(dim: int) -> tuple[tuple[int, ...], ...]: + """Return canonical ``{-1, 0, 1}`` stencil offsets in lexicographic order.""" + return tuple(product((-1, 0, 1), repeat=dim)) + + +def _infer_dim_from_stencil_size(stencil_size: int) -> int: + """Infer dimensionality from a central stencil size.""" + if stencil_size == 3: + return 1 + if stencil_size == 9: + return 2 + if stencil_size == 27: + return 3 + raise ValueError( + "stencil_values second dimension must be 3, 9, or 27 " + f"(for 1D/2D/3D stencils), got {stencil_size}" + ) + + +def meshless_fd_stencil_points_torch( + points: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + include_center: bool = True, +) -> torch.Tensor: + """Build local Cartesian stencil points for meshless finite differences. + + Parameters + ---------- + points : torch.Tensor + Query points with shape ``(num_points, dim)`` where ``dim`` is 1, 2, or 3. + spacing : float | Sequence[float], optional + Stencil spacing per axis. + include_center : bool, optional + Include the center point ``(0, ..., 0)`` in the stencil if ``True``. + + Returns + ------- + torch.Tensor + Stencil points with shape ``(num_points, stencil_size, dim)``. + ``stencil_size`` is ``3**dim`` when ``include_center=True`` and + ``3**dim - 1`` otherwise. + """ + if points.ndim != 2: + raise ValueError( + f"points must have shape (num_points, dim), got {tuple(points.shape)}" + ) + if not torch.is_floating_point(points): + raise TypeError("points must be a floating-point tensor") + dim = points.shape[1] + if dim < 1 or dim > 3: + raise ValueError(f"only 1D/2D/3D points are supported, got dim={dim}") + + spacing_tuple = _normalize_spacing(spacing=spacing, dim=dim) + offsets = torch.tensor( + _offset_lattice(dim), + device=points.device, + dtype=points.dtype, + ) + if not include_center: + offsets = offsets[offsets.abs().sum(dim=-1) > 0] + + spacing_tensor = torch.tensor( + spacing_tuple, + device=points.device, + dtype=points.dtype, + ) + return points.unsqueeze(1) + offsets.unsqueeze(0) * spacing_tensor + + +def meshless_fd_derivatives_torch( + stencil_values: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 1, + return_mixed_derivs: bool = False, +) -> torch.Tensor: + """Compute central finite-difference derivatives from meshless stencil values. + + Parameters + ---------- + stencil_values : torch.Tensor + Values evaluated on canonical stencil points with shape + ``(num_points, stencil_size)`` or ``(num_points, stencil_size, channels)``. + Stencil ordering must match + :func:`meshless_fd_stencil_points_torch` with ``include_center=True``. + spacing : float | Sequence[float], optional + Spacing per spatial axis. + order : int, optional + Derivative order, either ``1`` or ``2``. + return_mixed_derivs : bool, optional + Include mixed second derivatives. Valid only with ``order=2`` and + dimensionality >= 2. + + Returns + ------- + torch.Tensor + Stacked derivatives with shape ``(num_derivatives, num_points)`` for + scalar input or ``(num_derivatives, num_points, channels)`` for vector input. + """ + if stencil_values.ndim not in (2, 3): + raise ValueError( + "stencil_values must have shape (num_points, stencil_size) or " + "(num_points, stencil_size, channels)" + ) + if not torch.is_floating_point(stencil_values): + raise TypeError("stencil_values must be a floating-point tensor") + if order not in (1, 2): + raise ValueError(f"order must be 1 or 2, got {order}") + + dim = _infer_dim_from_stencil_size(stencil_values.shape[1]) + if return_mixed_derivs and order != 2: + raise ValueError("return_mixed_derivs=True requires order=2") + if return_mixed_derivs and dim == 1: + raise ValueError("mixed derivatives require at least 2D stencils") + + spacing_tuple = _normalize_spacing(spacing=spacing, dim=dim) + + squeeze_channels = stencil_values.ndim == 2 + values = stencil_values.unsqueeze(-1) if squeeze_channels else stencil_values + values_eval = ( + values.to(torch.float32) + if values.dtype in (torch.float16, torch.bfloat16) + else values + ) + + offsets = _offset_lattice(dim) + offset_to_index = {offset: idx for idx, offset in enumerate(offsets)} + center_index = offset_to_index[(0,) * dim] + + derivatives: list[torch.Tensor] = [] + for axis in range(dim): + plus = [0] * dim + minus = [0] * dim + plus[axis] = 1 + minus[axis] = -1 + + plus_index = offset_to_index[tuple(plus)] + minus_index = offset_to_index[tuple(minus)] + step = spacing_tuple[axis] + + if order == 1: + derivatives.append( + (values_eval[:, plus_index] - values_eval[:, minus_index]) + / (2.0 * step) + ) + else: + derivatives.append( + ( + values_eval[:, plus_index] + - 2.0 * values_eval[:, center_index] + + values_eval[:, minus_index] + ) + / (step * step) + ) + + if order == 2 and return_mixed_derivs: + for axis_i, axis_j in combinations(range(dim), 2): + pp = [0] * dim + pm = [0] * dim + mp = [0] * dim + mm = [0] * dim + pp[axis_i], pp[axis_j] = 1, 1 + pm[axis_i], pm[axis_j] = 1, -1 + mp[axis_i], mp[axis_j] = -1, 1 + mm[axis_i], mm[axis_j] = -1, -1 + + pp_index = offset_to_index[tuple(pp)] + pm_index = offset_to_index[tuple(pm)] + mp_index = offset_to_index[tuple(mp)] + mm_index = offset_to_index[tuple(mm)] + + denominator = 4.0 * spacing_tuple[axis_i] * spacing_tuple[axis_j] + derivatives.append( + ( + values_eval[:, pp_index] + - values_eval[:, pm_index] + - values_eval[:, mp_index] + + values_eval[:, mm_index] + ) + / denominator + ) + + output = torch.stack(derivatives, dim=0) + if values_eval.dtype != values.dtype: + output = output.to(values.dtype) + if squeeze_channels: + return output.squeeze(-1) + return output diff --git a/physicsnemo/nn/functional/derivatives/meshless_finite_difference/meshless_finite_difference.py b/physicsnemo/nn/functional/derivatives/meshless_finite_difference/meshless_finite_difference.py new file mode 100644 index 0000000000..3ad2d35fe5 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/meshless_finite_difference/meshless_finite_difference.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from .._request_utils import ( + normalize_derivative_orders, + normalize_include_mixed, + validate_mixed_request, +) +from ._torch_impl import ( + meshless_fd_derivatives_torch, + meshless_fd_stencil_points_torch, +) + + +class MeshlessFDDerivatives(FunctionSpec): + """Compute meshless finite-difference derivatives from local stencil values. + + This functional expects values already sampled on a canonical + Cartesian ``{-1, 0, 1}`` stencil around each query point. + It does not build stencil coordinates internally; it only maps stencil + values to derivative estimates using central finite-difference formulas. + + Parameters + ---------- + stencil_values : torch.Tensor + Values sampled on a canonical ``{-1,0,1}`` stencil with shape + ``(num_points, stencil_size)`` or ``(num_points, stencil_size, channels)``. + Stencil sizes must be ``3``, ``9``, or ``27``. + spacing : float | Sequence[float], optional + Stencil spacing per axis. + derivative_orders : int | Sequence[int], optional + Derivative orders to compute. Supported values are ``1``, ``2``, or + ``(1, 2)``. + include_mixed : bool, optional + Include mixed second derivatives when requesting second derivatives. + implementation : {"torch"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + + Returns + ------- + torch.Tensor + Stacked derivatives with shape ``(num_derivatives, num_points)`` for scalar + input or ``(num_derivatives, num_points, channels)`` for vector input. + + Notes + ----- + Derivative stack ordering is deterministic: first derivatives, then pure + second derivatives, then mixed second derivatives in axis-combination + order. + + The stencil size infers dimensionality: + ``3 -> 1D``, ``9 -> 2D``, ``27 -> 3D``. + """ + + _BENCHMARK_CASES = ( + ("1d-scalar-n4096", 4096, 1, 0.01, 1, False, 1), + ("2d-scalar-n4096-o1", 4096, 2, (0.01, 0.02), 1, False, 1), + ("2d-vector-n4096-o2", 4096, 2, (0.01, 0.02), 2, True, 2), + ("3d-scalar-n2048-o2", 2048, 3, (0.01, 0.015, 0.02), 2, True, 1), + ) + + @FunctionSpec.register(name="torch", rank=0, baseline=True) + def torch_forward( + stencil_values: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + derivative_orders: int | Sequence[int] = 1, + include_mixed: bool = False, + ) -> torch.Tensor: + """Dispatch meshless finite-difference derivatives to the torch backend.""" + requested_orders = normalize_derivative_orders( + derivative_orders=derivative_orders, + function_name="meshless_fd_derivatives", + ) + mixed_terms = normalize_include_mixed( + include_mixed=include_mixed, + function_name="meshless_fd_derivatives", + ) + + ndim = _infer_dim_from_stencil_size(stencil_values) + if ndim is not None: + validate_mixed_request( + derivative_orders=requested_orders, + include_mixed=mixed_terms, + ndim=ndim, + function_name="meshless_fd_derivatives", + ) + + outputs: list[torch.Tensor] = [] + if 1 in requested_orders: + outputs.append( + meshless_fd_derivatives_torch( + stencil_values=stencil_values, + spacing=spacing, + order=1, + return_mixed_derivs=False, + ) + ) + if 2 in requested_orders: + outputs.append( + meshless_fd_derivatives_torch( + stencil_values=stencil_values, + spacing=spacing, + order=2, + return_mixed_derivs=mixed_terms, + ) + ) + return torch.cat(outputs, dim=0) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + for ( + label, + num_points, + dim, + spacing, + order, + return_mixed_derivs, + channels, + ) in cls._BENCHMARK_CASES: + points = torch.rand(num_points, dim, device=device, dtype=torch.float32) + stencil_points = meshless_fd_stencil_points_torch(points, spacing=spacing) + stencil_values = cls._evaluate_stencil(stencil_points, channels=channels) + yield ( + label, + (stencil_values,), + { + "spacing": spacing, + "derivative_orders": order, + "include_mixed": return_mixed_derivs, + }, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + backward_cases = ( + ("1d-grad-n2048", 2048, 1, 0.01, 1, False, 2), + ("2d-grad-n2048-o2", 2048, 2, (0.01, 0.02), 2, True, 2), + ("3d-grad-n1024-o2", 1024, 3, (0.01, 0.015, 0.02), 2, True, 1), + ) + for ( + label, + num_points, + dim, + spacing, + order, + return_mixed_derivs, + channels, + ) in backward_cases: + points = torch.rand(num_points, dim, device=device, dtype=torch.float32) + stencil_points = meshless_fd_stencil_points_torch(points, spacing=spacing) + stencil_values = ( + cls._evaluate_stencil(stencil_points, channels=channels) + .detach() + .clone() + .requires_grad_(True) + ) + yield ( + label, + (stencil_values,), + { + "spacing": spacing, + "derivative_orders": order, + "include_mixed": return_mixed_derivs, + }, + ) + + @staticmethod + def _evaluate_stencil( + stencil_points: torch.Tensor, + channels: int, + ) -> torch.Tensor: + """Generate smooth multi-channel stencil values for benchmark inputs.""" + x = stencil_points[..., 0] + if stencil_points.shape[-1] == 1: + values = [torch.sin(2.0 * x) + 0.3 * x.square()] + elif stencil_points.shape[-1] == 2: + y = stencil_points[..., 1] + values = [ + torch.sin(1.4 * x) * torch.cos(0.7 * y) + 0.2 * x * y, + x.square() + y.pow(3), + ] + else: + y = stencil_points[..., 1] + z = stencil_points[..., 2] + values = [ + torch.sin(1.2 * x) * torch.cos(0.8 * y) * torch.sin(0.6 * z) + + 0.1 * x * y * z, + x.square() + 0.5 * y.square() - z, + ] + + stacked = torch.stack(values[:channels], dim=-1) + if channels == 1: + return stacked[..., 0] + return stacked + + +meshless_fd_derivatives = MeshlessFDDerivatives.make_function("meshless_fd_derivatives") + + +__all__ = [ + "MeshlessFDDerivatives", + "meshless_fd_derivatives", +] + + +def _infer_dim_from_stencil_size(stencil_values: torch.Tensor) -> int | None: + """Infer dimensionality from stencil shape when it is structurally valid.""" + if stencil_values.ndim not in (2, 3): + return None + return {3: 1, 9: 2, 27: 3}.get(stencil_values.shape[1]) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/__init__.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/__init__.py new file mode 100644 index 0000000000..29274e3cd3 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .rectilinear_grid_gradient import ( + RectilinearGridGradient, + rectilinear_grid_gradient, +) + +__all__ = ["RectilinearGridGradient", "rectilinear_grid_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_torch_impl.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_torch_impl.py new file mode 100644 index 0000000000..be48d4689b --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_torch_impl.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from .utils import ( + axis_central_weights, + axis_second_derivative_weights, + validate_and_normalize_coordinates, + validate_derivative_request, + validate_field, +) + + +def rectilinear_grid_gradient_torch( + field: torch.Tensor, + coordinates: Sequence[torch.Tensor], + periods: float | Sequence[float] | None = None, + derivative_order: int = 1, + include_mixed: bool = False, +) -> torch.Tensor: + """Compute periodic first or pure second derivatives on rectilinear grids.""" + ### Validate field and coordinate inputs. + validate_field(field) + derivative_order = validate_derivative_request( + derivative_order=derivative_order, + include_mixed=include_mixed, + ) + + coords_tuple, period_tuple = validate_and_normalize_coordinates( + field=field, + coordinates=coordinates, + periods=periods, + coordinates_dtype=field.dtype, + requires_grad_error="coordinate gradients are not supported; pass detached coordinates", + ) + + ### Compute per-axis nonuniform periodic central-difference derivatives. + gradients: list[torch.Tensor] = [] + for axis in range(field.ndim): + if derivative_order == 1: + w_minus, w_center, w_plus = axis_central_weights( + coords_tuple[axis], + period_tuple[axis], + ) + else: + w_minus, w_center, w_plus = axis_second_derivative_weights( + coords_tuple[axis], + period_tuple[axis], + ) + + view_shape = [1] * field.ndim + view_shape[axis] = field.shape[axis] + w_minus = w_minus.view(view_shape) + w_center = w_center.view(view_shape) + w_plus = w_plus.view(view_shape) + + grad_axis = ( + w_minus * torch.roll(field, shifts=1, dims=axis) + + w_center * field + + w_plus * torch.roll(field, shifts=-1, dims=axis) + ) + gradients.append(grad_axis) + + ### Stack per-axis derivative terms into (dims, *field.shape). + return torch.stack(gradients, dim=0) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/__init__.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/__init__.py new file mode 100644 index 0000000000..cf63766f87 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .op import rectilinear_grid_gradient_warp, rectilinear_grid_gradient_warp_multi + +__all__ = ["rectilinear_grid_gradient_warp", "rectilinear_grid_gradient_warp_multi"] diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/__init__.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/__init__.py new file mode 100644 index 0000000000..5116aa06be --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/__init__.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .backward import ( + _rectilinear_derivatives_1d_fused_no_mixed_backward_kernel, + _rectilinear_derivatives_2d_fused_no_mixed_backward_kernel, + _rectilinear_derivatives_3d_fused_no_mixed_backward_kernel, + _rectilinear_gradient_1d_backward_kernel, + _rectilinear_gradient_2d_backward_kernel, + _rectilinear_gradient_3d_backward_kernel, + _rectilinear_second_derivative_1d_backward_kernel, + _rectilinear_second_derivative_2d_backward_kernel, + _rectilinear_second_derivative_3d_backward_kernel, +) +from .forward import ( + _rectilinear_derivatives_1d_fused_no_mixed_kernel, + _rectilinear_derivatives_2d_fused_no_mixed_kernel, + _rectilinear_derivatives_3d_fused_no_mixed_kernel, + _rectilinear_gradient_1d_kernel, + _rectilinear_gradient_2d_kernel, + _rectilinear_gradient_3d_kernel, + _rectilinear_second_derivative_1d_kernel, + _rectilinear_second_derivative_2d_kernel, + _rectilinear_second_derivative_3d_kernel, +) + +__all__ = [ + "_rectilinear_gradient_1d_kernel", + "_rectilinear_gradient_2d_kernel", + "_rectilinear_gradient_3d_kernel", + "_rectilinear_second_derivative_1d_kernel", + "_rectilinear_second_derivative_2d_kernel", + "_rectilinear_second_derivative_3d_kernel", + "_rectilinear_derivatives_1d_fused_no_mixed_kernel", + "_rectilinear_derivatives_2d_fused_no_mixed_kernel", + "_rectilinear_derivatives_3d_fused_no_mixed_kernel", + "_rectilinear_derivatives_1d_fused_no_mixed_backward_kernel", + "_rectilinear_derivatives_2d_fused_no_mixed_backward_kernel", + "_rectilinear_derivatives_3d_fused_no_mixed_backward_kernel", + "_rectilinear_gradient_1d_backward_kernel", + "_rectilinear_gradient_2d_backward_kernel", + "_rectilinear_gradient_3d_backward_kernel", + "_rectilinear_second_derivative_1d_backward_kernel", + "_rectilinear_second_derivative_2d_backward_kernel", + "_rectilinear_second_derivative_3d_backward_kernel", +] diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/backward.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/backward.py new file mode 100644 index 0000000000..b8dcd1e6ba --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/backward.py @@ -0,0 +1,370 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warp as wp + +from .utils import _axis_coeff, _axis_second_coeff + +### ============================================================ +### Backward kernels (adjoint of rectilinear central differences) +### ============================================================ + + +@wp.kernel +def _rectilinear_gradient_1d_backward_kernel( + grad0: wp.array(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + period0: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + + ci = _axis_coeff(x0, period0, i) + cip = _axis_coeff(x0, period0, ip) + cim = _axis_coeff(x0, period0, im) + grad_field[i] = ci[1] * grad0[i] + cip[0] * grad0[ip] + cim[2] * grad0[im] + + +@wp.kernel +def _rectilinear_gradient_2d_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + period0: float, + period1: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + + cxi = _axis_coeff(x0, period0, i) + cxip = _axis_coeff(x0, period0, ip) + cxim = _axis_coeff(x0, period0, im) + + cyi = _axis_coeff(x1, period1, j) + cyip = _axis_coeff(x1, period1, jp) + cyim = _axis_coeff(x1, period1, jm) + + gx = cxi[1] * grad0[i, j] + cxip[0] * grad0[ip, j] + cxim[2] * grad0[im, j] + gy = cyi[1] * grad1[i, j] + cyip[0] * grad1[i, jp] + cyim[2] * grad1[i, jm] + grad_field[i, j] = gx + gy + + +@wp.kernel +def _rectilinear_gradient_3d_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + x2: wp.array(dtype=wp.float32), + period0: float, + period1: float, + period2: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + km = (k + n2 - 1) % n2 + kp = (k + 1) % n2 + + cxi = _axis_coeff(x0, period0, i) + cxip = _axis_coeff(x0, period0, ip) + cxim = _axis_coeff(x0, period0, im) + + cyi = _axis_coeff(x1, period1, j) + cyip = _axis_coeff(x1, period1, jp) + cyim = _axis_coeff(x1, period1, jm) + + czi = _axis_coeff(x2, period2, k) + czip = _axis_coeff(x2, period2, kp) + czim = _axis_coeff(x2, period2, km) + + gx = cxi[1] * grad0[i, j, k] + cxip[0] * grad0[ip, j, k] + cxim[2] * grad0[im, j, k] + gy = cyi[1] * grad1[i, j, k] + cyip[0] * grad1[i, jp, k] + cyim[2] * grad1[i, jm, k] + gz = czi[1] * grad2[i, j, k] + czip[0] * grad2[i, j, kp] + czim[2] * grad2[i, j, km] + grad_field[i, j, k] = gx + gy + gz + + +@wp.kernel +def _rectilinear_second_derivative_1d_backward_kernel( + grad0: wp.array(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + period0: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + + ci = _axis_second_coeff(x0, period0, i) + cip = _axis_second_coeff(x0, period0, ip) + cim = _axis_second_coeff(x0, period0, im) + grad_field[i] = ci[1] * grad0[i] + cip[0] * grad0[ip] + cim[2] * grad0[im] + + +@wp.kernel +def _rectilinear_second_derivative_2d_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + period0: float, + period1: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + + cxi = _axis_second_coeff(x0, period0, i) + cxip = _axis_second_coeff(x0, period0, ip) + cxim = _axis_second_coeff(x0, period0, im) + + cyi = _axis_second_coeff(x1, period1, j) + cyip = _axis_second_coeff(x1, period1, jp) + cyim = _axis_second_coeff(x1, period1, jm) + + gx = cxi[1] * grad0[i, j] + cxip[0] * grad0[ip, j] + cxim[2] * grad0[im, j] + gy = cyi[1] * grad1[i, j] + cyip[0] * grad1[i, jp] + cyim[2] * grad1[i, jm] + grad_field[i, j] = gx + gy + + +@wp.kernel +def _rectilinear_second_derivative_3d_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + x2: wp.array(dtype=wp.float32), + period0: float, + period1: float, + period2: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + km = (k + n2 - 1) % n2 + kp = (k + 1) % n2 + + cxi = _axis_second_coeff(x0, period0, i) + cxip = _axis_second_coeff(x0, period0, ip) + cxim = _axis_second_coeff(x0, period0, im) + + cyi = _axis_second_coeff(x1, period1, j) + cyip = _axis_second_coeff(x1, period1, jp) + cyim = _axis_second_coeff(x1, period1, jm) + + czi = _axis_second_coeff(x2, period2, k) + czip = _axis_second_coeff(x2, period2, kp) + czim = _axis_second_coeff(x2, period2, km) + + gx = cxi[1] * grad0[i, j, k] + cxip[0] * grad0[ip, j, k] + cxim[2] * grad0[im, j, k] + gy = cyi[1] * grad1[i, j, k] + cyip[0] * grad1[i, jp, k] + cyim[2] * grad1[i, jm, k] + gz = czi[1] * grad2[i, j, k] + czip[0] * grad2[i, j, kp] + czim[2] * grad2[i, j, km] + grad_field[i, j, k] = gx + gy + gz + + +### ============================================================ +### Fused backward kernels (adjoint of 1st+2nd, no mixed) +### ============================================================ + + +@wp.kernel +def _rectilinear_derivatives_1d_fused_no_mixed_backward_kernel( + grad0: wp.array(dtype=wp.float32), + grad00: wp.array(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + period0: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + + c1_i = _axis_coeff(x0, period0, i) + c1_ip = _axis_coeff(x0, period0, ip) + c1_im = _axis_coeff(x0, period0, im) + + c2_i = _axis_second_coeff(x0, period0, i) + c2_ip = _axis_second_coeff(x0, period0, ip) + c2_im = _axis_second_coeff(x0, period0, im) + + g1 = c1_i[1] * grad0[i] + c1_ip[0] * grad0[ip] + c1_im[2] * grad0[im] + g2 = c2_i[1] * grad00[i] + c2_ip[0] * grad00[ip] + c2_im[2] * grad00[im] + grad_field[i] = g1 + g2 + + +@wp.kernel +def _rectilinear_derivatives_2d_fused_no_mixed_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + grad00: wp.array2d(dtype=wp.float32), + grad11: wp.array2d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + period0: float, + period1: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + + cx1_i = _axis_coeff(x0, period0, i) + cx1_ip = _axis_coeff(x0, period0, ip) + cx1_im = _axis_coeff(x0, period0, im) + cy1_i = _axis_coeff(x1, period1, j) + cy1_ip = _axis_coeff(x1, period1, jp) + cy1_im = _axis_coeff(x1, period1, jm) + + cx2_i = _axis_second_coeff(x0, period0, i) + cx2_ip = _axis_second_coeff(x0, period0, ip) + cx2_im = _axis_second_coeff(x0, period0, im) + cy2_i = _axis_second_coeff(x1, period1, j) + cy2_ip = _axis_second_coeff(x1, period1, jp) + cy2_im = _axis_second_coeff(x1, period1, jm) + + g1x = cx1_i[1] * grad0[i, j] + cx1_ip[0] * grad0[ip, j] + cx1_im[2] * grad0[im, j] + g1y = cy1_i[1] * grad1[i, j] + cy1_ip[0] * grad1[i, jp] + cy1_im[2] * grad1[i, jm] + g2x = ( + cx2_i[1] * grad00[i, j] + cx2_ip[0] * grad00[ip, j] + cx2_im[2] * grad00[im, j] + ) + g2y = ( + cy2_i[1] * grad11[i, j] + cy2_ip[0] * grad11[i, jp] + cy2_im[2] * grad11[i, jm] + ) + grad_field[i, j] = g1x + g1y + g2x + g2y + + +@wp.kernel +def _rectilinear_derivatives_3d_fused_no_mixed_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + grad00: wp.array3d(dtype=wp.float32), + grad11: wp.array3d(dtype=wp.float32), + grad22: wp.array3d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + x2: wp.array(dtype=wp.float32), + period0: float, + period1: float, + period2: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + km = (k + n2 - 1) % n2 + kp = (k + 1) % n2 + + cx1_i = _axis_coeff(x0, period0, i) + cx1_ip = _axis_coeff(x0, period0, ip) + cx1_im = _axis_coeff(x0, period0, im) + cy1_i = _axis_coeff(x1, period1, j) + cy1_ip = _axis_coeff(x1, period1, jp) + cy1_im = _axis_coeff(x1, period1, jm) + cz1_i = _axis_coeff(x2, period2, k) + cz1_ip = _axis_coeff(x2, period2, kp) + cz1_im = _axis_coeff(x2, period2, km) + + cx2_i = _axis_second_coeff(x0, period0, i) + cx2_ip = _axis_second_coeff(x0, period0, ip) + cx2_im = _axis_second_coeff(x0, period0, im) + cy2_i = _axis_second_coeff(x1, period1, j) + cy2_ip = _axis_second_coeff(x1, period1, jp) + cy2_im = _axis_second_coeff(x1, period1, jm) + cz2_i = _axis_second_coeff(x2, period2, k) + cz2_ip = _axis_second_coeff(x2, period2, kp) + cz2_im = _axis_second_coeff(x2, period2, km) + + g1x = ( + cx1_i[1] * grad0[i, j, k] + + cx1_ip[0] * grad0[ip, j, k] + + cx1_im[2] * grad0[im, j, k] + ) + g1y = ( + cy1_i[1] * grad1[i, j, k] + + cy1_ip[0] * grad1[i, jp, k] + + cy1_im[2] * grad1[i, jm, k] + ) + g1z = ( + cz1_i[1] * grad2[i, j, k] + + cz1_ip[0] * grad2[i, j, kp] + + cz1_im[2] * grad2[i, j, km] + ) + + g2x = ( + cx2_i[1] * grad00[i, j, k] + + cx2_ip[0] * grad00[ip, j, k] + + cx2_im[2] * grad00[im, j, k] + ) + g2y = ( + cy2_i[1] * grad11[i, j, k] + + cy2_ip[0] * grad11[i, jp, k] + + cy2_im[2] * grad11[i, jm, k] + ) + g2z = ( + cz2_i[1] * grad22[i, j, k] + + cz2_ip[0] * grad22[i, j, kp] + + cz2_im[2] * grad22[i, j, km] + ) + grad_field[i, j, k] = g1x + g1y + g1z + g2x + g2y + g2z diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/forward.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/forward.py new file mode 100644 index 0000000000..ca6f7fd5be --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/forward.py @@ -0,0 +1,294 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warp as wp + +from .utils import _axis_coeff, _axis_second_coeff + +### ============================================================ +### Forward kernels (rectilinear periodic central differences) +### ============================================================ + + +@wp.kernel +def _rectilinear_gradient_1d_kernel( + field: wp.array(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + period0: float, + grad0: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + + coeff = _axis_coeff(x0, period0, i) + grad0[i] = coeff[0] * field[im] + coeff[1] * field[i] + coeff[2] * field[ip] + + +@wp.kernel +def _rectilinear_gradient_2d_kernel( + field: wp.array2d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + period0: float, + period1: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + + cx = _axis_coeff(x0, period0, i) + cy = _axis_coeff(x1, period1, j) + + grad0[i, j] = cx[0] * field[im, j] + cx[1] * field[i, j] + cx[2] * field[ip, j] + grad1[i, j] = cy[0] * field[i, jm] + cy[1] * field[i, j] + cy[2] * field[i, jp] + + +@wp.kernel +def _rectilinear_gradient_3d_kernel( + field: wp.array3d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + x2: wp.array(dtype=wp.float32), + period0: float, + period1: float, + period2: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + km = (k + n2 - 1) % n2 + kp = (k + 1) % n2 + + cx = _axis_coeff(x0, period0, i) + cy = _axis_coeff(x1, period1, j) + cz = _axis_coeff(x2, period2, k) + + grad0[i, j, k] = ( + cx[0] * field[im, j, k] + cx[1] * field[i, j, k] + cx[2] * field[ip, j, k] + ) + grad1[i, j, k] = ( + cy[0] * field[i, jm, k] + cy[1] * field[i, j, k] + cy[2] * field[i, jp, k] + ) + grad2[i, j, k] = ( + cz[0] * field[i, j, km] + cz[1] * field[i, j, k] + cz[2] * field[i, j, kp] + ) + + +@wp.kernel +def _rectilinear_second_derivative_1d_kernel( + field: wp.array(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + period0: float, + grad0: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + + coeff = _axis_second_coeff(x0, period0, i) + grad0[i] = coeff[0] * field[im] + coeff[1] * field[i] + coeff[2] * field[ip] + + +@wp.kernel +def _rectilinear_second_derivative_2d_kernel( + field: wp.array2d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + period0: float, + period1: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + + cx = _axis_second_coeff(x0, period0, i) + cy = _axis_second_coeff(x1, period1, j) + + grad0[i, j] = cx[0] * field[im, j] + cx[1] * field[i, j] + cx[2] * field[ip, j] + grad1[i, j] = cy[0] * field[i, jm] + cy[1] * field[i, j] + cy[2] * field[i, jp] + + +@wp.kernel +def _rectilinear_second_derivative_3d_kernel( + field: wp.array3d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + x2: wp.array(dtype=wp.float32), + period0: float, + period1: float, + period2: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + km = (k + n2 - 1) % n2 + kp = (k + 1) % n2 + + cx = _axis_second_coeff(x0, period0, i) + cy = _axis_second_coeff(x1, period1, j) + cz = _axis_second_coeff(x2, period2, k) + + grad0[i, j, k] = ( + cx[0] * field[im, j, k] + cx[1] * field[i, j, k] + cx[2] * field[ip, j, k] + ) + grad1[i, j, k] = ( + cy[0] * field[i, jm, k] + cy[1] * field[i, j, k] + cy[2] * field[i, jp, k] + ) + grad2[i, j, k] = ( + cz[0] * field[i, j, km] + cz[1] * field[i, j, k] + cz[2] * field[i, j, kp] + ) + + +### ============================================================ +### Fused forward kernels (single launch for 1st+2nd, no mixed) +### ============================================================ + + +@wp.kernel +def _rectilinear_derivatives_1d_fused_no_mixed_kernel( + field: wp.array(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + period0: float, + grad0: wp.array(dtype=wp.float32), + grad00: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + + coeff1 = _axis_coeff(x0, period0, i) + coeff2 = _axis_second_coeff(x0, period0, i) + grad0[i] = coeff1[0] * field[im] + coeff1[1] * field[i] + coeff1[2] * field[ip] + grad00[i] = coeff2[0] * field[im] + coeff2[1] * field[i] + coeff2[2] * field[ip] + + +@wp.kernel +def _rectilinear_derivatives_2d_fused_no_mixed_kernel( + field: wp.array2d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + period0: float, + period1: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + grad00: wp.array2d(dtype=wp.float32), + grad11: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + + cx1 = _axis_coeff(x0, period0, i) + cy1 = _axis_coeff(x1, period1, j) + cx2 = _axis_second_coeff(x0, period0, i) + cy2 = _axis_second_coeff(x1, period1, j) + + grad0[i, j] = cx1[0] * field[im, j] + cx1[1] * field[i, j] + cx1[2] * field[ip, j] + grad1[i, j] = cy1[0] * field[i, jm] + cy1[1] * field[i, j] + cy1[2] * field[i, jp] + grad00[i, j] = cx2[0] * field[im, j] + cx2[1] * field[i, j] + cx2[2] * field[ip, j] + grad11[i, j] = cy2[0] * field[i, jm] + cy2[1] * field[i, j] + cy2[2] * field[i, jp] + + +@wp.kernel +def _rectilinear_derivatives_3d_fused_no_mixed_kernel( + field: wp.array3d(dtype=wp.float32), + x0: wp.array(dtype=wp.float32), + x1: wp.array(dtype=wp.float32), + x2: wp.array(dtype=wp.float32), + period0: float, + period1: float, + period2: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + grad00: wp.array3d(dtype=wp.float32), + grad11: wp.array3d(dtype=wp.float32), + grad22: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + im = (i + n0 - 1) % n0 + ip = (i + 1) % n0 + jm = (j + n1 - 1) % n1 + jp = (j + 1) % n1 + km = (k + n2 - 1) % n2 + kp = (k + 1) % n2 + + cx1 = _axis_coeff(x0, period0, i) + cy1 = _axis_coeff(x1, period1, j) + cz1 = _axis_coeff(x2, period2, k) + cx2 = _axis_second_coeff(x0, period0, i) + cy2 = _axis_second_coeff(x1, period1, j) + cz2 = _axis_second_coeff(x2, period2, k) + + grad0[i, j, k] = ( + cx1[0] * field[im, j, k] + cx1[1] * field[i, j, k] + cx1[2] * field[ip, j, k] + ) + grad1[i, j, k] = ( + cy1[0] * field[i, jm, k] + cy1[1] * field[i, j, k] + cy1[2] * field[i, jp, k] + ) + grad2[i, j, k] = ( + cz1[0] * field[i, j, km] + cz1[1] * field[i, j, k] + cz1[2] * field[i, j, kp] + ) + grad00[i, j, k] = ( + cx2[0] * field[im, j, k] + cx2[1] * field[i, j, k] + cx2[2] * field[ip, j, k] + ) + grad11[i, j, k] = ( + cy2[0] * field[i, jm, k] + cy2[1] * field[i, j, k] + cy2[2] * field[i, jp, k] + ) + grad22[i, j, k] = ( + cz2[0] * field[i, j, km] + cz2[1] * field[i, j, k] + cz2[2] * field[i, j, kp] + ) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/utils.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/utils.py new file mode 100644 index 0000000000..5ab294478a --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/_kernels/utils.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warp as wp + + +@wp.func +def _axis_coeff( + coords: wp.array(dtype=wp.float32), + period: float, + idx: int, +) -> wp.vec3f: + ### Compute nonuniform periodic central-difference weights at one index. + n = coords.shape[0] + im = (idx + n - 1) % n + ip = (idx + 1) % n + + xi = coords[idx] + xim = coords[im] + xip = coords[ip] + + h_minus = xi - xim + if idx == 0: + h_minus = xi + period - xim + + h_plus = xip - xi + if idx == (n - 1): + h_plus = xip + period - xi + + denom = h_minus + h_plus + w_minus = -h_plus / (h_minus * denom) + w_center = (h_plus - h_minus) / (h_minus * h_plus) + w_plus = h_minus / (h_plus * denom) + return wp.vec3f(w_minus, w_center, w_plus) + + +@wp.func +def _axis_second_coeff( + coords: wp.array(dtype=wp.float32), + period: float, + idx: int, +) -> wp.vec3f: + ### Compute nonuniform periodic second-derivative weights at one index. + n = coords.shape[0] + im = (idx + n - 1) % n + ip = (idx + 1) % n + + xi = coords[idx] + xim = coords[im] + xip = coords[ip] + + h_minus = xi - xim + if idx == 0: + h_minus = xi + period - xim + + h_plus = xip - xi + if idx == (n - 1): + h_plus = xip + period - xi + + denom = h_minus + h_plus + w_minus = 2.0 / (h_minus * denom) + w_center = -2.0 / (h_minus * h_plus) + w_plus = 2.0 / (h_plus * denom) + return wp.vec3f(w_minus, w_center, w_plus) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/launch_backward.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/launch_backward.py new file mode 100644 index 0000000000..a94e2aecaf --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/launch_backward.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import warp as wp + +from ._kernels import ( + _rectilinear_derivatives_1d_fused_no_mixed_backward_kernel, + _rectilinear_derivatives_2d_fused_no_mixed_backward_kernel, + _rectilinear_derivatives_3d_fused_no_mixed_backward_kernel, + _rectilinear_gradient_1d_backward_kernel, + _rectilinear_gradient_2d_backward_kernel, + _rectilinear_gradient_3d_backward_kernel, + _rectilinear_second_derivative_1d_backward_kernel, + _rectilinear_second_derivative_2d_backward_kernel, + _rectilinear_second_derivative_3d_backward_kernel, +) + +_BACKWARD_KERNELS = { + (1, 1): _rectilinear_gradient_1d_backward_kernel, + (1, 2): _rectilinear_second_derivative_1d_backward_kernel, + (2, 1): _rectilinear_gradient_2d_backward_kernel, + (2, 2): _rectilinear_second_derivative_2d_backward_kernel, + (3, 1): _rectilinear_gradient_3d_backward_kernel, + (3, 2): _rectilinear_second_derivative_3d_backward_kernel, +} + +_FUSED_BACKWARD_NO_MIXED_KERNELS = { + 1: _rectilinear_derivatives_1d_fused_no_mixed_backward_kernel, + 2: _rectilinear_derivatives_2d_fused_no_mixed_backward_kernel, + 3: _rectilinear_derivatives_3d_fused_no_mixed_backward_kernel, +} + + +def _launch_dim(shape: torch.Size) -> int | tuple[int, ...]: + """Return Warp launch dimensions for 1D vs ND kernels.""" + return shape[0] if len(shape) == 1 else tuple(shape) + + +def _to_wp_components(components: list[torch.Tensor], count: int) -> list[wp.array]: + """Convert the leading tensor components to Warp arrays.""" + return [wp.from_torch(components[i], dtype=wp.float32) for i in range(count)] + + +def _to_wp_coords(coords_tuple: tuple[torch.Tensor, ...], ndim: int) -> list[wp.array]: + """Convert coordinate axes to Warp arrays.""" + return [wp.from_torch(coords_tuple[i], dtype=wp.float32) for i in range(ndim)] + + +def _period_values(period_tuple: tuple[float, ...], ndim: int) -> list[float]: + """Convert axis periods to float values.""" + return [float(period_tuple[i]) for i in range(ndim)] + + +def _launch_backward( + *, + grad_output_fp32: torch.Tensor, + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], + derivative_order: int, + grad_field: torch.Tensor, + wp_device, + wp_stream, +) -> None: + ### Launch dimensionality-specific backward kernels. + ndim = grad_field.ndim + kernel = _BACKWARD_KERNELS[(ndim, derivative_order)] + inputs = [ + *_to_wp_components(grad_output_fp32, ndim), + *_to_wp_coords(coords_tuple, ndim), + *_period_values(period_tuple, ndim), + wp.from_torch(grad_field, dtype=wp.float32), + ] + + with wp.ScopedStream(wp_stream): + wp.launch( + kernel=kernel, + dim=_launch_dim(grad_field.shape), + inputs=inputs, + device=wp_device, + stream=wp_stream, + ) + + +def _launch_backward_fused_no_mixed( + *, + grad_first_components: list[torch.Tensor], + grad_second_components: list[torch.Tensor], + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], + grad_field: torch.Tensor, + wp_device, + wp_stream, +) -> None: + """Launch dimensionality-specific fused first+second backward kernels.""" + ndim = grad_field.ndim + kernel = _FUSED_BACKWARD_NO_MIXED_KERNELS[ndim] + inputs = [ + *_to_wp_components(grad_first_components, ndim), + *_to_wp_components(grad_second_components, ndim), + *_to_wp_coords(coords_tuple, ndim), + *_period_values(period_tuple, ndim), + wp.from_torch(grad_field, dtype=wp.float32), + ] + + with wp.ScopedStream(wp_stream): + wp.launch( + kernel=kernel, + dim=_launch_dim(grad_field.shape), + inputs=inputs, + device=wp_device, + stream=wp_stream, + ) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/launch_forward.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/launch_forward.py new file mode 100644 index 0000000000..36131694f4 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/launch_forward.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import warp as wp + +from ._kernels import ( + _rectilinear_derivatives_1d_fused_no_mixed_kernel, + _rectilinear_derivatives_2d_fused_no_mixed_kernel, + _rectilinear_derivatives_3d_fused_no_mixed_kernel, + _rectilinear_gradient_1d_kernel, + _rectilinear_gradient_2d_kernel, + _rectilinear_gradient_3d_kernel, + _rectilinear_second_derivative_1d_kernel, + _rectilinear_second_derivative_2d_kernel, + _rectilinear_second_derivative_3d_kernel, +) + +_FORWARD_KERNELS = { + (1, 1): _rectilinear_gradient_1d_kernel, + (1, 2): _rectilinear_second_derivative_1d_kernel, + (2, 1): _rectilinear_gradient_2d_kernel, + (2, 2): _rectilinear_second_derivative_2d_kernel, + (3, 1): _rectilinear_gradient_3d_kernel, + (3, 2): _rectilinear_second_derivative_3d_kernel, +} + +_FUSED_FORWARD_NO_MIXED_KERNELS = { + 1: _rectilinear_derivatives_1d_fused_no_mixed_kernel, + 2: _rectilinear_derivatives_2d_fused_no_mixed_kernel, + 3: _rectilinear_derivatives_3d_fused_no_mixed_kernel, +} + + +def _launch_dim(shape: torch.Size) -> int | tuple[int, ...]: + """Return Warp launch dimensions for 1D vs ND kernels.""" + return shape[0] if len(shape) == 1 else tuple(shape) + + +def _to_wp_components(components: list[torch.Tensor], count: int) -> list[wp.array]: + """Convert the leading tensor components to Warp arrays.""" + return [wp.from_torch(components[i], dtype=wp.float32) for i in range(count)] + + +def _to_wp_coords(coords_tuple: tuple[torch.Tensor, ...], ndim: int) -> list[wp.array]: + """Convert coordinate axes to Warp arrays.""" + return [wp.from_torch(coords_tuple[i], dtype=wp.float32) for i in range(ndim)] + + +def _period_values(period_tuple: tuple[float, ...], ndim: int) -> list[float]: + """Convert axis periods to float values.""" + return [float(period_tuple[i]) for i in range(ndim)] + + +def _launch_forward( + *, + field_fp32: torch.Tensor, + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], + derivative_order: int, + grad_components: list[torch.Tensor], + wp_device, + wp_stream, +) -> None: + ### Launch dimensionality-specific forward kernels. + ndim = field_fp32.ndim + kernel = _FORWARD_KERNELS[(ndim, derivative_order)] + inputs = [ + wp.from_torch(field_fp32, dtype=wp.float32), + *_to_wp_coords(coords_tuple, ndim), + *_period_values(period_tuple, ndim), + *_to_wp_components(grad_components, ndim), + ] + + with wp.ScopedStream(wp_stream): + wp.launch( + kernel=kernel, + dim=_launch_dim(field_fp32.shape), + inputs=inputs, + device=wp_device, + stream=wp_stream, + ) + + +def _launch_forward_fused_no_mixed( + *, + field_fp32: torch.Tensor, + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], + first_components: list[torch.Tensor], + second_components: list[torch.Tensor], + wp_device, + wp_stream, +) -> None: + """Launch dimensionality-specific fused first+second derivative kernels.""" + ndim = field_fp32.ndim + kernel = _FUSED_FORWARD_NO_MIXED_KERNELS[ndim] + inputs = [ + wp.from_torch(field_fp32, dtype=wp.float32), + *_to_wp_coords(coords_tuple, ndim), + *_period_values(period_tuple, ndim), + *_to_wp_components(first_components, ndim), + *_to_wp_components(second_components, ndim), + ] + + with wp.ScopedStream(wp_stream): + wp.launch( + kernel=kernel, + dim=_launch_dim(field_fp32.shape), + inputs=inputs, + device=wp_device, + stream=wp_stream, + ) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/op.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/op.py new file mode 100644 index 0000000000..a522e129a6 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/_warp_impl/op.py @@ -0,0 +1,860 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import warp as wp + +from physicsnemo.core.function_spec import FunctionSpec + +from .._torch_impl import rectilinear_grid_gradient_torch +from ..utils import ( + validate_and_normalize_coordinates, + validate_derivative_request, + validate_field, +) +from .launch_backward import _launch_backward, _launch_backward_fused_no_mixed +from .launch_forward import _launch_forward, _launch_forward_fused_no_mixed + +### Warp runtime initialization for custom kernels. +wp.init() +wp.config.quiet = True + + +def _rectilinear_forward_common( + field: torch.Tensor, + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Run rectilinear forward kernels and restore the caller dtype.""" + validate_field(field) + derivative_order = validate_derivative_request( + derivative_order=derivative_order, + include_mixed=include_mixed, + ) + coords_tuple, period_tuple = validate_and_normalize_coordinates( + field=field, + coordinates=coords_tuple, + periods=period_tuple, + coordinates_dtype=torch.float32, + requires_grad_error="coordinate gradients are not supported in warp backend", + ) + + orig_dtype = field.dtype + field_fp32 = field.to(dtype=torch.float32).contiguous() + grad_components = [torch.empty_like(field_fp32) for _ in range(field_fp32.ndim)] + + wp_device, wp_stream = FunctionSpec.warp_launch_context(field_fp32) + _launch_forward( + field_fp32=field_fp32, + coords_tuple=coords_tuple, + period_tuple=period_tuple, + derivative_order=derivative_order, + grad_components=grad_components, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + output = torch.stack(grad_components, dim=0) + if output.dtype != orig_dtype: + output = output.to(dtype=orig_dtype) + return output + + +def _rectilinear_setup_common( + ctx: torch.autograd.function.FunctionCtx, + field: torch.Tensor, + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], + derivative_order: int, + include_mixed: bool, +) -> None: + """Store normalized geometry metadata for rectilinear custom-op backward.""" + derivative_order = validate_derivative_request( + derivative_order=derivative_order, + include_mixed=include_mixed, + ) + _, period_tuple = validate_and_normalize_coordinates( + field=field, + coordinates=coords_tuple, + periods=period_tuple, + coordinates_dtype=torch.float32, + requires_grad_error="coordinate gradients are not supported in warp backend", + ) + ctx.save_for_backward( + *[coord.to(dtype=torch.float32).contiguous() for coord in coords_tuple] + ) + ctx.period_tuple = period_tuple + ctx.derivative_order = derivative_order + ctx.orig_dtype = field.dtype + + +def _rectilinear_backward_common( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> torch.Tensor | None: + """Evaluate the rectilinear backward kernels for one custom-op invocation.""" + if grad_output is None or not ctx.needs_input_grad[0]: + return None + + coords_tuple = tuple(ctx.saved_tensors) + period_tuple = tuple(float(v) for v in ctx.period_tuple) + grad_output_fp32 = grad_output.to(dtype=torch.float32).contiguous() + derivative_order = int(ctx.derivative_order) + + ### CUDA 1D second-derivative VJP is routed through torch autograd for numerical stability. + if ( + derivative_order == 2 + and grad_output_fp32.device.type == "cuda" + and grad_output_fp32.shape[0] == 1 + ): + with torch.enable_grad(): + probe = torch.zeros_like(grad_output_fp32[0], requires_grad=True) + probe_out = rectilinear_grid_gradient_torch( + field=probe, + coordinates=coords_tuple, + periods=period_tuple, + derivative_order=2, + include_mixed=False, + ) + grad_field = torch.autograd.grad( + outputs=probe_out, + inputs=probe, + grad_outputs=grad_output_fp32, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + if grad_field.dtype != ctx.orig_dtype: + grad_field = grad_field.to(dtype=ctx.orig_dtype) + return grad_field + + grad_field = torch.empty_like(grad_output_fp32[0]) + wp_device, wp_stream = FunctionSpec.warp_launch_context(grad_output_fp32) + _launch_backward( + grad_output_fp32=grad_output_fp32, + coords_tuple=coords_tuple, + period_tuple=period_tuple, + derivative_order=derivative_order, + grad_field=grad_field, + wp_device=wp_device, + wp_stream=wp_stream, + ) + if grad_field.dtype != ctx.orig_dtype: + grad_field = grad_field.to(dtype=ctx.orig_dtype) + return grad_field + + +def _rectilinear_forward_fused_no_mixed_common( + field: torch.Tensor, + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], +) -> torch.Tensor: + """Run fused first+second forward kernels and return stacked outputs.""" + validate_field(field) + coords_tuple, period_tuple = validate_and_normalize_coordinates( + field=field, + coordinates=coords_tuple, + periods=period_tuple, + coordinates_dtype=torch.float32, + requires_grad_error="coordinate gradients are not supported in warp backend", + ) + + orig_dtype = field.dtype + field_fp32 = field.to(dtype=torch.float32).contiguous() + first_components = [torch.empty_like(field_fp32) for _ in range(field_fp32.ndim)] + second_components = [torch.empty_like(field_fp32) for _ in range(field_fp32.ndim)] + + wp_device, wp_stream = FunctionSpec.warp_launch_context(field_fp32) + _launch_forward_fused_no_mixed( + field_fp32=field_fp32, + coords_tuple=coords_tuple, + period_tuple=period_tuple, + first_components=first_components, + second_components=second_components, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + output = torch.stack([*first_components, *second_components], dim=0) + if output.dtype != orig_dtype: + output = output.to(dtype=orig_dtype) + return output + + +def _rectilinear_setup_fused_no_mixed_common( + ctx: torch.autograd.function.FunctionCtx, + field: torch.Tensor, + coords_tuple: tuple[torch.Tensor, ...], + period_tuple: tuple[float, ...], +) -> None: + """Store metadata for fused first+second rectilinear custom-op backward.""" + _, period_tuple = validate_and_normalize_coordinates( + field=field, + coordinates=coords_tuple, + periods=period_tuple, + coordinates_dtype=torch.float32, + requires_grad_error="coordinate gradients are not supported in warp backend", + ) + ctx.save_for_backward( + *[coord.to(dtype=torch.float32).contiguous() for coord in coords_tuple] + ) + ctx.period_tuple = period_tuple + ctx.orig_dtype = field.dtype + ctx.n_dims = field.ndim + + +def _rectilinear_backward_fused_no_mixed_common( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> torch.Tensor | None: + """Backward for fused first+second custom ops (field gradients only).""" + if grad_output is None or not ctx.needs_input_grad[0]: + return None + + coords_tuple = tuple(ctx.saved_tensors) + period_tuple = tuple(float(v) for v in ctx.period_tuple) + n_dims = int(ctx.n_dims) + grad_output_fp32 = grad_output.to(dtype=torch.float32).contiguous() + + ### CUDA 1D fused VJP reuses torch autograd for the 2nd-derivative contribution. + if n_dims == 1 and grad_output_fp32.device.type == "cuda": + with torch.enable_grad(): + probe = torch.zeros_like(grad_output_fp32[0], requires_grad=True) + probe_first = rectilinear_grid_gradient_torch( + field=probe, + coordinates=coords_tuple, + periods=period_tuple, + derivative_order=1, + include_mixed=False, + ) + probe_second = rectilinear_grid_gradient_torch( + field=probe, + coordinates=coords_tuple, + periods=period_tuple, + derivative_order=2, + include_mixed=False, + ) + probe_out = torch.cat((probe_first, probe_second), dim=0) + grad_field = torch.autograd.grad( + outputs=probe_out, + inputs=probe, + grad_outputs=grad_output_fp32, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + if grad_field.dtype != ctx.orig_dtype: + grad_field = grad_field.to(dtype=ctx.orig_dtype) + return grad_field + + grad_first_components = [grad_output_fp32[i] for i in range(n_dims)] + grad_second_components = [grad_output_fp32[n_dims + i] for i in range(n_dims)] + + grad_field = torch.empty_like(grad_output_fp32[0]) + wp_device, wp_stream = FunctionSpec.warp_launch_context(grad_output_fp32) + _launch_backward_fused_no_mixed( + grad_first_components=grad_first_components, + grad_second_components=grad_second_components, + coords_tuple=coords_tuple, + period_tuple=period_tuple, + grad_field=grad_field, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + if grad_field.dtype != ctx.orig_dtype: + grad_field = grad_field.to(dtype=ctx.orig_dtype) + return grad_field + + +@torch.library.custom_op( + "physicsnemo::rectilinear_grid_gradient_1d_warp_impl", mutates_args=() +) +def rectilinear_grid_gradient_1d_impl( + field: torch.Tensor, + coord0: torch.Tensor, + period0: float, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Compute periodic 1D first or pure second derivatives with Warp kernels.""" + return _rectilinear_forward_common( + field=field, + coords_tuple=(coord0,), + period_tuple=(float(period0),), + derivative_order=int(derivative_order), + include_mixed=bool(include_mixed), + ) + + +@rectilinear_grid_gradient_1d_impl.register_fake +def _rectilinear_grid_gradient_1d_impl_fake( + field: torch.Tensor, + coord0: torch.Tensor, + period0: float, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Fake tensor propagation for 1D rectilinear custom op.""" + _ = (coord0, period0, derivative_order, include_mixed) + return torch.empty((1, *field.shape), device=field.device, dtype=field.dtype) + + +def setup_rectilinear_grid_gradient_1d_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for 1D rectilinear custom op.""" + field, coord0, period0, derivative_order, include_mixed = inputs + _ = output + _rectilinear_setup_common( + ctx=ctx, + field=field, + coords_tuple=(coord0,), + period_tuple=(float(period0),), + derivative_order=int(derivative_order), + include_mixed=bool(include_mixed), + ) + + +def backward_rectilinear_grid_gradient_1d( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None, None, None]: + """Backward pass for 1D rectilinear custom op (field gradients only).""" + grad_field = _rectilinear_backward_common(ctx, grad_output) + return grad_field, None, None, None, None + + +rectilinear_grid_gradient_1d_impl.register_autograd( + backward_rectilinear_grid_gradient_1d, + setup_context=setup_rectilinear_grid_gradient_1d_context, +) + + +@torch.library.custom_op( + "physicsnemo::rectilinear_grid_gradient_2d_warp_impl", mutates_args=() +) +def rectilinear_grid_gradient_2d_impl( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + period0: float, + period1: float, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Compute periodic 2D first or pure second derivatives with Warp kernels.""" + return _rectilinear_forward_common( + field=field, + coords_tuple=(coord0, coord1), + period_tuple=(float(period0), float(period1)), + derivative_order=int(derivative_order), + include_mixed=bool(include_mixed), + ) + + +@rectilinear_grid_gradient_2d_impl.register_fake +def _rectilinear_grid_gradient_2d_impl_fake( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + period0: float, + period1: float, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Fake tensor propagation for 2D rectilinear custom op.""" + _ = (coord0, coord1, period0, period1, derivative_order, include_mixed) + return torch.empty((2, *field.shape), device=field.device, dtype=field.dtype) + + +def setup_rectilinear_grid_gradient_2d_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for 2D rectilinear custom op.""" + field, coord0, coord1, period0, period1, derivative_order, include_mixed = inputs + _ = output + _rectilinear_setup_common( + ctx=ctx, + field=field, + coords_tuple=(coord0, coord1), + period_tuple=(float(period0), float(period1)), + derivative_order=int(derivative_order), + include_mixed=bool(include_mixed), + ) + + +def backward_rectilinear_grid_gradient_2d( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None, None, None, None, None]: + """Backward pass for 2D rectilinear custom op (field gradients only).""" + grad_field = _rectilinear_backward_common(ctx, grad_output) + return grad_field, None, None, None, None, None, None + + +rectilinear_grid_gradient_2d_impl.register_autograd( + backward_rectilinear_grid_gradient_2d, + setup_context=setup_rectilinear_grid_gradient_2d_context, +) + + +@torch.library.custom_op( + "physicsnemo::rectilinear_grid_gradient_3d_warp_impl", mutates_args=() +) +def rectilinear_grid_gradient_3d_impl( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + coord2: torch.Tensor, + period0: float, + period1: float, + period2: float, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Compute periodic 3D first or pure second derivatives with Warp kernels.""" + return _rectilinear_forward_common( + field=field, + coords_tuple=(coord0, coord1, coord2), + period_tuple=(float(period0), float(period1), float(period2)), + derivative_order=int(derivative_order), + include_mixed=bool(include_mixed), + ) + + +@rectilinear_grid_gradient_3d_impl.register_fake +def _rectilinear_grid_gradient_3d_impl_fake( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + coord2: torch.Tensor, + period0: float, + period1: float, + period2: float, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Fake tensor propagation for 3D rectilinear custom op.""" + _ = ( + coord0, + coord1, + coord2, + period0, + period1, + period2, + derivative_order, + include_mixed, + ) + return torch.empty((3, *field.shape), device=field.device, dtype=field.dtype) + + +def setup_rectilinear_grid_gradient_3d_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for 3D rectilinear custom op.""" + ( + field, + coord0, + coord1, + coord2, + period0, + period1, + period2, + derivative_order, + include_mixed, + ) = inputs + _ = output + _rectilinear_setup_common( + ctx=ctx, + field=field, + coords_tuple=(coord0, coord1, coord2), + period_tuple=(float(period0), float(period1), float(period2)), + derivative_order=int(derivative_order), + include_mixed=bool(include_mixed), + ) + + +def backward_rectilinear_grid_gradient_3d( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None, None, None, None, None, None, None]: + """Backward pass for 3D rectilinear custom op (field gradients only).""" + grad_field = _rectilinear_backward_common(ctx, grad_output) + return grad_field, None, None, None, None, None, None, None, None + + +rectilinear_grid_gradient_3d_impl.register_autograd( + backward_rectilinear_grid_gradient_3d, + setup_context=setup_rectilinear_grid_gradient_3d_context, +) + + +@torch.library.custom_op( + "physicsnemo::rectilinear_derivatives_1d_fused_no_mixed_warp_impl", + mutates_args=(), +) +def rectilinear_derivatives_1d_fused_no_mixed_impl( + field: torch.Tensor, + coord0: torch.Tensor, + period0: float, +) -> torch.Tensor: + """Compute fused 1D first+second derivatives with one Warp launch.""" + return _rectilinear_forward_fused_no_mixed_common( + field=field, + coords_tuple=(coord0,), + period_tuple=(float(period0),), + ) + + +@rectilinear_derivatives_1d_fused_no_mixed_impl.register_fake +def _rectilinear_derivatives_1d_fused_no_mixed_impl_fake( + field: torch.Tensor, + coord0: torch.Tensor, + period0: float, +) -> torch.Tensor: + """Fake tensor propagation for fused 1D rectilinear custom op.""" + _ = (coord0, period0) + return torch.empty((2, *field.shape), device=field.device, dtype=field.dtype) + + +def setup_rectilinear_derivatives_1d_fused_no_mixed_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for fused 1D rectilinear custom op.""" + field, coord0, period0 = inputs + _ = output + _rectilinear_setup_fused_no_mixed_common( + ctx=ctx, + field=field, + coords_tuple=(coord0,), + period_tuple=(float(period0),), + ) + + +def backward_rectilinear_derivatives_1d_fused_no_mixed( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None]: + """Backward pass for fused 1D rectilinear custom op.""" + grad_field = _rectilinear_backward_fused_no_mixed_common(ctx, grad_output) + return grad_field, None, None + + +rectilinear_derivatives_1d_fused_no_mixed_impl.register_autograd( + backward_rectilinear_derivatives_1d_fused_no_mixed, + setup_context=setup_rectilinear_derivatives_1d_fused_no_mixed_context, +) + + +@torch.library.custom_op( + "physicsnemo::rectilinear_derivatives_2d_fused_no_mixed_warp_impl", + mutates_args=(), +) +def rectilinear_derivatives_2d_fused_no_mixed_impl( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + period0: float, + period1: float, +) -> torch.Tensor: + """Compute fused 2D first+second derivatives with one Warp launch.""" + return _rectilinear_forward_fused_no_mixed_common( + field=field, + coords_tuple=(coord0, coord1), + period_tuple=(float(period0), float(period1)), + ) + + +@rectilinear_derivatives_2d_fused_no_mixed_impl.register_fake +def _rectilinear_derivatives_2d_fused_no_mixed_impl_fake( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + period0: float, + period1: float, +) -> torch.Tensor: + """Fake tensor propagation for fused 2D rectilinear custom op.""" + _ = (coord0, coord1, period0, period1) + return torch.empty((4, *field.shape), device=field.device, dtype=field.dtype) + + +def setup_rectilinear_derivatives_2d_fused_no_mixed_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for fused 2D rectilinear custom op.""" + field, coord0, coord1, period0, period1 = inputs + _ = output + _rectilinear_setup_fused_no_mixed_common( + ctx=ctx, + field=field, + coords_tuple=(coord0, coord1), + period_tuple=(float(period0), float(period1)), + ) + + +def backward_rectilinear_derivatives_2d_fused_no_mixed( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None, None, None]: + """Backward pass for fused 2D rectilinear custom op.""" + grad_field = _rectilinear_backward_fused_no_mixed_common(ctx, grad_output) + return grad_field, None, None, None, None + + +rectilinear_derivatives_2d_fused_no_mixed_impl.register_autograd( + backward_rectilinear_derivatives_2d_fused_no_mixed, + setup_context=setup_rectilinear_derivatives_2d_fused_no_mixed_context, +) + + +@torch.library.custom_op( + "physicsnemo::rectilinear_derivatives_3d_fused_no_mixed_warp_impl", + mutates_args=(), +) +def rectilinear_derivatives_3d_fused_no_mixed_impl( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + coord2: torch.Tensor, + period0: float, + period1: float, + period2: float, +) -> torch.Tensor: + """Compute fused 3D first+second derivatives with one Warp launch.""" + return _rectilinear_forward_fused_no_mixed_common( + field=field, + coords_tuple=(coord0, coord1, coord2), + period_tuple=(float(period0), float(period1), float(period2)), + ) + + +@rectilinear_derivatives_3d_fused_no_mixed_impl.register_fake +def _rectilinear_derivatives_3d_fused_no_mixed_impl_fake( + field: torch.Tensor, + coord0: torch.Tensor, + coord1: torch.Tensor, + coord2: torch.Tensor, + period0: float, + period1: float, + period2: float, +) -> torch.Tensor: + """Fake tensor propagation for fused 3D rectilinear custom op.""" + _ = (coord0, coord1, coord2, period0, period1, period2) + return torch.empty((6, *field.shape), device=field.device, dtype=field.dtype) + + +def setup_rectilinear_derivatives_3d_fused_no_mixed_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for fused 3D rectilinear custom op.""" + field, coord0, coord1, coord2, period0, period1, period2 = inputs + _ = output + _rectilinear_setup_fused_no_mixed_common( + ctx=ctx, + field=field, + coords_tuple=(coord0, coord1, coord2), + period_tuple=(float(period0), float(period1), float(period2)), + ) + + +def backward_rectilinear_derivatives_3d_fused_no_mixed( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None, None, None, None, None]: + """Backward pass for fused 3D rectilinear custom op.""" + grad_field = _rectilinear_backward_fused_no_mixed_common(ctx, grad_output) + return grad_field, None, None, None, None, None, None + + +rectilinear_derivatives_3d_fused_no_mixed_impl.register_autograd( + backward_rectilinear_derivatives_3d_fused_no_mixed, + setup_context=setup_rectilinear_derivatives_3d_fused_no_mixed_context, +) + + +def rectilinear_grid_gradient_warp( + field: torch.Tensor, + coordinates: Sequence[torch.Tensor], + periods: float | Sequence[float] | None = None, + derivative_order: int = 1, + include_mixed: bool = False, +) -> torch.Tensor: + """Compute periodic first or pure second derivatives on rectilinear grids.""" + ### Validate field shape/dtype and normalize coordinates. + validate_field(field) + derivative_order = validate_derivative_request( + derivative_order=derivative_order, + include_mixed=include_mixed, + ) + + coords_tuple, period_tuple = validate_and_normalize_coordinates( + field=field, + coordinates=coordinates, + periods=periods, + coordinates_dtype=torch.float32, + requires_grad_error="coordinate gradients are not supported in warp backend", + ) + + if field.ndim == 1: + return rectilinear_grid_gradient_1d_impl( + field, + coords_tuple[0], + float(period_tuple[0]), + int(derivative_order), + bool(include_mixed), + ) + if field.ndim == 2: + return rectilinear_grid_gradient_2d_impl( + field, + coords_tuple[0], + coords_tuple[1], + float(period_tuple[0]), + float(period_tuple[1]), + int(derivative_order), + bool(include_mixed), + ) + return rectilinear_grid_gradient_3d_impl( + field, + coords_tuple[0], + coords_tuple[1], + coords_tuple[2], + float(period_tuple[0]), + float(period_tuple[1]), + float(period_tuple[2]), + int(derivative_order), + bool(include_mixed), + ) + + +def rectilinear_grid_gradient_warp_multi( + field: torch.Tensor, + coordinates: Sequence[torch.Tensor], + periods: float | Sequence[float] | None, + derivative_orders: tuple[int, ...], + include_mixed: bool, +) -> torch.Tensor: + """Compute multiple derivative families, fusing first+second when possible. + + For ``derivative_orders=(1, 2)`` with ``include_mixed=False``, this uses a + fused custom-op path (single fused forward launch + fused backward kernels). + Mixed requests are composed from single-order custom ops to preserve output + ordering and autograd behavior. + """ + validate_field(field) + coords_tuple, period_tuple = validate_and_normalize_coordinates( + field=field, + coordinates=coordinates, + periods=periods, + coordinates_dtype=torch.float32, + requires_grad_error="coordinate gradients are not supported in warp backend", + ) + + if include_mixed and 2 not in derivative_orders: + raise ValueError("include_mixed requires requesting 2nd derivatives") + if include_mixed and field.ndim < 2: + raise ValueError("mixed derivatives require at least 2D inputs") + + ### Mixed requests are composed from single-order custom ops. + if include_mixed: + outputs: list[torch.Tensor] = [] + first_terms = rectilinear_grid_gradient_warp( + field=field, + coordinates=coords_tuple, + periods=period_tuple, + derivative_order=1, + include_mixed=False, + ) + + if 1 in derivative_orders: + outputs.extend(first_terms.unbind(0)) + + if 2 in derivative_orders: + pure_second_terms = rectilinear_grid_gradient_warp( + field=field, + coordinates=coords_tuple, + periods=period_tuple, + derivative_order=2, + include_mixed=False, + ) + outputs.extend(pure_second_terms.unbind(0)) + + for axis_i in range(field.ndim): + for axis_j in range(axis_i + 1, field.ndim): + mixed_ij = rectilinear_grid_gradient_warp( + field=first_terms[axis_i], + coordinates=coords_tuple, + periods=period_tuple, + derivative_order=1, + include_mixed=False, + )[axis_j] + outputs.append(mixed_ij) + return torch.stack(outputs, dim=0) + + ### Single-order requests should use the direct single-order custom ops. + if len(derivative_orders) == 1: + return rectilinear_grid_gradient_warp( + field=field, + coordinates=coords_tuple, + periods=period_tuple, + derivative_order=int(derivative_orders[0]), + include_mixed=False, + ) + + ### Fused no-mixed path with custom-op backward for combined first+second. + if 1 in derivative_orders and 2 in derivative_orders: + if field.ndim == 1: + fused = rectilinear_derivatives_1d_fused_no_mixed_impl( + field, + coords_tuple[0], + float(period_tuple[0]), + ) + elif field.ndim == 2: + fused = rectilinear_derivatives_2d_fused_no_mixed_impl( + field, + coords_tuple[0], + coords_tuple[1], + float(period_tuple[0]), + float(period_tuple[1]), + ) + else: + fused = rectilinear_derivatives_3d_fused_no_mixed_impl( + field, + coords_tuple[0], + coords_tuple[1], + coords_tuple[2], + float(period_tuple[0]), + float(period_tuple[1]), + float(period_tuple[2]), + ) + + outputs: list[torch.Tensor] = [] + n_dims = field.ndim + if 1 in derivative_orders: + outputs.extend(fused[:n_dims].unbind(0)) + if 2 in derivative_orders: + outputs.extend(fused[n_dims:].unbind(0)) + return torch.stack(outputs, dim=0) + + raise RuntimeError( + "Unhandled derivative request in rectilinear_grid_gradient_warp_multi" + ) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/rectilinear_grid_gradient.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/rectilinear_grid_gradient.py new file mode 100644 index 0000000000..a6800c8ffb --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/rectilinear_grid_gradient.py @@ -0,0 +1,394 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from itertools import combinations + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from .._request_utils import ( + normalize_derivative_orders, + normalize_include_mixed, + validate_mixed_request, +) +from ._torch_impl import rectilinear_grid_gradient_torch +from ._warp_impl import ( + rectilinear_grid_gradient_warp, + rectilinear_grid_gradient_warp_multi, +) + + +class RectilinearGridGradient(FunctionSpec): + r"""Compute periodic gradients on rectilinear grids with nonuniform spacing. + + This functional computes first-order and/or second-order + derivatives of a scalar field on a 1D/2D/3D rectilinear grid where each + axis has independent, potentially nonuniform coordinate spacing. + + For each axis :math:`k`, first-order nonuniform central differencing is: + + .. math:: + + \partial_k f_i \approx + a_i\,f_{i-1} + b_i\,f_i + c_i\,f_{i+1} + + with + + .. math:: + + a_i = -\frac{h_i^+}{h_i^-(h_i^-+h_i^+)}, \quad + b_i = \frac{h_i^+ - h_i^-}{h_i^- h_i^+}, \quad + c_i = \frac{h_i^-}{h_i^+(h_i^-+h_i^+)} + + and pure second derivatives are: + + .. math:: + + \partial_{kk} f_i \approx + \tilde{a}_i\,f_{i-1} + \tilde{b}_i\,f_i + \tilde{c}_i\,f_{i+1} + + with + + .. math:: + + \tilde{a}_i = \frac{2}{h_i^-(h_i^-+h_i^+)}, \quad + \tilde{b}_i = -\frac{2}{h_i^- h_i^+}, \quad + \tilde{c}_i = \frac{2}{h_i^+(h_i^-+h_i^+)} + + where :math:`h_i^-` and :math:`h_i^+` are left/right periodic distances + along that axis. + + Parameters + ---------- + field : torch.Tensor + Scalar grid field with shape ``(n0,)``, ``(n0,n1)``, or ``(n0,n1,n2)``. + coordinates : Sequence[torch.Tensor] + Per-axis coordinate tensors ``(x0, x1, x2)`` matching field dimensions. + Each axis tensor must be rank-1, strictly increasing, and length + compatible with ``field.shape[axis]``. + periods : float | Sequence[float] | None, optional + Period length per axis. If ``None``, each axis is inferred as + ``coords[-1] - coords[0] + (coords[1] - coords[0])``. + derivative_orders : int | Sequence[int], optional + Derivative orders to compute. Supported values are ``1``, ``2``, or + ``(1, 2)``. + include_mixed : bool, optional + Include mixed second derivatives when requesting second derivatives. + Mixed terms are appended in axis-pair order ``(x,y)``, ``(x,z)``, + ``(y,z)``. + implementation : {"warp", "torch"} or None + Explicit backend selection. When ``None``, dispatch selects by rank. + + Returns + ------- + torch.Tensor + Gradient tensor of shape ``(num_derivatives, *field.shape)``. + """ + + ### Benchmark input presets (small -> large workload). + _BENCHMARK_CASES = ( + ("1d-n8192-d1", (8192,), 1), + ("1d-n512-d2", (512,), 2), + ("2d-384x384-d1", (384, 384), 1), + ("2d-256x256-d2", (256, 256), 2), + ("3d-96x96x96-d1", (96, 96, 96), 1), + ("3d-64x64x64-d2", (64, 64, 64), 2), + ) + + _COMPARE_ATOL = 5e-2 + _COMPARE_RTOL = 5e-2 + _COMPARE_BACKWARD_ATOL = 5e-2 + _COMPARE_BACKWARD_RTOL = 5e-2 + + @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0) + def warp_forward( + field: torch.Tensor, + coordinates: Sequence[torch.Tensor], + periods: float | Sequence[float] | None = None, + derivative_orders: int | Sequence[int] = 1, + include_mixed: bool = False, + ) -> torch.Tensor: + """Dispatch rectilinear gradients to the Warp backend.""" + return _dispatch_rectilinear_grid_requests( + backend_fn=rectilinear_grid_gradient_warp, + field=field, + coordinates=coordinates, + periods=periods, + derivative_orders=derivative_orders, + include_mixed=include_mixed, + ) + + @FunctionSpec.register(name="torch", rank=1, baseline=True) + def torch_forward( + field: torch.Tensor, + coordinates: Sequence[torch.Tensor], + periods: float | Sequence[float] | None = None, + derivative_orders: int | Sequence[int] = 1, + include_mixed: bool = False, + ) -> torch.Tensor: + """Dispatch rectilinear gradients to eager PyTorch.""" + return _dispatch_rectilinear_grid_requests( + backend_fn=rectilinear_grid_gradient_torch, + field=field, + coordinates=coordinates, + periods=periods, + derivative_orders=derivative_orders, + include_mixed=include_mixed, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build periodic nonuniform rectilinear coordinates and analytic fields. + for label, shape, derivative_order in cls._BENCHMARK_CASES: + if len(shape) == 1: + n0 = shape[0] + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=device)[:-1] + x0 = s0 + 0.04 * torch.sin(2.0 * torch.pi * s0) + field = torch.sin(2.0 * torch.pi * x0) + coordinates = (x0.to(torch.float32),) + periods = 1.0 + elif len(shape) == 2: + n0, n1 = shape + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=device)[:-1] + s1 = torch.linspace(0.0, 1.0, n1 + 1, device=device)[:-1] + x0 = s0 + 0.04 * torch.sin(2.0 * torch.pi * s0) + x1 = s1 + 0.03 * torch.sin(2.0 * torch.pi * s1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + field = torch.sin(2.0 * torch.pi * xx) + 0.5 * torch.cos( + 2.0 * torch.pi * yy + ) + coordinates = (x0.to(torch.float32), x1.to(torch.float32)) + periods = (1.0, 1.0) + else: + n0, n1, n2 = shape + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=device)[:-1] + s1 = torch.linspace(0.0, 1.0, n1 + 1, device=device)[:-1] + s2 = torch.linspace(0.0, 1.0, n2 + 1, device=device)[:-1] + x0 = s0 + 0.04 * torch.sin(2.0 * torch.pi * s0) + x1 = s1 + 0.03 * torch.sin(2.0 * torch.pi * s1) + x2 = s2 + 0.02 * torch.sin(2.0 * torch.pi * s2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + field = ( + torch.sin(2.0 * torch.pi * xx) + + 0.5 * torch.cos(2.0 * torch.pi * yy) + + 0.25 * torch.sin(2.0 * torch.pi * zz) + ) + coordinates = ( + x0.to(torch.float32), + x1.to(torch.float32), + x2.to(torch.float32), + ) + periods = (1.0, 1.0, 1.0) + + ### Yield each labeled benchmark/parity input. + yield ( + label, + (field.to(torch.float32), coordinates), + { + "periods": periods, + "derivative_orders": derivative_order, + "include_mixed": False, + }, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build differentiable field inputs for backward parity checks. + backward_cases = ( + ("1d-grad-n4096-d1", (4096,), 1), + ("1d-grad-n512-d2", (512,), 2), + ("2d-grad-256x256-d1", (256, 256), 1), + ("2d-grad-192x192-d2", (192, 192), 2), + ("3d-grad-80x80x80-d1", (80, 80, 80), 1), + ("3d-grad-56x56x56-d2", (56, 56, 56), 2), + ) + + for label, shape, derivative_order in backward_cases: + if len(shape) == 1: + n0 = shape[0] + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=device)[:-1] + x0 = s0 + 0.04 * torch.sin(2.0 * torch.pi * s0) + field = torch.sin(2.0 * torch.pi * x0) + coordinates = (x0.to(torch.float32),) + periods = 1.0 + elif len(shape) == 2: + n0, n1 = shape + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=device)[:-1] + s1 = torch.linspace(0.0, 1.0, n1 + 1, device=device)[:-1] + x0 = s0 + 0.04 * torch.sin(2.0 * torch.pi * s0) + x1 = s1 + 0.03 * torch.sin(2.0 * torch.pi * s1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + field = torch.sin(2.0 * torch.pi * xx) + 0.5 * torch.cos( + 2.0 * torch.pi * yy + ) + coordinates = (x0.to(torch.float32), x1.to(torch.float32)) + periods = (1.0, 1.0) + else: + n0, n1, n2 = shape + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=device)[:-1] + s1 = torch.linspace(0.0, 1.0, n1 + 1, device=device)[:-1] + s2 = torch.linspace(0.0, 1.0, n2 + 1, device=device)[:-1] + x0 = s0 + 0.04 * torch.sin(2.0 * torch.pi * s0) + x1 = s1 + 0.03 * torch.sin(2.0 * torch.pi * s1) + x2 = s2 + 0.02 * torch.sin(2.0 * torch.pi * s2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + field = ( + torch.sin(2.0 * torch.pi * xx) + + 0.5 * torch.cos(2.0 * torch.pi * yy) + + 0.25 * torch.sin(2.0 * torch.pi * zz) + ) + coordinates = ( + x0.to(torch.float32), + x1.to(torch.float32), + x2.to(torch.float32), + ) + periods = (1.0, 1.0, 1.0) + + yield ( + label, + ( + field.to(torch.float32).detach().clone().requires_grad_(True), + coordinates, + ), + { + "periods": periods, + "derivative_orders": derivative_order, + "include_mixed": False, + }, + ) + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare forward outputs across implementations.""" + ### Validate forward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + @classmethod + def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare backward gradients across implementations.""" + ### Validate backward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_BACKWARD_ATOL, + rtol=cls._COMPARE_BACKWARD_RTOL, + ) + + +rectilinear_grid_gradient = RectilinearGridGradient.make_function( + "rectilinear_grid_gradient" +) + + +__all__ = ["RectilinearGridGradient", "rectilinear_grid_gradient"] + + +def _dispatch_rectilinear_grid_requests( + *, + backend_fn, + field: torch.Tensor, + coordinates: Sequence[torch.Tensor], + periods: float | Sequence[float] | None, + derivative_orders: int | Sequence[int], + include_mixed: bool, +) -> torch.Tensor: + """Resolve unified derivative requests and dispatch to backend kernels.""" + requested_orders = normalize_derivative_orders( + derivative_orders=derivative_orders, + function_name="rectilinear_grid_gradient", + ) + mixed_terms = normalize_include_mixed( + include_mixed=include_mixed, + function_name="rectilinear_grid_gradient", + ) + validate_mixed_request( + derivative_orders=requested_orders, + include_mixed=mixed_terms, + ndim=field.ndim, + function_name="rectilinear_grid_gradient", + ) + + if backend_fn is rectilinear_grid_gradient_warp and ( + len(requested_orders) > 1 or mixed_terms + ): + return rectilinear_grid_gradient_warp_multi( + field=field, + coordinates=coordinates, + periods=periods, + derivative_orders=requested_orders, + include_mixed=mixed_terms, + ) + + outputs: list[torch.Tensor] = [] + first_terms: torch.Tensor | None = None + + if 1 in requested_orders: + first_terms = backend_fn( + field=field, + coordinates=coordinates, + periods=periods, + derivative_order=1, + include_mixed=False, + ) + outputs.extend(first_terms.unbind(0)) + + if 2 in requested_orders: + pure_second_terms = backend_fn( + field=field, + coordinates=coordinates, + periods=periods, + derivative_order=2, + include_mixed=False, + ) + outputs.extend(pure_second_terms.unbind(0)) + + if mixed_terms: + if first_terms is None: + first_terms = backend_fn( + field=field, + coordinates=coordinates, + periods=periods, + derivative_order=1, + include_mixed=False, + ) + + for axis_i, axis_j in combinations(range(field.ndim), 2): + axis_i_first = first_terms[axis_i] + mixed_ij = backend_fn( + field=axis_i_first, + coordinates=coordinates, + periods=periods, + derivative_order=1, + include_mixed=False, + )[axis_j] + outputs.append(mixed_ij) + + return torch.stack(outputs, dim=0) diff --git a/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/utils.py b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/utils.py new file mode 100644 index 0000000000..ae95492376 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/rectilinear_grid_gradient/utils.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +_SUPPORTED_DERIVATIVE_ORDERS = (1, 2) + + +def validate_field(field: torch.Tensor) -> None: + """Validate shared rectilinear field input constraints.""" + if field.ndim < 1 or field.ndim > 3: + raise ValueError( + f"rectilinear_grid_gradient supports 1D-3D fields, got {field.shape=}" + ) + if not torch.is_floating_point(field): + raise TypeError("field must be a floating-point tensor") + + +def normalize_periods( + periods: float | Sequence[float] | None, + coordinates: tuple[torch.Tensor, ...], +) -> tuple[float, ...]: + """Normalize explicit/inferred periodic lengths to one value per axis.""" + ndim = len(coordinates) + if periods is None: + inferred: list[float] = [] + for coords in coordinates: + if coords.numel() < 2: + raise ValueError( + "rectilinear_grid_gradient requires at least two coordinates per axis" + ) + inferred.append( + float((coords[-1] - coords[0] + (coords[1] - coords[0])).item()) + ) + return tuple(inferred) + if isinstance(periods, (float, int)): + return tuple(float(periods) for _ in range(ndim)) + periods_tuple = tuple(float(v) for v in periods) + if len(periods_tuple) != ndim: + raise ValueError( + f"periods must have {ndim} entries for a {ndim}D field, got {len(periods_tuple)}" + ) + return periods_tuple + + +def validate_and_normalize_coordinates( + field: torch.Tensor, + coordinates: Sequence[torch.Tensor], + periods: float | Sequence[float] | None, + *, + coordinates_dtype: torch.dtype, + requires_grad_error: str, +) -> tuple[tuple[torch.Tensor, ...], tuple[float, ...]]: + """Validate rectilinear coordinates and return normalized coordinates/periods.""" + if len(coordinates) != field.ndim: + raise ValueError( + f"coordinates must contain one axis tensor per field dimension ({field.ndim}), " + f"got {len(coordinates)}" + ) + + normalized_coords: list[torch.Tensor] = [] + for axis, coords in enumerate(coordinates): + if not isinstance(coords, torch.Tensor): + raise TypeError(f"coordinates[{axis}] must be a tensor") + if coords.ndim != 1: + raise ValueError( + f"coordinates[{axis}] must be rank-1, got shape={tuple(coords.shape)}" + ) + if coords.shape[0] != field.shape[axis]: + raise ValueError( + f"coordinates[{axis}] length must equal field.shape[{axis}] " + f"({field.shape[axis]}), got {coords.shape[0]}" + ) + if coords.requires_grad: + raise ValueError(requires_grad_error) + if not torch.is_floating_point(coords): + raise TypeError(f"coordinates[{axis}] must be floating-point") + if coords.device != field.device: + raise ValueError("field and coordinates must be on the same device") + if coords.numel() < 3: + raise ValueError( + "each coordinate axis must contain at least 3 points for central differencing" + ) + + coords_norm = coords.to(dtype=coordinates_dtype).contiguous() + diffs = coords_norm[1:] - coords_norm[:-1] + if torch.any(diffs <= 0): + raise ValueError(f"coordinates[{axis}] must be strictly increasing") + normalized_coords.append(coords_norm) + + period_tuple = normalize_periods( + periods=periods, coordinates=tuple(normalized_coords) + ) + for axis, period in enumerate(period_tuple): + if period <= 0.0: + raise ValueError("all periodic lengths must be strictly positive") + min_period = float( + (normalized_coords[axis][-1] - normalized_coords[axis][0]).item() + ) + if period <= min_period: + raise ValueError( + f"periods[{axis}] must be larger than coordinate span ({min_period}), got {period}" + ) + + return tuple(normalized_coords), period_tuple + + +def validate_derivative_request( + *, + derivative_order: int, + include_mixed: bool, +) -> int: + """Validate derivative-order/mixed-term request for phase-1 behavior.""" + if not isinstance(derivative_order, int): + raise TypeError( + f"derivative_order must be an integer, got {type(derivative_order)}" + ) + if derivative_order not in _SUPPORTED_DERIVATIVE_ORDERS: + raise ValueError( + "rectilinear_grid_gradient supports derivative_order in [1, 2], " + f"got derivative_order={derivative_order}" + ) + if not isinstance(include_mixed, bool): + raise TypeError(f"include_mixed must be a bool, got {type(include_mixed)}") + if include_mixed and derivative_order != 2: + raise ValueError("include_mixed is only valid when derivative_order=2") + if include_mixed: + raise NotImplementedError( + "include_mixed=True is not yet supported; phase-1 supports pure axis-wise " + "second derivatives only" + ) + return derivative_order + + +def axis_central_weights( + coords: torch.Tensor, + period: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build periodic nonuniform second-order central-difference weights.""" + diffs = coords[1:] - coords[:-1] + + h_plus = torch.empty_like(coords) + h_plus[:-1] = diffs + h_plus[-1] = period - (coords[-1] - coords[0]) + + h_minus = torch.empty_like(coords) + h_minus[1:] = diffs + h_minus[0] = h_plus[-1] + + if torch.any(h_minus <= 0.0) or torch.any(h_plus <= 0.0): + raise ValueError( + "rectilinear coordinates/period produce non-positive periodic spacing" + ) + + denom = h_minus + h_plus + w_minus = -h_plus / (h_minus * denom) + w_center = (h_plus - h_minus) / (h_minus * h_plus) + w_plus = h_minus / (h_plus * denom) + return w_minus, w_center, w_plus + + +def axis_second_derivative_weights( + coords: torch.Tensor, + period: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build periodic nonuniform second-derivative weights.""" + diffs = coords[1:] - coords[:-1] + + h_plus = torch.empty_like(coords) + h_plus[:-1] = diffs + h_plus[-1] = period - (coords[-1] - coords[0]) + + h_minus = torch.empty_like(coords) + h_minus[1:] = diffs + h_minus[0] = h_plus[-1] + + if torch.any(h_minus <= 0.0) or torch.any(h_plus <= 0.0): + raise ValueError( + "rectilinear coordinates/period produce non-positive periodic spacing" + ) + + denom = h_minus + h_plus + w_minus = 2.0 / (h_minus * denom) + w_center = -2.0 / (h_minus * h_plus) + w_plus = 2.0 / (h_plus * denom) + return w_minus, w_center, w_plus diff --git a/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/__init__.py b/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/__init__.py new file mode 100644 index 0000000000..b8dcf9c0a7 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .spectral_grid_gradient import SpectralGridGradient, spectral_grid_gradient + +__all__ = ["SpectralGridGradient", "spectral_grid_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/_torch_impl.py b/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/_torch_impl.py new file mode 100644 index 0000000000..e4e1e01aaa --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/_torch_impl.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from itertools import combinations + +import torch + + +def _normalize_lengths( + lengths: float | Sequence[float], ndim: int +) -> tuple[float, ...]: + """Normalize periodic lengths into one positive entry per axis.""" + if isinstance(lengths, (float, int)): + lengths_tuple = tuple(float(lengths) for _ in range(ndim)) + else: + lengths_tuple = tuple(float(v) for v in lengths) + if len(lengths_tuple) != ndim: + raise ValueError( + f"lengths must have {ndim} entries for a {ndim}D field, got {len(lengths_tuple)}" + ) + + for axis, length in enumerate(lengths_tuple): + if length <= 0.0: + raise ValueError(f"lengths[{axis}] must be strictly positive") + return lengths_tuple + + +def _validate_inputs( + field: torch.Tensor, + lengths: float | Sequence[float], + order: int, + return_mixed_derivs: bool, +) -> tuple[tuple[float, ...], torch.Tensor]: + """Validate spectral-gradient inputs and return normalized parameters.""" + if field.ndim < 1 or field.ndim > 3: + raise ValueError( + f"spectral_grid_gradient supports 1D-3D fields, got field.shape={tuple(field.shape)}" + ) + if not torch.is_floating_point(field): + raise TypeError("field must be a floating-point tensor") + if order not in (1, 2): + raise ValueError(f"order must be 1 or 2, got {order}") + if return_mixed_derivs and order != 2: + raise ValueError("return_mixed_derivs=True requires order=2") + if return_mixed_derivs and field.ndim == 1: + raise ValueError("mixed derivatives require at least 2D fields") + + lengths_tuple = _normalize_lengths(lengths=lengths, ndim=field.ndim) + + if field.dtype in (torch.float16, torch.bfloat16): + field_eval = field.to(torch.float32) + else: + field_eval = field + return lengths_tuple, field_eval + + +def _wavenumbers( + shape: Sequence[int], + lengths: Sequence[float], + *, + device: torch.device, + dtype: torch.dtype, +) -> list[torch.Tensor]: + """Build broadcastable angular wavenumber tensors for each axis.""" + ks: list[torch.Tensor] = [] + for axis, (n_axis, length_axis) in enumerate(zip(shape, lengths)): + freq_axis = torch.fft.fftfreq( + n_axis, + d=length_axis / float(n_axis), + device=device, + dtype=dtype, + ) + k_axis = 2.0 * torch.pi * freq_axis + view_shape = [1] * len(shape) + view_shape[axis] = n_axis + ks.append(k_axis.reshape(view_shape)) + return ks + + +def spectral_grid_gradient_torch( + field: torch.Tensor, + lengths: float | Sequence[float] = 1.0, + order: int = 1, + return_mixed_derivs: bool = False, +) -> torch.Tensor: + """Compute periodic grid derivatives with spectral differentiation. + + This implementation assumes periodic boundaries along each axis and computes + derivatives in Fourier space, then transforms back with inverse FFT. + """ + lengths_tuple, field_eval = _validate_inputs( + field=field, + lengths=lengths, + order=order, + return_mixed_derivs=return_mixed_derivs, + ) + ndim = field_eval.ndim + + u_hat = torch.fft.fftn(field_eval, dim=tuple(range(ndim))) + k_axes = _wavenumbers( + field_eval.shape, + lengths_tuple, + device=field_eval.device, + dtype=field_eval.dtype, + ) + + derivatives: list[torch.Tensor] = [] + if order == 1: + for axis in range(ndim): + deriv_hat = (1j * k_axes[axis]) * u_hat + derivatives.append(torch.fft.ifftn(deriv_hat, dim=tuple(range(ndim))).real) + else: + for axis in range(ndim): + deriv_hat = -(k_axes[axis] * k_axes[axis]) * u_hat + derivatives.append(torch.fft.ifftn(deriv_hat, dim=tuple(range(ndim))).real) + + if return_mixed_derivs: + for axis_i, axis_j in combinations(range(ndim), 2): + deriv_hat = -(k_axes[axis_i] * k_axes[axis_j]) * u_hat + derivatives.append( + torch.fft.ifftn(deriv_hat, dim=tuple(range(ndim))).real + ) + + output = torch.stack(derivatives, dim=0) + if output.dtype != field.dtype: + return output.to(dtype=field.dtype) + return output diff --git a/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/spectral_grid_gradient.py b/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/spectral_grid_gradient.py new file mode 100644 index 0000000000..3efa75aeb1 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/spectral_grid_gradient/spectral_grid_gradient.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from .._request_utils import ( + normalize_derivative_orders, + normalize_include_mixed, + validate_mixed_request, +) +from ._torch_impl import spectral_grid_gradient_torch + + +class SpectralGridGradient(FunctionSpec): + r"""Compute periodic derivatives with Fourier spectral differentiation. + + This functional computes first-order and/or second-order derivatives on + 1D/2D/3D + periodic scalar fields by transforming to Fourier space, applying exact + derivative multipliers, and transforming back. + + Parameters + ---------- + field : torch.Tensor + Scalar field on a periodic uniform grid with shape + ``(n0,)``, ``(n0, n1)``, or ``(n0, n1, n2)``. + lengths : float | Sequence[float], optional + Physical domain lengths per axis. A scalar applies the same length to + every axis. + derivative_orders : int | Sequence[int], optional + Derivative orders to compute. Supported values are ``1``, ``2``, or + ``(1, 2)``. + include_mixed : bool, optional + Include mixed second derivatives when requesting second derivatives. + implementation : {"torch"} or None + Implementation to use. When ``None``, dispatch selects the available + implementation. + + Returns + ------- + torch.Tensor + Stacked derivative tensor with shape ``(num_derivatives, *field.shape)``. + Derivative ordering is deterministic: + first derivatives, then pure second derivatives, then mixed second + derivatives in axis-pair order ``(x,y), (x,z), (y,z)``. + """ + + _BENCHMARK_CASES = ( + ("1d-n4096-o1", (4096,), 2.0, 1, False), + ("2d-512x512-o1", (512, 512), (2.0, 1.5), 1, False), + ("2d-512x512-o2-mixed", (512, 512), (2.0, 1.5), 2, True), + ("3d-128x128x128-o2", (128, 128, 128), (2.0, 1.5, 1.25), 2, False), + ) + + @FunctionSpec.register(name="torch", rank=0, baseline=True) + def torch_forward( + field: torch.Tensor, + lengths: float | Sequence[float] = 1.0, + derivative_orders: int | Sequence[int] = 1, + include_mixed: bool = False, + ) -> torch.Tensor: + """Dispatch spectral derivatives to the PyTorch backend.""" + requested_orders = normalize_derivative_orders( + derivative_orders=derivative_orders, + function_name="spectral_grid_gradient", + ) + mixed_terms = normalize_include_mixed( + include_mixed=include_mixed, + function_name="spectral_grid_gradient", + ) + validate_mixed_request( + derivative_orders=requested_orders, + include_mixed=mixed_terms, + ndim=field.ndim, + function_name="spectral_grid_gradient", + ) + + outputs: list[torch.Tensor] = [] + if 1 in requested_orders: + outputs.append( + spectral_grid_gradient_torch( + field=field, + lengths=lengths, + order=1, + return_mixed_derivs=False, + ) + ) + if 2 in requested_orders: + outputs.append( + spectral_grid_gradient_torch( + field=field, + lengths=lengths, + order=2, + return_mixed_derivs=mixed_terms, + ) + ) + + return torch.cat(outputs, dim=0) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + for label, shape, lengths, order, return_mixed_derivs in cls._BENCHMARK_CASES: + field = cls._make_periodic_field( + shape=shape, + lengths=lengths, + device=device, + ) + yield ( + label, + (field,), + { + "lengths": lengths, + "derivative_orders": order, + "include_mixed": return_mixed_derivs, + }, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + backward_cases = ( + ("1d-grad-n1024-o1", (1024,), 2.0, 1, False), + ("2d-grad-256x256-o2", (256, 256), (2.0, 1.5), 2, False), + ("3d-grad-64x64x64-o2-mixed", (64, 64, 64), (2.0, 1.5, 1.25), 2, True), + ) + for label, shape, lengths, order, return_mixed_derivs in backward_cases: + field = ( + cls._make_periodic_field( + shape=shape, + lengths=lengths, + device=device, + ) + .detach() + .clone() + .requires_grad_(True) + ) + yield ( + label, + (field,), + { + "lengths": lengths, + "derivative_orders": order, + "include_mixed": return_mixed_derivs, + }, + ) + + @staticmethod + def _make_periodic_field( + shape: tuple[int, ...], + lengths: float | Sequence[float], + device: torch.device, + ) -> torch.Tensor: + """Construct smooth periodic fields for benchmark and test cases.""" + dim = len(shape) + if isinstance(lengths, (float, int)): + lengths_tuple = tuple(float(lengths) for _ in range(dim)) + else: + lengths_tuple = tuple(float(v) for v in lengths) + + if dim == 1: + n0 = shape[0] + l0 = lengths_tuple[0] + x0 = torch.arange(n0, device=device, dtype=torch.float32) * (l0 / n0) + k0 = 2.0 * torch.pi / l0 + return torch.sin(k0 * x0) + 0.25 * torch.cos(2.0 * k0 * x0) + + if dim == 2: + n0, n1 = shape + l0, l1 = lengths_tuple + x0 = torch.arange(n0, device=device, dtype=torch.float32) * (l0 / n0) + x1 = torch.arange(n1, device=device, dtype=torch.float32) * (l1 / n1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + k0 = 2.0 * torch.pi / l0 + k1 = 2.0 * torch.pi / l1 + return torch.sin(k0 * xx + 0.3) * torch.cos(k1 * yy - 0.2) + + n0, n1, n2 = shape + l0, l1, l2 = lengths_tuple + x0 = torch.arange(n0, device=device, dtype=torch.float32) * (l0 / n0) + x1 = torch.arange(n1, device=device, dtype=torch.float32) * (l1 / n1) + x2 = torch.arange(n2, device=device, dtype=torch.float32) * (l2 / n2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + k0 = 2.0 * torch.pi / l0 + k1 = 2.0 * torch.pi / l1 + k2 = 2.0 * torch.pi / l2 + return ( + torch.sin(k0 * xx + 0.2) + * torch.cos(k1 * yy - 0.4) + * torch.sin(k2 * zz + 0.1) + ) + + +spectral_grid_gradient = SpectralGridGradient.make_function("spectral_grid_gradient") + + +__all__ = ["SpectralGridGradient", "spectral_grid_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/__init__.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/__init__.py new file mode 100644 index 0000000000..0186f126e8 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .uniform_grid_gradient import UniformGridGradient, uniform_grid_gradient + +__all__ = ["UniformGridGradient", "uniform_grid_gradient"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_torch_impl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_torch_impl.py new file mode 100644 index 0000000000..3e59d6e51f --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_torch_impl.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +_SUPPORTED_ORDERS = (2, 4) +_SUPPORTED_DERIVATIVE_ORDERS = (1, 2) + + +def _normalize_spacing( + spacing: float | Sequence[float], ndim: int +) -> tuple[float, ...]: + ### Normalize scalar/list spacing into one value per axis. + if isinstance(spacing, (float, int)): + return tuple(float(spacing) for _ in range(ndim)) + spacing_tuple = tuple(float(x) for x in spacing) + if len(spacing_tuple) != ndim: + raise ValueError( + f"spacing must have {ndim} entries for a {ndim}D field, got {len(spacing_tuple)}" + ) + return spacing_tuple + + +def _validate_order(order: int) -> int: + ### Validate finite-difference order selection. + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + f"uniform_grid_gradient supports {list(_SUPPORTED_ORDERS)} central orders, got order={order}" + ) + return order + + +def _validate_derivative_order(derivative_order: int) -> int: + ### Validate derivative-order selection (first vs pure second derivative). + if not isinstance(derivative_order, int): + raise TypeError( + f"derivative_order must be an integer, got {type(derivative_order)}" + ) + if derivative_order not in _SUPPORTED_DERIVATIVE_ORDERS: + raise ValueError( + "uniform_grid_gradient supports derivative_order in [1, 2], " + f"got derivative_order={derivative_order}" + ) + return derivative_order + + +def _validate_include_mixed( + *, + derivative_order: int, + include_mixed: bool, +) -> None: + ### Phase-1 guard: mixed second derivatives are intentionally not yet exposed. + if not isinstance(include_mixed, bool): + raise TypeError(f"include_mixed must be a bool, got {type(include_mixed)}") + if include_mixed and derivative_order != 2: + raise ValueError("include_mixed is only valid when derivative_order=2") + if include_mixed: + raise NotImplementedError( + "include_mixed=True is not yet supported; phase-1 supports pure axis-wise " + "second derivatives only" + ) + + +def _central_derivative_order2( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + ### Second-order periodic central difference. + return ( + torch.roll(field, shifts=-1, dims=axis) - torch.roll(field, shifts=1, dims=axis) + ) / (2.0 * dx) + + +def _central_derivative_order4( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + ### Fourth-order periodic central difference. + # d/dx f_i ≈ (-f_{i+2} + 8 f_{i+1} - 8 f_{i-1} + f_{i-2}) / (12 dx) + return ( + -torch.roll(field, shifts=-2, dims=axis) + + 8.0 * torch.roll(field, shifts=-1, dims=axis) + - 8.0 * torch.roll(field, shifts=1, dims=axis) + + torch.roll(field, shifts=2, dims=axis) + ) / (12.0 * dx) + + +def _second_derivative_order2( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + ### Second-order periodic second derivative. + return ( + torch.roll(field, shifts=-1, dims=axis) + - 2.0 * field + + torch.roll(field, shifts=1, dims=axis) + ) / (dx * dx) + + +def _second_derivative_order4( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + ### Fourth-order periodic second derivative. + # d2/dx2 f_i ≈ (-f_{i+2} + 16 f_{i+1} - 30 f_i + 16 f_{i-1} - f_{i-2}) / (12 dx^2) + return ( + -torch.roll(field, shifts=-2, dims=axis) + + 16.0 * torch.roll(field, shifts=-1, dims=axis) + - 30.0 * field + + 16.0 * torch.roll(field, shifts=1, dims=axis) + - torch.roll(field, shifts=2, dims=axis) + ) / (12.0 * dx * dx) + + +def uniform_grid_gradient_torch( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + derivative_order: int = 1, + include_mixed: bool = False, +) -> torch.Tensor: + """Compute periodic first or pure second derivatives on a uniform grid.""" + ### Validate field shape and dtype. + if field.ndim < 1 or field.ndim > 3: + raise ValueError( + f"uniform_grid_gradient supports 1D-3D fields, got {field.shape=}" + ) + if not torch.is_floating_point(field): + raise TypeError("field must be a floating-point tensor") + order = _validate_order(order) + derivative_order = _validate_derivative_order(derivative_order) + _validate_include_mixed( + derivative_order=derivative_order, + include_mixed=include_mixed, + ) + + ### Expand spacing to one entry per field axis. + spacing_tuple = _normalize_spacing(spacing, field.ndim) + + ### Compute periodic central differences independently per axis. + gradients: list[torch.Tensor] = [] + for axis, dx in enumerate(spacing_tuple): + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + ### Periodic axis-wise derivative with configurable derivative/stencil order. + if derivative_order == 1 and order == 2: + grad_axis = _central_derivative_order2(field, axis=axis, dx=dx) + elif derivative_order == 1 and order == 4: + grad_axis = _central_derivative_order4(field, axis=axis, dx=dx) + elif derivative_order == 2 and order == 2: + grad_axis = _second_derivative_order2(field, axis=axis, dx=dx) + else: + grad_axis = _second_derivative_order4(field, axis=axis, dx=dx) + gradients.append(grad_axis) + + ### Stack per-axis derivative terms into (dims, *field.shape). + return torch.stack(gradients, dim=0) diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/__init__.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/__init__.py new file mode 100644 index 0000000000..c4abef9a9f --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .op import uniform_grid_gradient_warp, uniform_grid_gradient_warp_multi + +__all__ = ["uniform_grid_gradient_warp", "uniform_grid_gradient_warp_multi"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/__init__.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/__init__.py new file mode 100644 index 0000000000..ee2e1a5fa2 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/__init__.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .backward import ( + _uniform_grid_derivatives_1d_order2_fused_backward_kernel, + _uniform_grid_derivatives_2d_order2_fused_backward_kernel, + _uniform_grid_derivatives_2d_order2_fused_no_mixed_backward_kernel, + _uniform_grid_derivatives_3d_order2_fused_backward_kernel, + _uniform_grid_derivatives_3d_order2_fused_no_mixed_backward_kernel, + _uniform_grid_gradient_1d_backward_kernel, + _uniform_grid_gradient_1d_order4_backward_kernel, + _uniform_grid_gradient_2d_backward_kernel, + _uniform_grid_gradient_2d_order4_backward_kernel, + _uniform_grid_gradient_3d_backward_kernel, + _uniform_grid_gradient_3d_order4_backward_kernel, + _uniform_grid_second_derivative_1d_backward_kernel, + _uniform_grid_second_derivative_1d_order4_backward_kernel, + _uniform_grid_second_derivative_2d_backward_kernel, + _uniform_grid_second_derivative_2d_order4_backward_kernel, + _uniform_grid_second_derivative_3d_backward_kernel, + _uniform_grid_second_derivative_3d_order4_backward_kernel, +) +from .forward import ( + _uniform_grid_derivatives_1d_order2_fused_kernel, + _uniform_grid_derivatives_2d_order2_fused_kernel, + _uniform_grid_derivatives_2d_order2_fused_no_mixed_kernel, + _uniform_grid_derivatives_3d_order2_fused_kernel, + _uniform_grid_derivatives_3d_order2_fused_no_mixed_kernel, + _uniform_grid_gradient_1d_kernel, + _uniform_grid_gradient_1d_order4_kernel, + _uniform_grid_gradient_2d_kernel, + _uniform_grid_gradient_2d_order4_kernel, + _uniform_grid_gradient_3d_kernel, + _uniform_grid_gradient_3d_order4_kernel, + _uniform_grid_second_derivative_1d_kernel, + _uniform_grid_second_derivative_1d_order4_kernel, + _uniform_grid_second_derivative_2d_kernel, + _uniform_grid_second_derivative_2d_order4_kernel, + _uniform_grid_second_derivative_3d_kernel, + _uniform_grid_second_derivative_3d_order4_kernel, +) + +__all__ = [ + "_uniform_grid_derivatives_1d_order2_fused_kernel", + "_uniform_grid_derivatives_2d_order2_fused_kernel", + "_uniform_grid_derivatives_2d_order2_fused_no_mixed_kernel", + "_uniform_grid_derivatives_3d_order2_fused_kernel", + "_uniform_grid_derivatives_3d_order2_fused_no_mixed_kernel", + "_uniform_grid_derivatives_1d_order2_fused_backward_kernel", + "_uniform_grid_derivatives_2d_order2_fused_backward_kernel", + "_uniform_grid_derivatives_2d_order2_fused_no_mixed_backward_kernel", + "_uniform_grid_derivatives_3d_order2_fused_backward_kernel", + "_uniform_grid_derivatives_3d_order2_fused_no_mixed_backward_kernel", + "_uniform_grid_gradient_1d_kernel", + "_uniform_grid_gradient_1d_order4_kernel", + "_uniform_grid_gradient_2d_kernel", + "_uniform_grid_gradient_2d_order4_kernel", + "_uniform_grid_gradient_3d_kernel", + "_uniform_grid_gradient_3d_order4_kernel", + "_uniform_grid_second_derivative_1d_kernel", + "_uniform_grid_second_derivative_1d_order4_kernel", + "_uniform_grid_second_derivative_2d_kernel", + "_uniform_grid_second_derivative_2d_order4_kernel", + "_uniform_grid_second_derivative_3d_kernel", + "_uniform_grid_second_derivative_3d_order4_kernel", + "_uniform_grid_gradient_1d_backward_kernel", + "_uniform_grid_gradient_1d_order4_backward_kernel", + "_uniform_grid_gradient_2d_backward_kernel", + "_uniform_grid_gradient_2d_order4_backward_kernel", + "_uniform_grid_gradient_3d_backward_kernel", + "_uniform_grid_gradient_3d_order4_backward_kernel", + "_uniform_grid_second_derivative_1d_backward_kernel", + "_uniform_grid_second_derivative_1d_order4_backward_kernel", + "_uniform_grid_second_derivative_2d_backward_kernel", + "_uniform_grid_second_derivative_2d_order4_backward_kernel", + "_uniform_grid_second_derivative_3d_backward_kernel", + "_uniform_grid_second_derivative_3d_order4_backward_kernel", +] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/backward.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/backward.py new file mode 100644 index 0000000000..f12ba908f3 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/backward.py @@ -0,0 +1,565 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warp as wp + +### ============================================================ +### Index wrapping helpers (periodic boundaries without modulo) +### ============================================================ + + +@wp.func +def _wrap_plus1(i: int, n: int) -> int: + return (i + 1) % n + + +@wp.func +def _wrap_minus1(i: int, n: int) -> int: + return (i + n - 1) % n + + +@wp.func +def _wrap_plus2(i: int, n: int) -> int: + return (i + 2) % n + + +@wp.func +def _wrap_minus2(i: int, n: int) -> int: + return (i + n - 2) % n + + +### ============================================================ +### Backward kernels (adjoint central differences) +### ============================================================ + + +@wp.kernel +def _uniform_grid_gradient_1d_backward_kernel( + grad0: wp.array(dtype=wp.float32), + inv_dx: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + grad_field[i] = (grad0[im] - grad0[ip]) * (0.5 * inv_dx) + + +@wp.kernel +def _uniform_grid_gradient_1d_order4_backward_kernel( + grad0: wp.array(dtype=wp.float32), + inv_dx: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + grad_field[i] = (grad0[ip2] - 8.0 * grad0[ip1] + 8.0 * grad0[im1] - grad0[im2]) * ( + inv_dx / 12.0 + ) + + +@wp.kernel +def _uniform_grid_gradient_2d_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + grad_field[i, j] = (grad0[im, j] - grad0[ip, j]) * (0.5 * inv_dx0) + ( + grad1[i, jm] - grad1[i, jp] + ) * (0.5 * inv_dx1) + + +@wp.kernel +def _uniform_grid_gradient_2d_order4_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + gx = (grad0[ip2, j] - 8.0 * grad0[ip1, j] + 8.0 * grad0[im1, j] - grad0[im2, j]) * ( + inv_dx0 / 12.0 + ) + gy = (grad1[i, jp2] - 8.0 * grad1[i, jp1] + 8.0 * grad1[i, jm1] - grad1[i, jm2]) * ( + inv_dx1 / 12.0 + ) + grad_field[i, j] = gx + gy + + +@wp.kernel +def _uniform_grid_gradient_3d_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + grad_field[i, j, k] = ( + (grad0[im, j, k] - grad0[ip, j, k]) * (0.5 * inv_dx0) + + (grad1[i, jm, k] - grad1[i, jp, k]) * (0.5 * inv_dx1) + + (grad2[i, j, km] - grad2[i, j, kp]) * (0.5 * inv_dx2) + ) + + +@wp.kernel +def _uniform_grid_gradient_3d_order4_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + + gx = ( + grad0[ip2, j, k] + - 8.0 * grad0[ip1, j, k] + + 8.0 * grad0[im1, j, k] + - grad0[im2, j, k] + ) * (inv_dx0 / 12.0) + gy = ( + grad1[i, jp2, k] + - 8.0 * grad1[i, jp1, k] + + 8.0 * grad1[i, jm1, k] + - grad1[i, jm2, k] + ) * (inv_dx1 / 12.0) + gz = ( + grad2[i, j, kp2] + - 8.0 * grad2[i, j, kp1] + + 8.0 * grad2[i, j, km1] + - grad2[i, j, km2] + ) * (inv_dx2 / 12.0) + grad_field[i, j, k] = gx + gy + gz + + +@wp.kernel +def _uniform_grid_second_derivative_1d_backward_kernel( + grad0: wp.array(dtype=wp.float32), + inv_dx20: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + grad_field[i] = (grad0[ip] - 2.0 * grad0[i] + grad0[im]) * inv_dx20 + + +@wp.kernel +def _uniform_grid_second_derivative_1d_order4_backward_kernel( + grad0: wp.array(dtype=wp.float32), + inv_dx20: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + grad_field[i] = ( + -grad0[ip2] + + 16.0 * grad0[ip1] + - 30.0 * grad0[i] + + 16.0 * grad0[im1] + - grad0[im2] + ) * (inv_dx20 / 12.0) + + +@wp.kernel +def _uniform_grid_second_derivative_2d_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + gxx = (grad0[ip, j] - 2.0 * grad0[i, j] + grad0[im, j]) * inv_dx20 + gyy = (grad1[i, jp] - 2.0 * grad1[i, j] + grad1[i, jm]) * inv_dx21 + grad_field[i, j] = gxx + gyy + + +@wp.kernel +def _uniform_grid_second_derivative_2d_order4_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + gxx = ( + -grad0[ip2, j] + + 16.0 * grad0[ip1, j] + - 30.0 * grad0[i, j] + + 16.0 * grad0[im1, j] + - grad0[im2, j] + ) * (inv_dx20 / 12.0) + gyy = ( + -grad1[i, jp2] + + 16.0 * grad1[i, jp1] + - 30.0 * grad1[i, j] + + 16.0 * grad1[i, jm1] + - grad1[i, jm2] + ) * (inv_dx21 / 12.0) + grad_field[i, j] = gxx + gyy + + +@wp.kernel +def _uniform_grid_second_derivative_3d_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + gxx = (grad0[ip, j, k] - 2.0 * grad0[i, j, k] + grad0[im, j, k]) * inv_dx20 + gyy = (grad1[i, jp, k] - 2.0 * grad1[i, j, k] + grad1[i, jm, k]) * inv_dx21 + gzz = (grad2[i, j, kp] - 2.0 * grad2[i, j, k] + grad2[i, j, km]) * inv_dx22 + grad_field[i, j, k] = gxx + gyy + gzz + + +@wp.kernel +def _uniform_grid_second_derivative_3d_order4_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + + gxx = ( + -grad0[ip2, j, k] + + 16.0 * grad0[ip1, j, k] + - 30.0 * grad0[i, j, k] + + 16.0 * grad0[im1, j, k] + - grad0[im2, j, k] + ) * (inv_dx20 / 12.0) + gyy = ( + -grad1[i, jp2, k] + + 16.0 * grad1[i, jp1, k] + - 30.0 * grad1[i, j, k] + + 16.0 * grad1[i, jm1, k] + - grad1[i, jm2, k] + ) * (inv_dx21 / 12.0) + gzz = ( + -grad2[i, j, kp2] + + 16.0 * grad2[i, j, kp1] + - 30.0 * grad2[i, j, k] + + 16.0 * grad2[i, j, km1] + - grad2[i, j, km2] + ) * (inv_dx22 / 12.0) + grad_field[i, j, k] = gxx + gyy + gzz + + +### ============================================================ +### Fused backward kernels for order=2 (first+second, no mixed) +### ============================================================ + + +@wp.kernel +def _uniform_grid_derivatives_1d_order2_fused_backward_kernel( + grad0: wp.array(dtype=wp.float32), + grad00: wp.array(dtype=wp.float32), + inv_dx: float, + inv_dx2: float, + grad_field: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = grad0.shape[0] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + + g1 = (grad0[im] - grad0[ip]) * (0.5 * inv_dx) + g2 = (grad00[ip] - 2.0 * grad00[i] + grad00[im]) * inv_dx2 + grad_field[i] = g1 + g2 + + +@wp.kernel +def _uniform_grid_derivatives_2d_order2_fused_no_mixed_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + grad00: wp.array2d(dtype=wp.float32), + grad11: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx20: float, + inv_dx21: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + g1x = (grad0[im, j] - grad0[ip, j]) * (0.5 * inv_dx0) + g1y = (grad1[i, jm] - grad1[i, jp]) * (0.5 * inv_dx1) + g2x = (grad00[ip, j] - 2.0 * grad00[i, j] + grad00[im, j]) * inv_dx20 + g2y = (grad11[i, jp] - 2.0 * grad11[i, j] + grad11[i, jm]) * inv_dx21 + grad_field[i, j] = g1x + g1y + g2x + g2y + + +@wp.kernel +def _uniform_grid_derivatives_3d_order2_fused_no_mixed_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + grad00: wp.array3d(dtype=wp.float32), + grad11: wp.array3d(dtype=wp.float32), + grad22: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + g1x = (grad0[im, j, k] - grad0[ip, j, k]) * (0.5 * inv_dx0) + g1y = (grad1[i, jm, k] - grad1[i, jp, k]) * (0.5 * inv_dx1) + g1z = (grad2[i, j, km] - grad2[i, j, kp]) * (0.5 * inv_dx2) + g2x = (grad00[ip, j, k] - 2.0 * grad00[i, j, k] + grad00[im, j, k]) * inv_dx20 + g2y = (grad11[i, jp, k] - 2.0 * grad11[i, j, k] + grad11[i, jm, k]) * inv_dx21 + g2z = (grad22[i, j, kp] - 2.0 * grad22[i, j, k] + grad22[i, j, km]) * inv_dx22 + grad_field[i, j, k] = g1x + g1y + g1z + g2x + g2y + g2z + + +@wp.kernel +def _uniform_grid_derivatives_2d_order2_fused_backward_kernel( + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + grad00: wp.array2d(dtype=wp.float32), + grad11: wp.array2d(dtype=wp.float32), + grad01: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx20: float, + inv_dx21: float, + inv_dx01: float, + grad_field: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + g1x = (grad0[im, j] - grad0[ip, j]) * (0.5 * inv_dx0) + g1y = (grad1[i, jm] - grad1[i, jp]) * (0.5 * inv_dx1) + g2x = (grad00[ip, j] - 2.0 * grad00[i, j] + grad00[im, j]) * inv_dx20 + g2y = (grad11[i, jp] - 2.0 * grad11[i, j] + grad11[i, jm]) * inv_dx21 + gm = (grad01[im, jm] - grad01[im, jp] - grad01[ip, jm] + grad01[ip, jp]) * ( + 0.25 * inv_dx01 + ) + grad_field[i, j] = g1x + g1y + g2x + g2y + gm + + +@wp.kernel +def _uniform_grid_derivatives_3d_order2_fused_backward_kernel( + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + grad00: wp.array3d(dtype=wp.float32), + grad11: wp.array3d(dtype=wp.float32), + grad22: wp.array3d(dtype=wp.float32), + grad01: wp.array3d(dtype=wp.float32), + grad02: wp.array3d(dtype=wp.float32), + grad12: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + inv_dx01: float, + inv_dx02: float, + inv_dx12: float, + grad_field: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad0.shape[0] + n1 = grad0.shape[1] + n2 = grad0.shape[2] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + g1x = (grad0[im, j, k] - grad0[ip, j, k]) * (0.5 * inv_dx0) + g1y = (grad1[i, jm, k] - grad1[i, jp, k]) * (0.5 * inv_dx1) + g1z = (grad2[i, j, km] - grad2[i, j, kp]) * (0.5 * inv_dx2) + + g2x = (grad00[ip, j, k] - 2.0 * grad00[i, j, k] + grad00[im, j, k]) * inv_dx20 + g2y = (grad11[i, jp, k] - 2.0 * grad11[i, j, k] + grad11[i, jm, k]) * inv_dx21 + g2z = (grad22[i, j, kp] - 2.0 * grad22[i, j, k] + grad22[i, j, km]) * inv_dx22 + + gm01 = ( + grad01[im, jm, k] - grad01[im, jp, k] - grad01[ip, jm, k] + grad01[ip, jp, k] + ) * (0.25 * inv_dx01) + gm02 = ( + grad02[im, j, km] - grad02[im, j, kp] - grad02[ip, j, km] + grad02[ip, j, kp] + ) * (0.25 * inv_dx02) + gm12 = ( + grad12[i, jm, km] - grad12[i, jm, kp] - grad12[i, jp, km] + grad12[i, jp, kp] + ) * (0.25 * inv_dx12) + grad_field[i, j, k] = g1x + g1y + g1z + g2x + g2y + g2z + gm01 + gm02 + gm12 diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/forward.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/forward.py new file mode 100644 index 0000000000..e8b2febfca --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/_kernels/forward.py @@ -0,0 +1,578 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warp as wp + +### ============================================================ +### Index wrapping helpers (periodic boundaries without modulo) +### ============================================================ + + +@wp.func +def _wrap_plus1(i: int, n: int) -> int: + return (i + 1) % n + + +@wp.func +def _wrap_minus1(i: int, n: int) -> int: + return (i + n - 1) % n + + +@wp.func +def _wrap_plus2(i: int, n: int) -> int: + return (i + 2) % n + + +@wp.func +def _wrap_minus2(i: int, n: int) -> int: + return (i + n - 2) % n + + +### ============================================================ +### Forward kernels (periodic central differences) +### ============================================================ + + +@wp.kernel +def _uniform_grid_gradient_1d_kernel( + field: wp.array(dtype=wp.float32), + inv_dx: float, + grad0: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + + grad0[i] = (field[ip] - field[im]) * (0.5 * inv_dx) + + +@wp.kernel +def _uniform_grid_gradient_1d_order4_kernel( + field: wp.array(dtype=wp.float32), + inv_dx: float, + grad0: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + grad0[i] = (-field[ip2] + 8.0 * field[ip1] - 8.0 * field[im1] + field[im2]) * ( + inv_dx / 12.0 + ) + + +@wp.kernel +def _uniform_grid_gradient_2d_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + grad0[i, j] = (field[ip, j] - field[im, j]) * (0.5 * inv_dx0) + grad1[i, j] = (field[i, jp] - field[i, jm]) * (0.5 * inv_dx1) + + +@wp.kernel +def _uniform_grid_gradient_2d_order4_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + grad0[i, j] = ( + -field[ip2, j] + 8.0 * field[ip1, j] - 8.0 * field[im1, j] + field[im2, j] + ) * (inv_dx0 / 12.0) + grad1[i, j] = ( + -field[i, jp2] + 8.0 * field[i, jp1] - 8.0 * field[i, jm1] + field[i, jm2] + ) * (inv_dx1 / 12.0) + + +@wp.kernel +def _uniform_grid_gradient_3d_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + grad0[i, j, k] = (field[ip, j, k] - field[im, j, k]) * (0.5 * inv_dx0) + grad1[i, j, k] = (field[i, jp, k] - field[i, jm, k]) * (0.5 * inv_dx1) + grad2[i, j, k] = (field[i, j, kp] - field[i, j, km]) * (0.5 * inv_dx2) + + +@wp.kernel +def _uniform_grid_gradient_3d_order4_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + + grad0[i, j, k] = ( + -field[ip2, j, k] + + 8.0 * field[ip1, j, k] + - 8.0 * field[im1, j, k] + + field[im2, j, k] + ) * (inv_dx0 / 12.0) + grad1[i, j, k] = ( + -field[i, jp2, k] + + 8.0 * field[i, jp1, k] + - 8.0 * field[i, jm1, k] + + field[i, jm2, k] + ) * (inv_dx1 / 12.0) + grad2[i, j, k] = ( + -field[i, j, kp2] + + 8.0 * field[i, j, kp1] + - 8.0 * field[i, j, km1] + + field[i, j, km2] + ) * (inv_dx2 / 12.0) + + +@wp.kernel +def _uniform_grid_second_derivative_1d_kernel( + field: wp.array(dtype=wp.float32), + inv_dx2: float, + grad0: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + grad0[i] = (field[ip] - 2.0 * field[i] + field[im]) * inv_dx2 + + +@wp.kernel +def _uniform_grid_second_derivative_1d_order4_kernel( + field: wp.array(dtype=wp.float32), + inv_dx2: float, + grad0: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + grad0[i] = ( + -field[ip2] + + 16.0 * field[ip1] + - 30.0 * field[i] + + 16.0 * field[im1] + - field[im2] + ) * (inv_dx2 / 12.0) + + +@wp.kernel +def _uniform_grid_second_derivative_2d_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + grad0[i, j] = (field[ip, j] - 2.0 * field[i, j] + field[im, j]) * inv_dx20 + grad1[i, j] = (field[i, jp] - 2.0 * field[i, j] + field[i, jm]) * inv_dx21 + + +@wp.kernel +def _uniform_grid_second_derivative_2d_order4_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + grad0[i, j] = ( + -field[ip2, j] + + 16.0 * field[ip1, j] + - 30.0 * field[i, j] + + 16.0 * field[im1, j] + - field[im2, j] + ) * (inv_dx20 / 12.0) + grad1[i, j] = ( + -field[i, jp2] + + 16.0 * field[i, jp1] + - 30.0 * field[i, j] + + 16.0 * field[i, jm1] + - field[i, jm2] + ) * (inv_dx21 / 12.0) + + +@wp.kernel +def _uniform_grid_second_derivative_3d_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + grad0[i, j, k] = ( + field[ip, j, k] - 2.0 * field[i, j, k] + field[im, j, k] + ) * inv_dx20 + grad1[i, j, k] = ( + field[i, jp, k] - 2.0 * field[i, j, k] + field[i, jm, k] + ) * inv_dx21 + grad2[i, j, k] = ( + field[i, j, kp] - 2.0 * field[i, j, k] + field[i, j, km] + ) * inv_dx22 + + +@wp.kernel +def _uniform_grid_second_derivative_3d_order4_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + + grad0[i, j, k] = ( + -field[ip2, j, k] + + 16.0 * field[ip1, j, k] + - 30.0 * field[i, j, k] + + 16.0 * field[im1, j, k] + - field[im2, j, k] + ) * (inv_dx20 / 12.0) + grad1[i, j, k] = ( + -field[i, jp2, k] + + 16.0 * field[i, jp1, k] + - 30.0 * field[i, j, k] + + 16.0 * field[i, jm1, k] + - field[i, jm2, k] + ) * (inv_dx21 / 12.0) + grad2[i, j, k] = ( + -field[i, j, kp2] + + 16.0 * field[i, j, kp1] + - 30.0 * field[i, j, k] + + 16.0 * field[i, j, km1] + - field[i, j, km2] + ) * (inv_dx22 / 12.0) + + +### ============================================================ +### Fused forward kernels for order=2 (single launch for 1st+2nd+mixed) +### ============================================================ + + +@wp.kernel +def _uniform_grid_derivatives_1d_order2_fused_kernel( + field: wp.array(dtype=wp.float32), + inv_dx: float, + inv_dx2: float, + grad0: wp.array(dtype=wp.float32), + grad00: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + + grad0[i] = (field[ip] - field[im]) * (0.5 * inv_dx) + grad00[i] = (field[ip] - 2.0 * field[i] + field[im]) * inv_dx2 + + +@wp.kernel +def _uniform_grid_derivatives_2d_order2_fused_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx20: float, + inv_dx21: float, + inv_dx01: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + grad00: wp.array2d(dtype=wp.float32), + grad11: wp.array2d(dtype=wp.float32), + grad01: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + grad0[i, j] = (field[ip, j] - field[im, j]) * (0.5 * inv_dx0) + grad1[i, j] = (field[i, jp] - field[i, jm]) * (0.5 * inv_dx1) + + grad00[i, j] = (field[ip, j] - 2.0 * field[i, j] + field[im, j]) * inv_dx20 + grad11[i, j] = (field[i, jp] - 2.0 * field[i, j] + field[i, jm]) * inv_dx21 + + grad01[i, j] = (field[ip, jp] - field[ip, jm] - field[im, jp] + field[im, jm]) * ( + 0.25 * inv_dx01 + ) + + +@wp.kernel +def _uniform_grid_derivatives_3d_order2_fused_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + inv_dx01: float, + inv_dx02: float, + inv_dx12: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + grad00: wp.array3d(dtype=wp.float32), + grad11: wp.array3d(dtype=wp.float32), + grad22: wp.array3d(dtype=wp.float32), + grad01: wp.array3d(dtype=wp.float32), + grad02: wp.array3d(dtype=wp.float32), + grad12: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + grad0[i, j, k] = (field[ip, j, k] - field[im, j, k]) * (0.5 * inv_dx0) + grad1[i, j, k] = (field[i, jp, k] - field[i, jm, k]) * (0.5 * inv_dx1) + grad2[i, j, k] = (field[i, j, kp] - field[i, j, km]) * (0.5 * inv_dx2) + + grad00[i, j, k] = ( + field[ip, j, k] - 2.0 * field[i, j, k] + field[im, j, k] + ) * inv_dx20 + grad11[i, j, k] = ( + field[i, jp, k] - 2.0 * field[i, j, k] + field[i, jm, k] + ) * inv_dx21 + grad22[i, j, k] = ( + field[i, j, kp] - 2.0 * field[i, j, k] + field[i, j, km] + ) * inv_dx22 + + grad01[i, j, k] = ( + field[ip, jp, k] - field[ip, jm, k] - field[im, jp, k] + field[im, jm, k] + ) * (0.25 * inv_dx01) + grad02[i, j, k] = ( + field[ip, j, kp] - field[ip, j, km] - field[im, j, kp] + field[im, j, km] + ) * (0.25 * inv_dx02) + grad12[i, j, k] = ( + field[i, jp, kp] - field[i, jp, km] - field[i, jm, kp] + field[i, jm, km] + ) * (0.25 * inv_dx12) + + +@wp.kernel +def _uniform_grid_derivatives_2d_order2_fused_no_mixed_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx20: float, + inv_dx21: float, + grad0: wp.array2d(dtype=wp.float32), + grad1: wp.array2d(dtype=wp.float32), + grad00: wp.array2d(dtype=wp.float32), + grad11: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + + grad0[i, j] = (field[ip, j] - field[im, j]) * (0.5 * inv_dx0) + grad1[i, j] = (field[i, jp] - field[i, jm]) * (0.5 * inv_dx1) + + grad00[i, j] = (field[ip, j] - 2.0 * field[i, j] + field[im, j]) * inv_dx20 + grad11[i, j] = (field[i, jp] - 2.0 * field[i, j] + field[i, jm]) * inv_dx21 + + +@wp.kernel +def _uniform_grid_derivatives_3d_order2_fused_no_mixed_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + inv_dx20: float, + inv_dx21: float, + inv_dx22: float, + grad0: wp.array3d(dtype=wp.float32), + grad1: wp.array3d(dtype=wp.float32), + grad2: wp.array3d(dtype=wp.float32), + grad00: wp.array3d(dtype=wp.float32), + grad11: wp.array3d(dtype=wp.float32), + grad22: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + + grad0[i, j, k] = (field[ip, j, k] - field[im, j, k]) * (0.5 * inv_dx0) + grad1[i, j, k] = (field[i, jp, k] - field[i, jm, k]) * (0.5 * inv_dx1) + grad2[i, j, k] = (field[i, j, kp] - field[i, j, km]) * (0.5 * inv_dx2) + + grad00[i, j, k] = ( + field[ip, j, k] - 2.0 * field[i, j, k] + field[im, j, k] + ) * inv_dx20 + grad11[i, j, k] = ( + field[i, jp, k] - 2.0 * field[i, j, k] + field[i, jm, k] + ) * inv_dx21 + grad22[i, j, k] = ( + field[i, j, kp] - 2.0 * field[i, j, k] + field[i, j, km] + ) * inv_dx22 diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/launch_backward.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/launch_backward.py new file mode 100644 index 0000000000..f109a011ba --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/launch_backward.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import warp as wp + +from ._kernels import ( + _uniform_grid_derivatives_1d_order2_fused_backward_kernel, + _uniform_grid_derivatives_2d_order2_fused_backward_kernel, + _uniform_grid_derivatives_2d_order2_fused_no_mixed_backward_kernel, + _uniform_grid_derivatives_3d_order2_fused_backward_kernel, + _uniform_grid_derivatives_3d_order2_fused_no_mixed_backward_kernel, + _uniform_grid_gradient_1d_backward_kernel, + _uniform_grid_gradient_1d_order4_backward_kernel, + _uniform_grid_gradient_2d_backward_kernel, + _uniform_grid_gradient_2d_order4_backward_kernel, + _uniform_grid_gradient_3d_backward_kernel, + _uniform_grid_gradient_3d_order4_backward_kernel, + _uniform_grid_second_derivative_1d_backward_kernel, + _uniform_grid_second_derivative_1d_order4_backward_kernel, + _uniform_grid_second_derivative_2d_backward_kernel, + _uniform_grid_second_derivative_2d_order4_backward_kernel, + _uniform_grid_second_derivative_3d_backward_kernel, + _uniform_grid_second_derivative_3d_order4_backward_kernel, +) +from .utils import _wp_launch + +_BACKWARD_KERNELS = { + (1, 1, 2): _uniform_grid_gradient_1d_backward_kernel, + (1, 1, 4): _uniform_grid_gradient_1d_order4_backward_kernel, + (1, 2, 2): _uniform_grid_second_derivative_1d_backward_kernel, + (1, 2, 4): _uniform_grid_second_derivative_1d_order4_backward_kernel, + (2, 1, 2): _uniform_grid_gradient_2d_backward_kernel, + (2, 1, 4): _uniform_grid_gradient_2d_order4_backward_kernel, + (2, 2, 2): _uniform_grid_second_derivative_2d_backward_kernel, + (2, 2, 4): _uniform_grid_second_derivative_2d_order4_backward_kernel, + (3, 1, 2): _uniform_grid_gradient_3d_backward_kernel, + (3, 1, 4): _uniform_grid_gradient_3d_order4_backward_kernel, + (3, 2, 2): _uniform_grid_second_derivative_3d_backward_kernel, + (3, 2, 4): _uniform_grid_second_derivative_3d_order4_backward_kernel, +} + +_FUSED_BACKWARD_KERNELS = { + (1, False): _uniform_grid_derivatives_1d_order2_fused_backward_kernel, + (1, True): _uniform_grid_derivatives_1d_order2_fused_backward_kernel, + (2, False): _uniform_grid_derivatives_2d_order2_fused_no_mixed_backward_kernel, + (2, True): _uniform_grid_derivatives_2d_order2_fused_backward_kernel, + (3, False): _uniform_grid_derivatives_3d_order2_fused_no_mixed_backward_kernel, + (3, True): _uniform_grid_derivatives_3d_order2_fused_backward_kernel, +} + + +def _launch_dim(shape: torch.Size) -> int | tuple[int, ...]: + """Return Warp launch dimensions for 1D vs ND kernels.""" + return shape[0] if len(shape) == 1 else tuple(shape) + + +def _inverse_spacings(spacing_tuple: tuple[float, ...], power: int) -> list[float]: + """Compute inverse spacing terms with optional square for second derivatives.""" + if power == 1: + return [1.0 / float(dx) for dx in spacing_tuple] + return [1.0 / float(dx * dx) for dx in spacing_tuple] + + +def _mixed_inverse_spacings(spacing_tuple: tuple[float, ...]) -> list[float]: + """Compute inverse mixed spacing terms in axis-pair order.""" + return [ + 1.0 / float(spacing_tuple[i] * spacing_tuple[j]) + for i in range(len(spacing_tuple)) + for j in range(i + 1, len(spacing_tuple)) + ] + + +def _to_wp_components(components: list[torch.Tensor], count: int) -> list[wp.array]: + """Convert the leading tensor components to Warp arrays.""" + return [wp.from_torch(components[i], dtype=wp.float32) for i in range(count)] + + +def _launch_backward( + *, + grad_output_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + order: int, + derivative_order: int, + grad_field: torch.Tensor, + wp_device, + wp_stream, +) -> None: + ### Launch dimensionality/order-specific backward kernels. + ndim = grad_field.ndim + kernel = _BACKWARD_KERNELS[(ndim, derivative_order, order)] + local_spacing = spacing_tuple[:ndim] + inv_terms = _inverse_spacings( + local_spacing, + power=1 if derivative_order == 1 else 2, + ) + + _wp_launch( + kernel=kernel, + dim=_launch_dim(grad_field.shape), + inputs=[ + *_to_wp_components(grad_output_fp32, ndim), + *inv_terms, + wp.from_torch(grad_field, dtype=wp.float32), + ], + device=wp_device, + stream=wp_stream, + ) + + +def _launch_backward_fused_order2_no_mixed( + *, + grad_first_components: list[torch.Tensor], + grad_second_components: list[torch.Tensor], + grad_mixed_components: list[torch.Tensor], + spacing_tuple: tuple[float, ...], + include_mixed: bool, + grad_field: torch.Tensor, + wp_device, + wp_stream, +) -> None: + """Launch fused order-2 backward kernels for first/second/(optional mixed).""" + ndim = grad_field.ndim + kernel = _FUSED_BACKWARD_KERNELS[(ndim, include_mixed)] + local_spacing = spacing_tuple[:ndim] + inv_first = _inverse_spacings(local_spacing, power=1) + inv_second = _inverse_spacings(local_spacing, power=2) + + inputs: list = [ + *_to_wp_components(grad_first_components, ndim), + *_to_wp_components(grad_second_components, ndim), + ] + + if include_mixed and ndim > 1: + mixed_count = ndim * (ndim - 1) // 2 + inputs.extend(_to_wp_components(grad_mixed_components, mixed_count)) + + inputs.extend(inv_first) + inputs.extend(inv_second) + + if include_mixed and ndim > 1: + inputs.extend(_mixed_inverse_spacings(local_spacing)) + + inputs.append(wp.from_torch(grad_field, dtype=wp.float32)) + + _wp_launch( + kernel=kernel, + dim=_launch_dim(grad_field.shape), + inputs=inputs, + device=wp_device, + stream=wp_stream, + ) diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/launch_forward.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/launch_forward.py new file mode 100644 index 0000000000..503be31875 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/launch_forward.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import warp as wp + +from ._kernels import ( + _uniform_grid_derivatives_1d_order2_fused_kernel, + _uniform_grid_derivatives_2d_order2_fused_kernel, + _uniform_grid_derivatives_2d_order2_fused_no_mixed_kernel, + _uniform_grid_derivatives_3d_order2_fused_kernel, + _uniform_grid_derivatives_3d_order2_fused_no_mixed_kernel, + _uniform_grid_gradient_1d_kernel, + _uniform_grid_gradient_1d_order4_kernel, + _uniform_grid_gradient_2d_kernel, + _uniform_grid_gradient_2d_order4_kernel, + _uniform_grid_gradient_3d_kernel, + _uniform_grid_gradient_3d_order4_kernel, + _uniform_grid_second_derivative_1d_kernel, + _uniform_grid_second_derivative_1d_order4_kernel, + _uniform_grid_second_derivative_2d_kernel, + _uniform_grid_second_derivative_2d_order4_kernel, + _uniform_grid_second_derivative_3d_kernel, + _uniform_grid_second_derivative_3d_order4_kernel, +) +from .utils import _wp_launch + +_FORWARD_KERNELS = { + (1, 1, 2): _uniform_grid_gradient_1d_kernel, + (1, 1, 4): _uniform_grid_gradient_1d_order4_kernel, + (1, 2, 2): _uniform_grid_second_derivative_1d_kernel, + (1, 2, 4): _uniform_grid_second_derivative_1d_order4_kernel, + (2, 1, 2): _uniform_grid_gradient_2d_kernel, + (2, 1, 4): _uniform_grid_gradient_2d_order4_kernel, + (2, 2, 2): _uniform_grid_second_derivative_2d_kernel, + (2, 2, 4): _uniform_grid_second_derivative_2d_order4_kernel, + (3, 1, 2): _uniform_grid_gradient_3d_kernel, + (3, 1, 4): _uniform_grid_gradient_3d_order4_kernel, + (3, 2, 2): _uniform_grid_second_derivative_3d_kernel, + (3, 2, 4): _uniform_grid_second_derivative_3d_order4_kernel, +} + +_FUSED_FORWARD_KERNELS = { + (1, False): _uniform_grid_derivatives_1d_order2_fused_kernel, + (1, True): _uniform_grid_derivatives_1d_order2_fused_kernel, + (2, False): _uniform_grid_derivatives_2d_order2_fused_no_mixed_kernel, + (2, True): _uniform_grid_derivatives_2d_order2_fused_kernel, + (3, False): _uniform_grid_derivatives_3d_order2_fused_no_mixed_kernel, + (3, True): _uniform_grid_derivatives_3d_order2_fused_kernel, +} + + +def _launch_dim(shape: torch.Size) -> int | tuple[int, ...]: + """Return Warp launch dimensions for 1D vs ND kernels.""" + return shape[0] if len(shape) == 1 else tuple(shape) + + +def _inverse_spacings(spacing_tuple: tuple[float, ...], power: int) -> list[float]: + """Compute inverse spacing terms with optional square for second derivatives.""" + if power == 1: + return [1.0 / float(dx) for dx in spacing_tuple] + return [1.0 / float(dx * dx) for dx in spacing_tuple] + + +def _mixed_inverse_spacings(spacing_tuple: tuple[float, ...]) -> list[float]: + """Compute inverse mixed spacing terms in axis-pair order.""" + return [ + 1.0 / float(spacing_tuple[i] * spacing_tuple[j]) + for i in range(len(spacing_tuple)) + for j in range(i + 1, len(spacing_tuple)) + ] + + +def _to_wp_components(components: list[torch.Tensor], count: int) -> list[wp.array]: + """Convert the leading tensor components to Warp arrays.""" + return [wp.from_torch(components[i], dtype=wp.float32) for i in range(count)] + + +def _launch_forward( + *, + field_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + order: int, + derivative_order: int, + grad_components: list[torch.Tensor], + wp_device, + wp_stream, +) -> None: + ### Launch dimensionality/order-specific forward kernels. + ndim = field_fp32.ndim + kernel = _FORWARD_KERNELS[(ndim, derivative_order, order)] + local_spacing = spacing_tuple[:ndim] + inv_terms = _inverse_spacings( + local_spacing, + power=1 if derivative_order == 1 else 2, + ) + + _wp_launch( + kernel=kernel, + dim=_launch_dim(field_fp32.shape), + inputs=[ + wp.from_torch(field_fp32, dtype=wp.float32), + *inv_terms, + *_to_wp_components(grad_components, ndim), + ], + device=wp_device, + stream=wp_stream, + ) + + +def _launch_forward_fused_order2( + *, + field_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + first_components: list[torch.Tensor], + second_components: list[torch.Tensor], + mixed_components: list[torch.Tensor], + include_mixed: bool, + wp_device, + wp_stream, +) -> None: + """Launch fused first/second/mixed derivative kernels (order=2 only).""" + ndim = field_fp32.ndim + kernel = _FUSED_FORWARD_KERNELS[(ndim, include_mixed)] + local_spacing = spacing_tuple[:ndim] + inv_first = _inverse_spacings(local_spacing, power=1) + inv_second = _inverse_spacings(local_spacing, power=2) + + inputs: list = [ + wp.from_torch(field_fp32, dtype=wp.float32), + *inv_first, + *inv_second, + ] + if include_mixed and ndim > 1: + inputs.extend(_mixed_inverse_spacings(local_spacing)) + + inputs.extend(_to_wp_components(first_components, ndim)) + inputs.extend(_to_wp_components(second_components, ndim)) + + if include_mixed and ndim > 1: + mixed_count = ndim * (ndim - 1) // 2 + inputs.extend(_to_wp_components(mixed_components, mixed_count)) + + _wp_launch( + kernel=kernel, + dim=_launch_dim(field_fp32.shape), + inputs=inputs, + device=wp_device, + stream=wp_stream, + ) diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/op.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/op.py new file mode 100644 index 0000000000..bd904122a6 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/op.py @@ -0,0 +1,489 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from .launch_backward import _launch_backward, _launch_backward_fused_order2_no_mixed +from .launch_forward import _launch_forward, _launch_forward_fused_order2 +from .utils import ( + _normalize_spacing, + _validate_derivative_order, + _validate_field, + _validate_include_mixed, + _validate_order, + _warp_launch_context, +) + + +@torch.library.custom_op( + "physicsnemo::uniform_grid_gradient_warp_impl", mutates_args=() +) +def uniform_grid_gradient_impl( + field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Compute periodic first or pure second derivatives on a uniform grid.""" + _validate_field(field) + spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + order = _validate_order(int(order)) + derivative_order = _validate_derivative_order(int(derivative_order)) + _validate_include_mixed( + derivative_order=derivative_order, + include_mixed=bool(include_mixed), + ) + + orig_dtype = field.dtype + field_fp32 = ( + field + if field.dtype == torch.float32 and field.is_contiguous() + else field.to(dtype=torch.float32).contiguous() + ) + + ### Write gradients directly into preallocated output slices to avoid stack copy. + output_fp32 = torch.empty( + (field_fp32.ndim, *field_fp32.shape), + device=field_fp32.device, + dtype=torch.float32, + ) + grad_components = [output_fp32[axis] for axis in range(field_fp32.ndim)] + + wp_device, wp_stream = _warp_launch_context(field_fp32) + _launch_forward( + field_fp32=field_fp32, + spacing_tuple=spacing_tuple, + order=order, + derivative_order=derivative_order, + grad_components=grad_components, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + if output_fp32.dtype != orig_dtype: + return output_fp32.to(dtype=orig_dtype) + return output_fp32 + + +@uniform_grid_gradient_impl.register_fake +def _uniform_grid_gradient_impl_fake( + field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, + derivative_order: int, + include_mixed: bool, +) -> torch.Tensor: + """Fake tensor propagation for uniform-grid custom op.""" + _ = (spacing_meta, order, derivative_order, include_mixed) + return torch.empty( + (field.ndim, *field.shape), + device=field.device, + dtype=field.dtype, + ) + + +def setup_uniform_grid_gradient_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for uniform-grid custom-op autograd.""" + field, spacing_meta, order, derivative_order, include_mixed = inputs + _ = output + ctx.spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + ctx.order = int(order) + ctx.derivative_order = int(derivative_order) + ctx.include_mixed = bool(include_mixed) + ctx.orig_dtype = field.dtype + + +def backward_uniform_grid_gradient( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None, None, None]: + """Backward pass for uniform-grid custom op (gradients wrt field only).""" + if grad_output is None or not ctx.needs_input_grad[0]: + return None, None, None, None, None + + grad_output_fp32 = ( + grad_output + if grad_output.dtype == torch.float32 and grad_output.is_contiguous() + else grad_output.to(dtype=torch.float32).contiguous() + ) + grad_field = torch.empty_like(grad_output_fp32[0]) + + wp_device, wp_stream = _warp_launch_context(grad_output_fp32) + _launch_backward( + grad_output_fp32=grad_output_fp32, + spacing_tuple=ctx.spacing_tuple, + order=ctx.order, + derivative_order=ctx.derivative_order, + grad_field=grad_field, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + if grad_field.dtype != ctx.orig_dtype: + grad_field = grad_field.to(dtype=ctx.orig_dtype) + return grad_field, None, None, None, None + + +uniform_grid_gradient_impl.register_autograd( + backward_uniform_grid_gradient, + setup_context=setup_uniform_grid_gradient_context, +) + + +@torch.library.custom_op( + "physicsnemo::uniform_grid_derivatives_order2_fused_warp_impl", + mutates_args=(), +) +def uniform_grid_derivatives_order2_fused_impl( + field: torch.Tensor, + spacing_meta: torch.Tensor, + include_mixed: bool, +) -> torch.Tensor: + """Compute fused order-2 derivatives (first + second + optional mixed).""" + _validate_field(field) + spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + if not isinstance(include_mixed, bool): + raise TypeError(f"include_mixed must be a bool, got {type(include_mixed)}") + if include_mixed and field.ndim < 2: + raise ValueError("mixed derivatives require at least 2D inputs") + + orig_dtype = field.dtype + field_fp32 = ( + field + if field.dtype == torch.float32 and field.is_contiguous() + else field.to(dtype=torch.float32).contiguous() + ) + n_dims = field_fp32.ndim + n_mixed = 0 + if include_mixed and n_dims == 2: + n_mixed = 1 + elif include_mixed and n_dims == 3: + n_mixed = 3 + output_fp32 = torch.empty( + (2 * n_dims + n_mixed, *field_fp32.shape), + device=field_fp32.device, + dtype=torch.float32, + ) + first_terms = [output_fp32[axis] for axis in range(n_dims)] + second_terms = [output_fp32[n_dims + axis] for axis in range(n_dims)] + mixed_terms = [output_fp32[2 * n_dims + axis] for axis in range(n_mixed)] + + wp_device, wp_stream = _warp_launch_context(field_fp32) + _launch_forward_fused_order2( + field_fp32=field_fp32, + spacing_tuple=spacing_tuple, + first_components=first_terms, + second_components=second_terms, + mixed_components=mixed_terms, + include_mixed=include_mixed, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + if output_fp32.dtype != orig_dtype: + return output_fp32.to(dtype=orig_dtype) + return output_fp32 + + +@uniform_grid_derivatives_order2_fused_impl.register_fake +def _uniform_grid_derivatives_order2_fused_impl_fake( + field: torch.Tensor, + spacing_meta: torch.Tensor, + include_mixed: bool, +) -> torch.Tensor: + """Fake tensor propagation for fused uniform order-2 custom op.""" + _ = spacing_meta + n_mixed = 0 + if include_mixed and field.ndim == 2: + n_mixed = 1 + elif include_mixed and field.ndim == 3: + n_mixed = 3 + return torch.empty( + (2 * field.ndim + n_mixed, *field.shape), + device=field.device, + dtype=field.dtype, + ) + + +def setup_uniform_grid_derivatives_order2_fused_context( + ctx: torch.autograd.function.FunctionCtx, inputs: tuple, output: torch.Tensor +) -> None: + """Store backward context for fused uniform order-2 custom op.""" + field, spacing_meta, include_mixed = inputs + _ = output + ctx.spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + ctx.orig_dtype = field.dtype + ctx.n_dims = field.ndim + ctx.include_mixed = bool(include_mixed) + + +def backward_uniform_grid_derivatives_order2_fused( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None]: + """Backward pass for fused uniform order-2 custom op.""" + if grad_output is None or not ctx.needs_input_grad[0]: + return None, None, None + + grad_output_fp32 = ( + grad_output + if grad_output.dtype == torch.float32 and grad_output.is_contiguous() + else grad_output.to(dtype=torch.float32).contiguous() + ) + n_dims = int(ctx.n_dims) + include_mixed = bool(ctx.include_mixed) + grad_first_components = [grad_output_fp32[axis] for axis in range(n_dims)] + grad_second_components = [grad_output_fp32[n_dims + axis] for axis in range(n_dims)] + grad_mixed_components: list[torch.Tensor] = [] + if include_mixed: + n_mixed = 1 if n_dims == 2 else 3 + grad_mixed_components = [ + grad_output_fp32[2 * n_dims + axis] for axis in range(n_mixed) + ] + grad_field = torch.empty_like(grad_output_fp32[0]) + + wp_device, wp_stream = _warp_launch_context(grad_output_fp32) + _launch_backward_fused_order2_no_mixed( + grad_first_components=grad_first_components, + grad_second_components=grad_second_components, + grad_mixed_components=grad_mixed_components, + spacing_tuple=ctx.spacing_tuple, + include_mixed=include_mixed, + grad_field=grad_field, + wp_device=wp_device, + wp_stream=wp_stream, + ) + if grad_field.dtype != ctx.orig_dtype: + grad_field = grad_field.to(dtype=ctx.orig_dtype) + return grad_field, None, None + + +uniform_grid_derivatives_order2_fused_impl.register_autograd( + backward_uniform_grid_derivatives_order2_fused, + setup_context=setup_uniform_grid_derivatives_order2_fused_context, +) + + +def uniform_grid_gradient_warp( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + derivative_order: int = 1, + include_mixed: bool = False, +) -> torch.Tensor: + """Compute periodic first or pure second derivatives on a uniform grid.""" + _validate_field(field) + spacing_tuple = _normalize_spacing(spacing, field.ndim) + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + order = _validate_order(order) + derivative_order = _validate_derivative_order(derivative_order) + _validate_include_mixed( + derivative_order=derivative_order, + include_mixed=include_mixed, + ) + spacing_meta = torch.tensor(spacing_tuple, dtype=torch.float32, device="cpu") + return uniform_grid_gradient_impl( + field, + spacing_meta, + int(order), + int(derivative_order), + bool(include_mixed), + ) + + +def uniform_grid_gradient_warp_multi( + field: torch.Tensor, + spacing: float | Sequence[float], + order: int, + derivative_orders: tuple[int, ...], + include_mixed: bool, +) -> torch.Tensor: + """Compute multiple derivative families, fusing Warp launches when possible. + + For ``order=2``, combined requests use fused kernels (with custom-op + backward when gradients are required), including optional mixed second + derivatives. ``order=4`` requests continue to compose single-order Warp calls. + """ + _validate_field(field) + spacing_tuple = _normalize_spacing(spacing, field.ndim) + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + order = _validate_order(order) + + if include_mixed and 2 not in derivative_orders: + raise ValueError("include_mixed requires requesting 2nd derivatives") + if include_mixed and field.ndim < 2: + raise ValueError("mixed derivatives require at least 2D inputs") + + spacing_meta = torch.tensor(spacing_tuple, dtype=torch.float32, device="cpu") + + ### Mixed requests compose existing single-order autograd-safe calls. + if include_mixed: + if order == 2: + fused = uniform_grid_derivatives_order2_fused_impl( + field, spacing_meta, True + ) + outputs: list[torch.Tensor] = [] + n_dims = field.ndim + n_mixed = 1 if n_dims == 2 else 3 + if 1 in derivative_orders: + outputs.extend(fused[:n_dims].unbind(0)) + if 2 in derivative_orders: + outputs.extend(fused[n_dims : 2 * n_dims].unbind(0)) + outputs.extend(fused[2 * n_dims : 2 * n_dims + n_mixed].unbind(0)) + return torch.stack(outputs, dim=0) + + outputs: list[torch.Tensor] = [] + if 1 in derivative_orders: + outputs.extend( + uniform_grid_gradient_warp( + field=field, + spacing=spacing_tuple, + order=order, + derivative_order=1, + include_mixed=False, + ).unbind(0) + ) + if 2 in derivative_orders: + pure_second = uniform_grid_gradient_warp( + field=field, + spacing=spacing_tuple, + order=order, + derivative_order=2, + include_mixed=False, + ) + outputs.extend(pure_second.unbind(0)) + first_terms = uniform_grid_gradient_warp( + field=field, + spacing=spacing_tuple, + order=order, + derivative_order=1, + include_mixed=False, + ) + for axis_i in range(field.ndim): + for axis_j in range(axis_i + 1, field.ndim): + mixed_ij = uniform_grid_gradient_warp( + field=first_terms[axis_i], + spacing=spacing_tuple, + order=order, + derivative_order=1, + include_mixed=False, + )[axis_j] + outputs.append(mixed_ij) + return torch.stack(outputs, dim=0) + + ### Single-order requests should use the direct single-order custom op path. + if len(derivative_orders) == 1: + return uniform_grid_gradient_warp( + field=field, + spacing=spacing_tuple, + order=order, + derivative_order=int(derivative_orders[0]), + include_mixed=False, + ) + + ### Backward-capable fused path for order-2 combined first+second requests. + if ( + order == 2 + and field.requires_grad + and 1 in derivative_orders + and 2 in derivative_orders + ): + fused = uniform_grid_derivatives_order2_fused_impl(field, spacing_meta, False) + outputs: list[torch.Tensor] = [] + n_dims = field.ndim + if 1 in derivative_orders: + outputs.extend(fused[:n_dims].unbind(0)) + if 2 in derivative_orders: + outputs.extend(fused[n_dims:].unbind(0)) + return torch.stack(outputs, dim=0) + + ### Fallback path: preserve full autograd behavior and order=4 support. + if field.requires_grad or order != 2: + outputs: list[torch.Tensor] = [] + if 1 in derivative_orders: + outputs.extend( + uniform_grid_gradient_warp( + field=field, + spacing=spacing_tuple, + order=order, + derivative_order=1, + include_mixed=False, + ).unbind(0) + ) + if 2 in derivative_orders: + pure_second = uniform_grid_gradient_warp( + field=field, + spacing=spacing_tuple, + order=order, + derivative_order=2, + include_mixed=False, + ) + outputs.extend(pure_second.unbind(0)) + return torch.stack(outputs, dim=0) + + ### Fused forward path (order=2 only, no-mixed, no autograd required). + field_fp32 = ( + field + if field.dtype == torch.float32 and field.is_contiguous() + else field.to(dtype=torch.float32).contiguous() + ) + n_dims = field_fp32.ndim + first_terms = [ + torch.empty_like(field_fp32, dtype=torch.float32) for _ in range(n_dims) + ] + second_terms = [ + torch.empty_like(field_fp32, dtype=torch.float32) for _ in range(n_dims) + ] + + wp_device, wp_stream = _warp_launch_context(field_fp32) + _launch_forward_fused_order2( + field_fp32=field_fp32, + spacing_tuple=spacing_tuple, + first_components=first_terms, + second_components=second_terms, + mixed_components=[], + include_mixed=False, + wp_device=wp_device, + wp_stream=wp_stream, + ) + + outputs: list[torch.Tensor] = [] + if 1 in derivative_orders: + outputs.extend(first_terms) + if 2 in derivative_orders: + outputs.extend(second_terms) + + output = torch.stack(outputs, dim=0) + if output.dtype != field.dtype: + output = output.to(dtype=field.dtype) + return output diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/utils.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/utils.py new file mode 100644 index 0000000000..aed5df51b3 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/utils.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import warp as wp + +_SUPPORTED_ORDERS = (2, 4) +_SUPPORTED_DERIVATIVE_ORDERS = (1, 2) + +### Warp runtime initialization for custom kernels. +wp.init() +wp.config.quiet = True + +### Optional launch block size override; <=0 uses Warp default autotuning. +_WARP_BLOCK_DIM = -1 + + +def _normalize_spacing( + spacing: float | Sequence[float], ndim: int +) -> tuple[float, ...]: + ### Normalize scalar/list spacing into one value per axis. + if isinstance(spacing, (float, int)): + return tuple(float(spacing) for _ in range(ndim)) + spacing_tuple = tuple(float(x) for x in spacing) + if len(spacing_tuple) != ndim: + raise ValueError( + f"spacing must have {ndim} entries for a {ndim}D field, got {len(spacing_tuple)}" + ) + return spacing_tuple + + +def _validate_order(order: int) -> int: + ### Validate finite-difference order selection. + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + f"uniform_grid_gradient supports {list(_SUPPORTED_ORDERS)} central orders, got order={order}" + ) + return order + + +def _validate_derivative_order(derivative_order: int) -> int: + ### Validate derivative-order selection (first vs pure second derivative). + if not isinstance(derivative_order, int): + raise TypeError( + f"derivative_order must be an integer, got {type(derivative_order)}" + ) + if derivative_order not in _SUPPORTED_DERIVATIVE_ORDERS: + raise ValueError( + "uniform_grid_gradient supports derivative_order in [1, 2], " + f"got derivative_order={derivative_order}" + ) + return derivative_order + + +def _validate_include_mixed( + *, + derivative_order: int, + include_mixed: bool, +) -> None: + ### Phase-1 guard: mixed second derivatives are intentionally not yet exposed. + if not isinstance(include_mixed, bool): + raise TypeError(f"include_mixed must be a bool, got {type(include_mixed)}") + if include_mixed and derivative_order != 2: + raise ValueError("include_mixed is only valid when derivative_order=2") + if include_mixed: + raise NotImplementedError( + "include_mixed=True is not yet supported; phase-1 supports pure axis-wise " + "second derivatives only" + ) + + +def _validate_field(field: torch.Tensor) -> None: + ### Validate field shape and dtype. + if field.ndim < 1 or field.ndim > 3: + raise ValueError( + f"uniform_grid_gradient supports 1D-3D fields, got {field.shape=}" + ) + if not torch.is_floating_point(field): + raise TypeError("field must be a floating-point tensor") + + +def _wp_launch( + *, + kernel, + dim, + inputs, + device, + stream, +) -> None: + ### Launch a Warp kernel, optionally overriding block size. + if _WARP_BLOCK_DIM > 0: + wp.launch( + kernel=kernel, + dim=dim, + inputs=inputs, + device=device, + stream=stream, + block_dim=_WARP_BLOCK_DIM, + ) + return + wp.launch( + kernel=kernel, + dim=dim, + inputs=inputs, + device=device, + stream=stream, + ) + + +def _warp_launch_context(field: torch.Tensor): + ### Resolve warp launch context without per-call dynamic imports. + if field.device.type == "cuda": + return None, wp.stream_from_torch(torch.cuda.current_stream(field.device)) + return "cpu", None diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/uniform_grid_gradient.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/uniform_grid_gradient.py new file mode 100644 index 0000000000..41c7b85eac --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/uniform_grid_gradient.py @@ -0,0 +1,394 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence +from itertools import combinations + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from .._request_utils import ( + normalize_derivative_orders, + normalize_include_mixed, + validate_mixed_request, +) +from ._torch_impl import uniform_grid_gradient_torch +from ._warp_impl import uniform_grid_gradient_warp, uniform_grid_gradient_warp_multi + +### Auto-dispatch crossover threshold for 3D CUDA fields. +### <= TORCH_MAX uses eager torch, larger fields use warp. +_AUTO_3D_TORCH_MAX_NUMEL = 48 * 48 * 48 + + +class UniformGridGradient(FunctionSpec): + r"""Compute periodic central-difference gradients on a uniform grid. + + This functional computes first-order and/or second-order + derivatives of a scalar field defined on a 1D/2D/3D uniform Cartesian + grid with periodic indexing. + + For each axis :math:`k`, the first derivative is: + + .. math:: + + \partial_k f(\mathbf{i}) \approx + \frac{f(\mathbf{i}+\hat{e}_k) - f(\mathbf{i}-\hat{e}_k)}{2\,\Delta x_k} + + and the pure second derivative is: + + .. math:: + + \partial_{kk} f(\mathbf{i}) \approx + \frac{f(\mathbf{i}+\hat{e}_k)-2f(\mathbf{i})+f(\mathbf{i}-\hat{e}_k)} + {\Delta x_k^2} + + with periodic wrap-around at boundaries. + + Parameters + ---------- + field : torch.Tensor + Scalar grid field with shape ``(n0,)``, ``(n0,n1)``, or ``(n0,n1,n2)``. + spacing : float | Sequence[float], optional + Uniform spacing per axis. Use a scalar for isotropic spacing or a + sequence matching field dimensionality. + order : int, optional + Central-difference accuracy order. Supported values are ``2`` and ``4``. + derivative_orders : int | Sequence[int], optional + Derivative orders to compute. Supported values are ``1``, ``2``, or + ``(1, 2)``. + include_mixed : bool, optional + Include mixed second derivatives when requesting second derivatives. + Mixed terms are appended in axis-pair order ``(x,y)``, ``(x,z)``, + ``(y,z)``. + implementation : {"warp", "torch"} or None + Explicit backend selection. When ``None``, ``uniform_grid_gradient`` + applies a shape-aware auto-dispatch heuristic. + + Returns + ------- + torch.Tensor + Gradient tensor of shape ``(num_derivatives, *field.shape)``. + """ + + ### Benchmark input presets (small -> large workload). + _BENCHMARK_CASES = ( + ("1d-n8192-o2-d1", (8192,), 0.01, 2, 1), + ("1d-n8192-o4-d1", (8192,), 0.01, 4, 1), + ("2d-512x512-o2-d1", (512, 512), (0.01, 0.02), 2, 1), + ("2d-512x512-o2-d2", (512, 512), (0.01, 0.02), 2, 2), + ("3d-128x128x128-o2-d1", (128, 128, 128), 0.02, 2, 1), + ("3d-96x96x96-o2-d2", (96, 96, 96), 0.02, 2, 2), + ) + + _COMPARE_ATOL = 1e-5 + _COMPARE_RTOL = 1e-5 + _COMPARE_BACKWARD_ATOL = 1e-5 + _COMPARE_BACKWARD_RTOL = 1e-5 + + @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0) + def warp_forward( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + derivative_orders: int | Sequence[int] = 1, + include_mixed: bool = False, + ) -> torch.Tensor: + """Dispatch uniform-grid gradients to the Warp backend.""" + return _dispatch_uniform_grid_requests( + backend_fn=uniform_grid_gradient_warp, + field=field, + spacing=spacing, + order=order, + derivative_orders=derivative_orders, + include_mixed=include_mixed, + ) + + @FunctionSpec.register(name="torch", rank=2, baseline=True) + def torch_forward( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + derivative_orders: int | Sequence[int] = 1, + include_mixed: bool = False, + ) -> torch.Tensor: + """Dispatch uniform-grid gradients to eager PyTorch.""" + return _dispatch_uniform_grid_requests( + backend_fn=uniform_grid_gradient_torch, + field=field, + spacing=spacing, + order=order, + derivative_orders=derivative_orders, + include_mixed=include_mixed, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build periodic analytic fields for benchmark and parity coverage. + for label, shape, spacing, order, derivative_order in cls._BENCHMARK_CASES: + if len(shape) == 1: + x = torch.linspace(0.0, 1.0, shape[0], device=device) + field = torch.sin(2.0 * torch.pi * x) + elif len(shape) == 2: + x0 = torch.linspace(0.0, 1.0, shape[0], device=device) + x1 = torch.linspace(0.0, 1.0, shape[1], device=device) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + field = torch.sin(2.0 * torch.pi * xx) + 0.5 * torch.cos( + 2.0 * torch.pi * yy + ) + else: + x0 = torch.linspace(0.0, 1.0, shape[0], device=device) + x1 = torch.linspace(0.0, 1.0, shape[1], device=device) + x2 = torch.linspace(0.0, 1.0, shape[2], device=device) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + field = ( + torch.sin(2.0 * torch.pi * xx) + + 0.5 * torch.cos(2.0 * torch.pi * yy) + + 0.25 * torch.sin(2.0 * torch.pi * zz) + ) + + ### Yield the labeled functional input case. + yield ( + label, + (field.to(torch.float32),), + { + "spacing": spacing, + "order": order, + "derivative_orders": derivative_order, + "include_mixed": False, + }, + ) + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + + ### Build representative differentiable fields for backward parity. + backward_cases = ( + ("1d-grad-n4096-o2-d1", (4096,), 0.01, 2, 1), + ("2d-grad-256x256-o2-d1", (256, 256), (0.01, 0.02), 2, 1), + ("2d-grad-256x256-o2-d2", (256, 256), (0.01, 0.02), 2, 2), + ("2d-grad-256x256-o4-d1", (256, 256), (0.01, 0.02), 4, 1), + ("3d-grad-96x96x96-o2-d1", (96, 96, 96), 0.02, 2, 1), + ("3d-grad-64x64x64-o2-d2", (64, 64, 64), 0.02, 2, 2), + ) + + for label, shape, spacing, order, derivative_order in backward_cases: + if len(shape) == 1: + x = torch.linspace(0.0, 1.0, shape[0], device=device) + field = torch.sin(2.0 * torch.pi * x) + elif len(shape) == 2: + x0 = torch.linspace(0.0, 1.0, shape[0], device=device) + x1 = torch.linspace(0.0, 1.0, shape[1], device=device) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + field = torch.sin(2.0 * torch.pi * xx) + 0.5 * torch.cos( + 2.0 * torch.pi * yy + ) + else: + x0 = torch.linspace(0.0, 1.0, shape[0], device=device) + x1 = torch.linspace(0.0, 1.0, shape[1], device=device) + x2 = torch.linspace(0.0, 1.0, shape[2], device=device) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + field = ( + torch.sin(2.0 * torch.pi * xx) + + 0.5 * torch.cos(2.0 * torch.pi * yy) + + 0.25 * torch.sin(2.0 * torch.pi * zz) + ) + + ### Yield differentiable field inputs for backward dispatch. + yield ( + label, + (field.to(torch.float32).detach().clone().requires_grad_(True),), + { + "spacing": spacing, + "order": order, + "derivative_orders": derivative_order, + "include_mixed": False, + }, + ) + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare forward outputs across implementations.""" + ### Validate forward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + @classmethod + def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare backward gradients across implementations.""" + ### Validate backward parity across backends. + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_BACKWARD_ATOL, + rtol=cls._COMPARE_BACKWARD_RTOL, + ) + + +def _auto_select_implementation(field: torch.Tensor) -> str: + ### Select backend by dimensionality/size on CUDA and by capability on CPU. + available = set(UniformGridGradient.available_implementations()) + if "warp" not in available: + return "torch" + + if field.device.type != "cuda": + return "torch" + + ### Autograd paths should prefer the explicit Warp autograd kernels. + if field.requires_grad: + return "warp" + + ### 1D/2D generally favor eager torch in current measurements. + if field.ndim in (1, 2): + if "torch" in available: + return "torch" + return "warp" + + ### 3D uses a single-threshold crossover: torch -> warp. + numel = field.numel() + if numel <= _AUTO_3D_TORCH_MAX_NUMEL and "torch" in available: + return "torch" + return "warp" + + +_uniform_grid_gradient_dispatch = UniformGridGradient.make_function( + "_uniform_grid_gradient_dispatch" +) + + +def uniform_grid_gradient( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + derivative_orders: int | Sequence[int] = 1, + include_mixed: bool = False, + implementation: str | None = None, +) -> torch.Tensor: + """Compute periodic first and/or second derivatives on a uniform grid. + + When ``implementation`` is ``None``, a shape-aware backend heuristic is + used: on CUDA, 1D/2D fields prefer ``torch``; 3D fields use a single-threshold + crossover (``torch`` -> ``warp``) as problem size + grows. Inputs requiring gradients prefer ``warp`` to use the explicit + custom backward kernels. + """ + if implementation is None: + implementation = _auto_select_implementation(field) + return _uniform_grid_gradient_dispatch( + field, + spacing=spacing, + order=order, + derivative_orders=derivative_orders, + include_mixed=include_mixed, + implementation=implementation, + ) + + +__all__ = ["UniformGridGradient", "uniform_grid_gradient"] + + +def _dispatch_uniform_grid_requests( + *, + backend_fn, + field: torch.Tensor, + spacing: float | Sequence[float], + order: int, + derivative_orders: int | Sequence[int], + include_mixed: bool, +) -> torch.Tensor: + """Resolve unified derivative requests and dispatch to backend kernels.""" + requested_orders = normalize_derivative_orders( + derivative_orders=derivative_orders, + function_name="uniform_grid_gradient", + ) + mixed_terms = normalize_include_mixed( + include_mixed=include_mixed, + function_name="uniform_grid_gradient", + ) + validate_mixed_request( + derivative_orders=requested_orders, + include_mixed=mixed_terms, + ndim=field.ndim, + function_name="uniform_grid_gradient", + ) + + if backend_fn is uniform_grid_gradient_warp and ( + len(requested_orders) > 1 or mixed_terms + ): + return uniform_grid_gradient_warp_multi( + field=field, + spacing=spacing, + order=order, + derivative_orders=requested_orders, + include_mixed=mixed_terms, + ) + + outputs: list[torch.Tensor] = [] + first_terms: torch.Tensor | None = None + + if 1 in requested_orders: + first_terms = backend_fn( + field=field, + spacing=spacing, + order=order, + derivative_order=1, + include_mixed=False, + ) + outputs.extend(first_terms.unbind(0)) + + if 2 in requested_orders: + pure_second_terms = backend_fn( + field=field, + spacing=spacing, + order=order, + derivative_order=2, + include_mixed=False, + ) + outputs.extend(pure_second_terms.unbind(0)) + + if mixed_terms: + if first_terms is None: + first_terms = backend_fn( + field=field, + spacing=spacing, + order=order, + derivative_order=1, + include_mixed=False, + ) + + for axis_i, axis_j in combinations(range(field.ndim), 2): + axis_i_first = first_terms[axis_i] + mixed_ij = backend_fn( + field=axis_i_first, + spacing=spacing, + order=order, + derivative_order=1, + include_mixed=False, + )[axis_j] + outputs.append(mixed_ij) + + return torch.stack(outputs, dim=0) diff --git a/test/nn/functional/derivatives/test_mesh_green_gauss_gradient.py b/test/nn/functional/derivatives/test_mesh_green_gauss_gradient.py new file mode 100644 index 0000000000..c96133c5fc --- /dev/null +++ b/test/nn/functional/derivatives/test_mesh_green_gauss_gradient.py @@ -0,0 +1,370 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.nn.functional import mesh_green_gauss_gradient +from physicsnemo.nn.functional.derivatives import MeshGreenGaussGradient +from physicsnemo.nn.functional.derivatives.mesh_green_gauss_gradient.utils import ( + build_neighbors, +) +from test.conftest import requires_module +from test.nn.functional._parity_utils import clone_case + + +# Build a deterministic structured triangular mesh. +def _build_case(device: str, nx: int = 36, ny: int = 32): + torch_device = torch.device(device) + x = torch.linspace(0.0, 1.0, nx, device=torch_device, dtype=torch.float32) + y = torch.linspace(0.0, 1.0, ny, device=torch_device, dtype=torch.float32) + xx, yy = torch.meshgrid(x, y, indexing="ij") + points = torch.stack((xx.reshape(-1), yy.reshape(-1)), dim=-1) + + cells = [] + for i in range(nx - 1): + for j in range(ny - 1): + p00 = i * ny + j + p10 = (i + 1) * ny + j + p01 = i * ny + (j + 1) + p11 = (i + 1) * ny + (j + 1) + cells.append((p00, p10, p11)) + cells.append((p00, p11, p01)) + cells = torch.tensor(cells, device=torch_device, dtype=torch.int64) + return points.contiguous(), cells.contiguous() + + +def _build_case_3d(device: str, nx: int = 12, ny: int = 10, nz: int = 8): + torch_device = torch.device(device) + x = torch.linspace(0.0, 1.0, nx, device=torch_device, dtype=torch.float32) + y = torch.linspace(0.0, 1.0, ny, device=torch_device, dtype=torch.float32) + z = torch.linspace(0.0, 1.0, nz, device=torch_device, dtype=torch.float32) + xx, yy, zz = torch.meshgrid(x, y, z, indexing="ij") + points = torch.stack((xx.reshape(-1), yy.reshape(-1), zz.reshape(-1)), dim=-1) + + def _idx(i: int, j: int, k: int) -> int: + return (i * ny + j) * nz + k + + cells = [] + for i in range(nx - 1): + for j in range(ny - 1): + for k in range(nz - 1): + p000 = _idx(i, j, k) + p100 = _idx(i + 1, j, k) + p010 = _idx(i, j + 1, k) + p110 = _idx(i + 1, j + 1, k) + p001 = _idx(i, j, k + 1) + p101 = _idx(i + 1, j, k + 1) + p011 = _idx(i, j + 1, k + 1) + p111 = _idx(i + 1, j + 1, k + 1) + + cells.append((p000, p100, p110, p111)) + cells.append((p000, p110, p010, p111)) + cells.append((p000, p010, p011, p111)) + cells.append((p000, p011, p001, p111)) + cells.append((p000, p001, p101, p111)) + cells.append((p000, p101, p100, p111)) + + cells = torch.tensor(cells, device=torch_device, dtype=torch.int64) + return points.contiguous(), cells.contiguous() + + +# Validate torch Green-Gauss reconstruction on a linear field. +def test_mesh_green_gauss_gradient_torch(device: str): + points, cells = _build_case(device=device, nx=40, ny=34) + neighbors = build_neighbors(cells) + centroids = points[cells].mean(dim=1) + coeff = torch.tensor([2.0, -3.0], device=points.device, dtype=torch.float32) + values = (centroids * coeff).sum(dim=-1) + + output = MeshGreenGaussGradient.dispatch( + points, + cells, + neighbors, + values, + implementation="torch", + ) + interior = (neighbors >= 0).all(dim=1) + expected = coeff.view(1, -1).expand(interior.sum(), -1) + torch.testing.assert_close(output[interior], expected, atol=5e-2, rtol=5e-2) + + +def test_mesh_green_gauss_gradient_torch_3d(device: str): + points, cells = _build_case_3d(device=device, nx=11, ny=9, nz=7) + neighbors = build_neighbors(cells) + values = torch.sin(points[cells].mean(dim=1).sum(dim=-1)) + + output = MeshGreenGaussGradient.dispatch( + points, + cells, + neighbors, + values, + implementation="torch", + ) + assert output.shape == (cells.shape[0], points.shape[1]) + assert torch.isfinite(output).all() + + +# Validate warp Green-Gauss reconstruction on a linear field. +@requires_module("warp") +def test_mesh_green_gauss_gradient_warp(device: str): + points, cells = _build_case(device=device, nx=40, ny=34) + neighbors = build_neighbors(cells) + centroids = points[cells].mean(dim=1) + coeff = torch.tensor([2.0, -3.0], device=points.device, dtype=torch.float32) + values = (centroids * coeff).sum(dim=-1) + + output = MeshGreenGaussGradient.dispatch( + points, + cells, + neighbors, + values, + implementation="warp", + ) + interior = (neighbors >= 0).all(dim=1) + expected = coeff.view(1, -1).expand(interior.sum(), -1) + torch.testing.assert_close(output[interior], expected, atol=5e-2, rtol=5e-2) + + +@requires_module("warp") +def test_mesh_green_gauss_gradient_warp_3d(device: str): + points, cells = _build_case_3d(device=device, nx=11, ny=9, nz=7) + neighbors = build_neighbors(cells) + values = torch.sin(points[cells].mean(dim=1).sum(dim=-1)) + + output_torch = MeshGreenGaussGradient.dispatch( + points, + cells, + neighbors, + values, + implementation="torch", + ) + + output_warp = MeshGreenGaussGradient.dispatch( + points, + cells, + neighbors, + values, + implementation="warp", + ) + MeshGreenGaussGradient.compare_forward(output_warp, output_torch) + + +# Validate warp backend forward parity against torch across benchmark cases. +@requires_module("warp") +def test_mesh_green_gauss_gradient_backend_forward_parity(device: str): + for _label, args, kwargs in MeshGreenGaussGradient.make_inputs_forward( + device=device + ): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = MeshGreenGaussGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_warp = MeshGreenGaussGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + MeshGreenGaussGradient.compare_forward(out_warp, out_torch) + + +# Validate warp backend backward parity against torch on value gradients. +@requires_module("warp") +def test_mesh_green_gauss_gradient_backend_backward_parity(device: str): + for _label, args, kwargs in MeshGreenGaussGradient.make_inputs_backward( + device=device + ): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = MeshGreenGaussGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_torch.square().mean().backward() + grad_torch = args_torch[3].grad + assert grad_torch is not None + + out_warp = MeshGreenGaussGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + out_warp.square().mean().backward() + grad_warp = args_warp[3].grad + assert grad_warp is not None + + MeshGreenGaussGradient.compare_backward(grad_warp, grad_torch) + + +@requires_module("warp") +def test_mesh_green_gauss_gradient_warp_supports_point_gradients(device: str): + points, cells = _build_case(device=device, nx=26, ny=22) + neighbors = build_neighbors(cells) + centroids = points[cells].mean(dim=1) + base_values = ( + torch.sin(2.0 * torch.pi * centroids[:, 0]) + + 0.25 * torch.cos(2.0 * torch.pi * centroids[:, 1]) + ).to(torch.float32) + + points_warp = points.detach().clone().requires_grad_(True) + values_warp = base_values.detach().clone().requires_grad_(True) + out_warp = MeshGreenGaussGradient.dispatch( + points_warp, + cells, + neighbors, + values_warp, + implementation="warp", + ) + out_warp.square().mean().backward() + grad_points_warp = points_warp.grad + grad_values_warp = values_warp.grad + assert grad_points_warp is not None + assert grad_values_warp is not None + assert torch.isfinite(grad_points_warp).all() + assert torch.isfinite(grad_values_warp).all() + assert torch.any(grad_points_warp != 0.0) + + +# Validate benchmark input generation contract for forward inputs. +def test_mesh_green_gauss_gradient_make_inputs_forward(device: str): + label, args, kwargs = next( + iter(MeshGreenGaussGradient.make_inputs_forward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + points, cells, neighbors, values = args + assert points.ndim == 2 + assert cells.ndim == 2 + assert neighbors.shape == (cells.shape[0], cells.shape[1]) + assert values.shape[0] == cells.shape[0] + + output = MeshGreenGaussGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + assert output.shape[0] == cells.shape[0] + assert output.shape[1] == points.shape[1] + + +# Validate benchmark input generation contract for backward inputs. +def test_mesh_green_gauss_gradient_make_inputs_backward(device: str): + label, args, kwargs = next( + iter(MeshGreenGaussGradient.make_inputs_backward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + values = args[3] + assert values.requires_grad + + output = MeshGreenGaussGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + output.square().mean().backward() + assert values.grad is not None + + +# Validate compare-forward hook contract. +def test_mesh_green_gauss_gradient_compare_forward_contract(device: str): + _label, args, kwargs = next( + iter(MeshGreenGaussGradient.make_inputs_forward(device=device)) + ) + output = MeshGreenGaussGradient.dispatch(*args, implementation="torch", **kwargs) + reference = output.detach().clone() + MeshGreenGaussGradient.compare_forward(output, reference) + + +# Validate compare-backward hook contract. +def test_mesh_green_gauss_gradient_compare_backward_contract(device: str): + _label, args, kwargs = next( + iter(MeshGreenGaussGradient.make_inputs_backward(device=device)) + ) + values = args[3] + + output = MeshGreenGaussGradient.dispatch(*args, implementation="torch", **kwargs) + output.square().mean().backward() + + assert values.grad is not None + MeshGreenGaussGradient.compare_backward(values.grad, values.grad.detach().clone()) + + +# Validate exported API and input validation paths. +def test_mesh_green_gauss_gradient_error_handling(device: str): + points, cells = _build_case(device=device, nx=16, ny=14) + values = torch.randn(cells.shape[0], device=points.device, dtype=torch.float32) + + neighbors = build_neighbors(cells) + output = mesh_green_gauss_gradient(points, cells, neighbors, values) + assert output.shape[0] == cells.shape[0] + assert output.shape[1] == points.shape[1] + + with pytest.raises(ValueError, match="supports dims in"): + bad_points = torch.randn( + points.shape[0], 4, device=points.device, dtype=torch.float32 + ) + MeshGreenGaussGradient.dispatch( + bad_points, + cells, + neighbors, + values, + implementation="torch", + ) + + with pytest.raises(ValueError, match="must contain 3 vertices"): + bad_cells = torch.randint( + 0, + points.shape[0], + (cells.shape[0], 4), + device=points.device, + dtype=torch.int64, + ) + MeshGreenGaussGradient.dispatch( + points, + bad_cells, + neighbors, + values, + implementation="torch", + ) + + with pytest.raises(ValueError, match="leading dimension must match n_cells"): + MeshGreenGaussGradient.dispatch( + points, + cells, + neighbors, + values[:-1], + implementation="torch", + ) + + with pytest.raises(ValueError, match="neighbors shape must match"): + MeshGreenGaussGradient.dispatch( + points, + cells, + neighbors[:, :-1], + values, + implementation="torch", + ) diff --git a/test/nn/functional/derivatives/test_mesh_lsq_gradient.py b/test/nn/functional/derivatives/test_mesh_lsq_gradient.py new file mode 100644 index 0000000000..7f91225137 --- /dev/null +++ b/test/nn/functional/derivatives/test_mesh_lsq_gradient.py @@ -0,0 +1,375 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.nn.functional import mesh_lsq_gradient +from physicsnemo.nn.functional.derivatives import MeshLSQGradient +from test.conftest import requires_module +from test.nn.functional._parity_utils import clone_case + + +# Build deterministic KNN-CSR test data on random points. +def _make_case(device: str, n_entities: int, n_dims: int, k_neighbors: int): + torch_device = torch.device(device) + generator = torch.Generator(device=torch_device) + generator.manual_seed(1234 + n_entities + n_dims) + + points = torch.rand((n_entities, n_dims), generator=generator, device=torch_device) + dists = torch.cdist(points, points) + knn = torch.topk(dists, k=k_neighbors + 1, largest=False, dim=1).indices[:, 1:] + + offsets = torch.arange( + 0, + n_entities * k_neighbors + 1, + k_neighbors, + device=torch_device, + dtype=torch.int64, + ) + indices = knn.reshape(-1).to(torch.int64) + return points, offsets, indices + + +# Validate torch LSQ reconstruction on an affine scalar field. +@pytest.mark.parametrize("n_dims", [1, 2, 3]) +def test_mesh_lsq_gradient_torch(device: str, n_dims: int): + points, offsets, indices = _make_case( + device, n_entities=1024, n_dims=n_dims, k_neighbors=16 + ) + + coeff = torch.arange(1, n_dims + 1, device=points.device, dtype=torch.float32) + values = (points * coeff).sum(dim=-1) + + output = MeshLSQGradient.dispatch( + points, + values, + offsets, + indices, + implementation="torch", + ) + + expected = coeff.view(1, -1).expand(points.shape[0], -1) + torch.testing.assert_close(output, expected, atol=3e-3, rtol=3e-3) + + +# Validate warp backend parity against torch across benchmark representative inputs. +@requires_module("warp") +def test_mesh_lsq_gradient_backend_forward_parity(device: str): + for _label, args, kwargs in MeshLSQGradient.make_inputs_forward(device=device): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = MeshLSQGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_warp = MeshLSQGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + MeshLSQGradient.compare_forward(out_warp, out_torch) + + +# Validate warp backward parity against torch on differentiable value fields. +@requires_module("warp") +def test_mesh_lsq_gradient_backend_backward_parity(device: str): + for _label, args, kwargs in MeshLSQGradient.make_inputs_backward(device=device): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = MeshLSQGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_torch.square().mean().backward() + grad_torch = args_torch[1].grad + assert grad_torch is not None + + out_warp = MeshLSQGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + out_warp.square().mean().backward() + grad_warp = args_warp[1].grad + assert grad_warp is not None + + MeshLSQGradient.compare_backward(grad_warp, grad_torch) + + +@requires_module("warp") +def test_mesh_lsq_gradient_warp_supports_point_gradients(device: str): + points, offsets, indices = _make_case( + device, n_entities=768, n_dims=3, k_neighbors=12 + ) + base_values = ( + torch.sin(2.0 * torch.pi * points[:, 0]) + + 0.4 * torch.cos(2.0 * torch.pi * points[:, 1]) + + 0.2 * points[:, 2].square() + ).to(torch.float32) + + points_warp = points.detach().clone().to(torch.float32).requires_grad_(True) + values_warp = base_values.detach().clone().requires_grad_(True) + out_warp = MeshLSQGradient.dispatch( + points_warp, + values_warp, + offsets, + indices, + implementation="warp", + ) + out_warp.square().mean().backward() + grad_points_warp = points_warp.grad + grad_values_warp = values_warp.grad + assert grad_points_warp is not None + assert grad_values_warp is not None + assert torch.isfinite(grad_points_warp).all() + assert torch.isfinite(grad_values_warp).all() + + +# Validate warp backend on 1D input parity against torch. +@requires_module("warp") +def test_mesh_lsq_gradient_warp(device: str): + points, offsets, indices = _make_case( + device, n_entities=512, n_dims=1, k_neighbors=16 + ) + values = torch.sin(2.0 * torch.pi * points[:, 0]).to(torch.float32) + + out_torch = MeshLSQGradient.dispatch( + points, + values, + offsets, + indices, + implementation="torch", + ) + out_warp = MeshLSQGradient.dispatch( + points, + values, + offsets, + indices, + implementation="warp", + ) + MeshLSQGradient.compare_forward(out_warp, out_torch) + + +# Validate benchmark input generation contract for forward inputs. +def test_mesh_lsq_gradient_make_inputs_forward(device: str): + label, args, kwargs = next(iter(MeshLSQGradient.make_inputs_forward(device=device))) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + points, values, offsets, indices = args + assert points.ndim == 2 + assert values.shape[0] == points.shape[0] + assert offsets.ndim == 1 + assert indices.ndim == 1 + + output = MeshLSQGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + assert output.shape[0] == points.shape[0] + assert output.shape[1] == points.shape[1] + + +# Validate benchmark input generation contract for backward inputs. +def test_mesh_lsq_gradient_make_inputs_backward(device: str): + label, args, kwargs = next( + iter(MeshLSQGradient.make_inputs_backward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + values = args[1] + assert values.requires_grad + + output = MeshLSQGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + output.square().mean().backward() + assert values.grad is not None + + +# Validate compare-forward hook contract. +def test_mesh_lsq_gradient_compare_forward_contract(device: str): + _label, args, kwargs = next( + iter(MeshLSQGradient.make_inputs_forward(device=device)) + ) + output = MeshLSQGradient.dispatch(*args, implementation="torch", **kwargs) + reference = output.detach().clone() + MeshLSQGradient.compare_forward(output, reference) + + +# Validate compare-backward hook contract. +def test_mesh_lsq_gradient_compare_backward_contract(device: str): + _label, args, kwargs = next( + iter(MeshLSQGradient.make_inputs_backward(device=device)) + ) + values = args[1] + + output = MeshLSQGradient.dispatch(*args, implementation="torch", **kwargs) + output.square().mean().backward() + + assert values.grad is not None + MeshLSQGradient.compare_backward(values.grad, values.grad.detach().clone()) + + +# Validate exported API and input validation paths. +def test_mesh_lsq_gradient_error_handling(device: str): + points, offsets, indices = _make_case( + device, n_entities=128, n_dims=3, k_neighbors=8 + ) + values = torch.sin(points[:, 0]) + + output = mesh_lsq_gradient(points, values, offsets, indices) + assert output.shape == (points.shape[0], points.shape[1]) + assert output.dtype == torch.float32 + + with pytest.raises(ValueError, match=r"must have shape \(n_entities \+ 1,\)"): + MeshLSQGradient.dispatch( + points, + values, + offsets[:-1], + indices, + implementation="torch", + ) + + with pytest.raises(ValueError, match=r"must equal len\(neighbor_indices\)"): + bad_offsets = offsets.clone() + bad_offsets[-1] = bad_offsets[-1] - 1 + MeshLSQGradient.dispatch( + points, + values, + bad_offsets, + indices, + implementation="torch", + ) + + with pytest.raises(ValueError, match="neighbor_offsets must be non-decreasing"): + bad_offsets = offsets.clone() + mid = bad_offsets.shape[0] // 2 + bad_offsets[mid] = bad_offsets[mid - 1] - 1 + MeshLSQGradient.dispatch( + points, + values, + bad_offsets, + indices, + implementation="torch", + ) + + with pytest.raises(ValueError, match="values leading dimension must match points"): + MeshLSQGradient.dispatch( + points, + values[:-1], + offsets, + indices, + implementation="torch", + ) + + with pytest.raises(TypeError, match="neighbor_offsets must be int32 or int64"): + MeshLSQGradient.dispatch( + points, + values, + offsets.to(torch.float32), + indices, + implementation="torch", + ) + + with pytest.raises(ValueError, match="must satisfy 0 <= index < n_entities"): + bad_indices = indices.clone() + bad_indices[0] = points.shape[0] + MeshLSQGradient.dispatch( + points, + values, + offsets, + bad_indices, + implementation="torch", + ) + + with pytest.raises( + ValueError, match="safe_epsilon must be a finite positive value" + ): + MeshLSQGradient.dispatch( + points, + values, + offsets, + indices, + safe_epsilon=0.0, + implementation="torch", + ) + + if torch.cuda.is_available(): + other_device = torch.device("cuda" if points.device.type == "cpu" else "cpu") + with pytest.raises(ValueError, match="must be on the same device"): + MeshLSQGradient.dispatch( + points, + values, + offsets.to(other_device), + indices, + implementation="torch", + ) + + +# Validate warp backend input validation paths mirror torch behavior. +@requires_module("warp") +def test_mesh_lsq_gradient_error_handling_warp(device: str): + points, offsets, indices = _make_case( + device, n_entities=128, n_dims=3, k_neighbors=8 + ) + values = torch.sin(points[:, 0]) + + with pytest.raises(ValueError, match="neighbor_offsets must be non-decreasing"): + bad_offsets = offsets.clone() + mid = bad_offsets.shape[0] // 2 + bad_offsets[mid] = bad_offsets[mid - 1] - 1 + MeshLSQGradient.dispatch( + points, + values, + bad_offsets, + indices, + implementation="warp", + ) + + with pytest.raises( + ValueError, match="safe_epsilon must be a finite positive value" + ): + MeshLSQGradient.dispatch( + points, + values, + offsets, + indices, + safe_epsilon=-1.0, + implementation="warp", + ) + + if torch.cuda.is_available(): + other_device = torch.device("cuda" if points.device.type == "cpu" else "cpu") + with pytest.raises(ValueError, match="must be on the same device"): + MeshLSQGradient.dispatch( + points, + values, + offsets.to(other_device), + indices, + implementation="warp", + ) diff --git a/test/nn/functional/derivatives/test_meshless_finite_difference.py b/test/nn/functional/derivatives/test_meshless_finite_difference.py new file mode 100644 index 0000000000..5f57d50bb5 --- /dev/null +++ b/test/nn/functional/derivatives/test_meshless_finite_difference.py @@ -0,0 +1,385 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest +import torch + +from physicsnemo.nn.functional import meshless_fd_derivatives +from physicsnemo.nn.functional.derivatives import MeshlessFDDerivatives +from physicsnemo.nn.functional.derivatives.meshless_finite_difference._torch_impl import ( + meshless_fd_stencil_points_torch, +) + + +def _spacing_for_dim(dim: int) -> float | tuple[float, ...]: + """Return representative spacing values by dimensionality.""" + if dim == 1: + return 0.01 + if dim == 2: + return (0.01, 0.015) + return (0.01, 0.015, 0.02) + + +def _analytic_values(points: torch.Tensor) -> torch.Tensor: + """Evaluate a smooth two-channel scalar field at arbitrary points.""" + x = points[..., 0] + if points.shape[-1] == 1: + ch0 = torch.sin(2.0 * x) + 0.3 * x.square() + ch1 = torch.cos(1.5 * x) - 0.1 * x + return torch.stack((ch0, ch1), dim=-1) + + y = points[..., 1] + if points.shape[-1] == 2: + ch0 = torch.sin(1.4 * x) * torch.cos(0.7 * y) + 0.2 * x * y + ch1 = x.square() + y.pow(3) + return torch.stack((ch0, ch1), dim=-1) + + z = points[..., 2] + ch0 = torch.sin(1.2 * x) * torch.cos(0.8 * y) * torch.sin(0.6 * z) + 0.1 * x * y * z + ch1 = x.square() + 0.5 * y.square() - z + return torch.stack((ch0, ch1), dim=-1) + + +def _analytic_first_derivatives(points: torch.Tensor) -> torch.Tensor: + """Evaluate analytic first derivatives for the two-channel test field.""" + x = points[:, 0] + if points.shape[1] == 1: + dx = torch.stack( + (2.0 * torch.cos(2.0 * x) + 0.6 * x, -1.5 * torch.sin(1.5 * x) - 0.1), + dim=-1, + ) + return dx.unsqueeze(0) + + y = points[:, 1] + if points.shape[1] == 2: + dfdx = torch.stack( + ( + 1.4 * torch.cos(1.4 * x) * torch.cos(0.7 * y) + 0.2 * y, + 2.0 * x, + ), + dim=-1, + ) + dfdy = torch.stack( + ( + -0.7 * torch.sin(1.4 * x) * torch.sin(0.7 * y) + 0.2 * x, + 3.0 * y.square(), + ), + dim=-1, + ) + return torch.stack((dfdx, dfdy), dim=0) + + z = points[:, 2] + dfdx = torch.stack( + ( + 1.2 * torch.cos(1.2 * x) * torch.cos(0.8 * y) * torch.sin(0.6 * z) + + 0.1 * y * z, + 2.0 * x, + ), + dim=-1, + ) + dfdy = torch.stack( + ( + -0.8 * torch.sin(1.2 * x) * torch.sin(0.8 * y) * torch.sin(0.6 * z) + + 0.1 * x * z, + y, + ), + dim=-1, + ) + dfdz = torch.stack( + ( + 0.6 * torch.sin(1.2 * x) * torch.cos(0.8 * y) * torch.cos(0.6 * z) + + 0.1 * x * y, + -torch.ones_like(x), + ), + dim=-1, + ) + return torch.stack((dfdx, dfdy, dfdz), dim=0) + + +def _analytic_second_derivatives(points: torch.Tensor) -> torch.Tensor: + """Evaluate analytic second derivatives (pure then mixed axis pairs).""" + x = points[:, 0] + if points.shape[1] == 1: + d2xx = torch.stack( + ( + -4.0 * torch.sin(2.0 * x) + 0.6, + -2.25 * torch.cos(1.5 * x), + ), + dim=-1, + ) + return d2xx.unsqueeze(0) + + y = points[:, 1] + if points.shape[1] == 2: + d2xx = torch.stack( + ( + -1.96 * torch.sin(1.4 * x) * torch.cos(0.7 * y), + 2.0 * torch.ones_like(x), + ), + dim=-1, + ) + d2yy = torch.stack( + ( + -0.49 * torch.sin(1.4 * x) * torch.cos(0.7 * y), + 6.0 * y, + ), + dim=-1, + ) + d2xy = torch.stack( + ( + -0.98 * torch.cos(1.4 * x) * torch.sin(0.7 * y) + 0.2, + torch.zeros_like(x), + ), + dim=-1, + ) + return torch.stack((d2xx, d2yy, d2xy), dim=0) + + z = points[:, 2] + d2xx = torch.stack( + ( + -1.44 * torch.sin(1.2 * x) * torch.cos(0.8 * y) * torch.sin(0.6 * z), + 2.0 * torch.ones_like(x), + ), + dim=-1, + ) + d2yy = torch.stack( + ( + -0.64 * torch.sin(1.2 * x) * torch.cos(0.8 * y) * torch.sin(0.6 * z), + torch.ones_like(x), + ), + dim=-1, + ) + d2zz = torch.stack( + ( + -0.36 * torch.sin(1.2 * x) * torch.cos(0.8 * y) * torch.sin(0.6 * z), + torch.zeros_like(x), + ), + dim=-1, + ) + d2xy = torch.stack( + ( + -0.96 * torch.cos(1.2 * x) * torch.sin(0.8 * y) * torch.sin(0.6 * z) + + 0.1 * z, + torch.zeros_like(x), + ), + dim=-1, + ) + d2xz = torch.stack( + ( + 0.72 * torch.cos(1.2 * x) * torch.cos(0.8 * y) * torch.cos(0.6 * z) + + 0.1 * y, + torch.zeros_like(x), + ), + dim=-1, + ) + d2yz = torch.stack( + ( + -0.48 * torch.sin(1.2 * x) * torch.sin(0.8 * y) * torch.cos(0.6 * z) + + 0.1 * x, + torch.zeros_like(x), + ), + dim=-1, + ) + return torch.stack((d2xx, d2yy, d2zz, d2xy, d2xz, d2yz), dim=0) + + +# Validate stencil-point generation for representative dimensions. +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_meshless_fd_derivatives_stencil_points_torch(device: str, dim: int): + points = torch.rand(32, dim, device=device, dtype=torch.float32) + spacing = _spacing_for_dim(dim) + + stencil_points = meshless_fd_stencil_points_torch( + points, + spacing=spacing, + include_center=True, + ) + + assert stencil_points.shape == (32, 3**dim, dim) + center_index = (3**dim) // 2 + torch.testing.assert_close(stencil_points[:, center_index], points) + + +# Validate meshless first derivatives against analytic derivatives. +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_meshless_fd_derivatives_torch_first_order(device: str, dim: int): + points = torch.rand(128, dim, device=device, dtype=torch.float32) + spacing = _spacing_for_dim(dim) + stencil_points = meshless_fd_stencil_points_torch(points, spacing=spacing) + stencil_values = _analytic_values(stencil_points) + expected = _analytic_first_derivatives(points) + + output = MeshlessFDDerivatives.dispatch( + stencil_values, + spacing=spacing, + derivative_orders=1, + include_mixed=False, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=5e-3, rtol=5e-3) + + +# Validate meshless second derivatives, including mixed terms, against analytics. +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_meshless_fd_derivatives_torch_second_order(device: str, dim: int): + points = torch.rand(128, dim, device=device, dtype=torch.float32) + spacing = _spacing_for_dim(dim) + stencil_points = meshless_fd_stencil_points_torch(points, spacing=spacing) + stencil_values = _analytic_values(stencil_points) + expected = _analytic_second_derivatives(points) + + output = MeshlessFDDerivatives.dispatch( + stencil_values, + spacing=spacing, + derivative_orders=2, + include_mixed=(dim > 1), + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=7e-3, rtol=7e-3) + + +# Validate unified derivative-order requests concatenate outputs deterministically. +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_meshless_fd_derivatives_torch_combined_orders(device: str, dim: int): + points = torch.rand(128, dim, device=device, dtype=torch.float32) + spacing = _spacing_for_dim(dim) + stencil_points = meshless_fd_stencil_points_torch(points, spacing=spacing) + stencil_values = _analytic_values(stencil_points) + expected_first = _analytic_first_derivatives(points) + expected_second = _analytic_second_derivatives(points) + + output = MeshlessFDDerivatives.dispatch( + stencil_values, + spacing=spacing, + derivative_orders=(1, 2), + include_mixed=(dim > 1), + implementation="torch", + ) + expected = torch.cat((expected_first, expected_second), dim=0) + torch.testing.assert_close(output, expected, atol=7e-3, rtol=7e-3) + + +# Validate benchmark input generation contract for derivative forward inputs. +def test_meshless_fd_derivatives_make_inputs_forward(device: str): + label, args, kwargs = next( + iter(MeshlessFDDerivatives.make_inputs_forward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + output = MeshlessFDDerivatives.dispatch(*args, implementation="torch", **kwargs) + assert output.ndim in (2, 3) + + +# Validate benchmark input generation contract for derivative backward inputs. +def test_meshless_fd_derivatives_make_inputs_backward(device: str): + label, args, kwargs = next( + iter(MeshlessFDDerivatives.make_inputs_backward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + stencil_values = args[0] + assert stencil_values.requires_grad + + output = MeshlessFDDerivatives.dispatch(*args, implementation="torch", **kwargs) + output.square().mean().backward() + assert stencil_values.grad is not None + + +# Validate exported API and error handling branches for meshless FD functionals. +def test_meshless_fd_derivatives_error_handling(device: str): + points = torch.rand(16, 2, device=device, dtype=torch.float32) + stencil_points = meshless_fd_stencil_points_torch(points, spacing=(0.01, 0.02)) + assert stencil_points.shape == (16, 9, 2) + + values = _analytic_values(stencil_points) + derivs = meshless_fd_derivatives(values, spacing=(0.01, 0.02), derivative_orders=1) + assert derivs.shape == (2, 16, 2) + + with pytest.raises(ValueError, match="shape"): + meshless_fd_stencil_points_torch( + torch.rand(16, device=device, dtype=torch.float32) + ) + + with pytest.raises(TypeError, match="floating-point"): + meshless_fd_stencil_points_torch( + torch.ones(16, 2, device=device, dtype=torch.int32) + ) + + with pytest.raises(ValueError, match="must have 2 entries"): + meshless_fd_stencil_points_torch( + torch.rand(16, 2, device=device, dtype=torch.float32), + spacing=(0.1,), + ) + + with pytest.raises(ValueError, match="strictly positive"): + meshless_fd_stencil_points_torch( + torch.rand(16, 2, device=device, dtype=torch.float32), + spacing=(0.1, 0.0), + ) + + with pytest.raises(ValueError, match="must have shape"): + MeshlessFDDerivatives.dispatch( + torch.rand(16, device=device, dtype=torch.float32), + implementation="torch", + ) + + with pytest.raises(TypeError, match="floating-point"): + MeshlessFDDerivatives.dispatch( + torch.ones(16, 9, device=device, dtype=torch.int32), + implementation="torch", + ) + + with pytest.raises(ValueError, match="must be 3, 9, or 27"): + MeshlessFDDerivatives.dispatch( + torch.rand(16, 5, device=device, dtype=torch.float32), + implementation="torch", + ) + + with pytest.raises(ValueError, match="supports derivative orders"): + MeshlessFDDerivatives.dispatch( + torch.rand(16, 9, device=device, dtype=torch.float32), + derivative_orders=3, + implementation="torch", + ) + + with pytest.raises(ValueError, match="only valid when requesting 2nd derivatives"): + MeshlessFDDerivatives.dispatch( + torch.rand(16, 9, device=device, dtype=torch.float32), + derivative_orders=1, + include_mixed=True, + implementation="torch", + ) + + with pytest.raises(ValueError, match="must have 2 entries"): + MeshlessFDDerivatives.dispatch( + torch.rand(16, 9, device=device, dtype=torch.float32), + spacing=(0.1,), + implementation="torch", + ) + + with pytest.raises(ValueError, match="mixed derivatives require at least 2D"): + MeshlessFDDerivatives.dispatch( + torch.rand(16, 3, device=device, dtype=torch.float32), + derivative_orders=2, + include_mixed=True, + implementation="torch", + ) diff --git a/test/nn/functional/derivatives/test_rectilinear_grid_gradient.py b/test/nn/functional/derivatives/test_rectilinear_grid_gradient.py new file mode 100644 index 0000000000..5b4be10d63 --- /dev/null +++ b/test/nn/functional/derivatives/test_rectilinear_grid_gradient.py @@ -0,0 +1,475 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.nn.functional import rectilinear_grid_gradient +from physicsnemo.nn.functional.derivatives import RectilinearGridGradient +from test.conftest import requires_module +from test.nn.functional._parity_utils import clone_case + + +# Build analytic periodic fields on nonuniform rectilinear coordinates. +def _make_periodic_case(device: str, dims: int, derivative_order: int): + torch_device = torch.device(device) + amp0 = 0.04 if derivative_order == 1 else 0.0 + amp1 = 0.03 if derivative_order == 1 else 0.0 + amp2 = 0.02 if derivative_order == 1 else 0.0 + + if dims == 1: + n0 = 1024 + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=torch_device)[:-1] + x0 = s0 + amp0 * torch.sin(2.0 * torch.pi * s0) + field = torch.sin(2.0 * torch.pi * x0) + if derivative_order == 1: + expected = (2.0 * torch.pi) * torch.cos(2.0 * torch.pi * x0).unsqueeze(0) + else: + expected = ( + -((2.0 * torch.pi) ** 2) * torch.sin(2.0 * torch.pi * x0) + ).unsqueeze(0) + return field, (x0.to(torch.float32),), 1.0, expected + + if dims == 2: + n0, n1 = 320, 256 + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=torch_device)[:-1] + s1 = torch.linspace(0.0, 1.0, n1 + 1, device=torch_device)[:-1] + x0 = s0 + amp0 * torch.sin(2.0 * torch.pi * s0) + x1 = s1 + amp1 * torch.sin(2.0 * torch.pi * s1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + field = torch.sin(2.0 * torch.pi * xx) + 0.5 * torch.cos(2.0 * torch.pi * yy) + if derivative_order == 1: + deriv_x = (2.0 * torch.pi) * torch.cos(2.0 * torch.pi * xx) + deriv_y = -1.0 * torch.pi * torch.sin(2.0 * torch.pi * yy) + else: + deriv_x = -((2.0 * torch.pi) ** 2) * torch.sin(2.0 * torch.pi * xx) + deriv_y = -2.0 * (torch.pi**2) * torch.cos(2.0 * torch.pi * yy) + expected = torch.stack((deriv_x, deriv_y), dim=0) + return field, (x0.to(torch.float32), x1.to(torch.float32)), (1.0, 1.0), expected + + n0, n1, n2 = 120, 96, 80 + s0 = torch.linspace(0.0, 1.0, n0 + 1, device=torch_device)[:-1] + s1 = torch.linspace(0.0, 1.0, n1 + 1, device=torch_device)[:-1] + s2 = torch.linspace(0.0, 1.0, n2 + 1, device=torch_device)[:-1] + x0 = s0 + amp0 * torch.sin(2.0 * torch.pi * s0) + x1 = s1 + amp1 * torch.sin(2.0 * torch.pi * s1) + x2 = s2 + amp2 * torch.sin(2.0 * torch.pi * s2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + field = ( + torch.sin(2.0 * torch.pi * xx) + + 0.5 * torch.cos(2.0 * torch.pi * yy) + + 0.25 * torch.sin(2.0 * torch.pi * zz) + ) + if derivative_order == 1: + deriv_x = (2.0 * torch.pi) * torch.cos(2.0 * torch.pi * xx) + deriv_y = -1.0 * torch.pi * torch.sin(2.0 * torch.pi * yy) + deriv_z = 0.5 * torch.pi * torch.cos(2.0 * torch.pi * zz) + else: + deriv_x = -((2.0 * torch.pi) ** 2) * torch.sin(2.0 * torch.pi * xx) + deriv_y = -2.0 * (torch.pi**2) * torch.cos(2.0 * torch.pi * yy) + deriv_z = -(torch.pi**2) * torch.sin(2.0 * torch.pi * zz) + expected = torch.stack((deriv_x, deriv_y, deriv_z), dim=0) + return ( + field, + (x0.to(torch.float32), x1.to(torch.float32), x2.to(torch.float32)), + (1.0, 1.0, 1.0), + expected, + ) + + +# Validate torch backend against analytic periodic derivatives. +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize("derivative_order", [1, 2]) +def test_rectilinear_grid_gradient_torch(device: str, dims: int, derivative_order: int): + field, coordinates, periods, expected = _make_periodic_case( + device, dims, derivative_order + ) + output = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=derivative_order, + implementation="torch", + ) + atol, rtol = (6e-1, 1e-1) if derivative_order == 2 and dims == 1 else (3e-2, 3e-2) + torch.testing.assert_close(output, expected, atol=atol, rtol=rtol) + + +# Validate unified derivative-order requests concatenate outputs deterministically. +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_rectilinear_grid_gradient_torch_combined_orders(device: str, dims: int): + field, coordinates, periods, _expected_first = _make_periodic_case( + device, dims, derivative_order=1 + ) + first_only = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=1, + include_mixed=False, + implementation="torch", + ) + second_only = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=2, + include_mixed=False, + implementation="torch", + ) + output = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=(1, 2), + include_mixed=False, + implementation="torch", + ) + expected = torch.cat((first_only, second_only), dim=0) + torch.testing.assert_close(output, expected, atol=3e-2, rtol=3e-2) + + +# Validate mixed second derivatives are available through unified API. +@pytest.mark.parametrize("dims", [2, 3]) +def test_rectilinear_grid_gradient_torch_second_order_mixed(device: str, dims: int): + field, coordinates, periods, _ = _make_periodic_case( + device, dims, derivative_order=2 + ) + output = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=2, + include_mixed=True, + implementation="torch", + ) + expected_count = dims + (dims * (dims - 1)) // 2 + assert output.shape[0] == expected_count + + +# Validate warp backend against analytic periodic derivatives. +@requires_module("warp") +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize("derivative_order", [1, 2]) +def test_rectilinear_grid_gradient_warp(device: str, dims: int, derivative_order: int): + field, coordinates, periods, expected = _make_periodic_case( + device, dims, derivative_order + ) + output = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=derivative_order, + implementation="warp", + ) + atol, rtol = (6e-1, 1e-1) if derivative_order == 2 and dims == 1 else (4e-2, 4e-2) + torch.testing.assert_close(output, expected, atol=atol, rtol=rtol) + + +# Validate warp backend forward parity against torch across benchmark cases. +@requires_module("warp") +def test_rectilinear_grid_gradient_backend_forward_parity(device: str): + for _label, args, kwargs in RectilinearGridGradient.make_inputs_forward( + device=device + ): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = RectilinearGridGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_warp = RectilinearGridGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + RectilinearGridGradient.compare_forward(out_warp, out_torch) + + +# Validate warp fused first+second path parity against torch. +@requires_module("warp") +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_rectilinear_grid_gradient_backend_forward_combined_orders( + device: str, dims: int +): + field, coordinates, periods, _ = _make_periodic_case( + device, dims, derivative_order=1 + ) + out_torch = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=(1, 2), + include_mixed=False, + implementation="torch", + ) + out_warp = RectilinearGridGradient.dispatch( + field.to(torch.float32), + coordinates, + periods=periods, + derivative_orders=(1, 2), + include_mixed=False, + implementation="warp", + ) + RectilinearGridGradient.compare_forward(out_warp, out_torch) + + +# Validate warp backend backward parity against torch. +@requires_module("warp") +def test_rectilinear_grid_gradient_backend_backward_parity(device: str): + for _label, args, kwargs in RectilinearGridGradient.make_inputs_backward( + device=device + ): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = RectilinearGridGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + grad_seed = torch.randn_like(out_torch) + grad_torch = torch.autograd.grad( + outputs=out_torch, + inputs=args_torch[0], + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + assert grad_torch is not None + + out_warp = RectilinearGridGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + grad_warp = torch.autograd.grad( + outputs=out_warp, + inputs=args_warp[0], + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + assert grad_warp is not None + + RectilinearGridGradient.compare_backward(grad_warp, grad_torch) + + +# Validate warp backend backward parity for fused combined-order requests. +@requires_module("warp") +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_rectilinear_grid_gradient_backend_backward_combined_orders( + device: str, dims: int +): + field, coordinates, periods, _ = _make_periodic_case( + device, dims, derivative_order=1 + ) + + field_torch = field.to(torch.float32).detach().clone().requires_grad_(True) + field_warp = field.to(torch.float32).detach().clone().requires_grad_(True) + + out_torch = RectilinearGridGradient.dispatch( + field_torch, + coordinates, + periods=periods, + derivative_orders=(1, 2), + include_mixed=False, + implementation="torch", + ) + out_warp = RectilinearGridGradient.dispatch( + field_warp, + coordinates, + periods=periods, + derivative_orders=(1, 2), + include_mixed=False, + implementation="warp", + ) + + grad_seed = torch.randn_like(out_torch) + grad_torch = torch.autograd.grad( + outputs=out_torch, + inputs=field_torch, + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + grad_warp = torch.autograd.grad( + outputs=out_warp, + inputs=field_warp, + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + + assert grad_torch is not None + assert grad_warp is not None + torch.testing.assert_close(grad_warp, grad_torch, atol=7e-2, rtol=7e-2) + + +# Validate benchmark input generation contract for forward inputs. +def test_rectilinear_grid_gradient_make_inputs_forward(device: str): + label, args, kwargs = next( + iter(RectilinearGridGradient.make_inputs_forward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + field, coordinates = args + assert field.ndim in (1, 2, 3) + assert len(coordinates) == field.ndim + + output = RectilinearGridGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + assert output.shape[0] == field.ndim + + +# Validate benchmark input generation contract for backward inputs. +def test_rectilinear_grid_gradient_make_inputs_backward(device: str): + label, args, kwargs = next( + iter(RectilinearGridGradient.make_inputs_backward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + field = args[0] + assert field.requires_grad + + output = RectilinearGridGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + output.square().mean().backward() + assert field.grad is not None + + +# Validate compare-forward hook contract. +def test_rectilinear_grid_gradient_compare_forward_contract(device: str): + _label, args, kwargs = next( + iter(RectilinearGridGradient.make_inputs_forward(device=device)) + ) + output = RectilinearGridGradient.dispatch(*args, implementation="torch", **kwargs) + reference = output.detach().clone() + RectilinearGridGradient.compare_forward(output, reference) + + +# Validate compare-backward hook contract. +def test_rectilinear_grid_gradient_compare_backward_contract(device: str): + _label, args, kwargs = next( + iter(RectilinearGridGradient.make_inputs_backward(device=device)) + ) + field = args[0] + + output = RectilinearGridGradient.dispatch(*args, implementation="torch", **kwargs) + output.square().mean().backward() + + assert field.grad is not None + RectilinearGridGradient.compare_backward(field.grad, field.grad.detach().clone()) + + +# Validate exported API and input validation paths. +def test_rectilinear_grid_gradient_error_handling(device: str): + x = torch.linspace(0.0, 1.0, 17, device=device)[:-1] + field = torch.sin(2.0 * torch.pi * x).to(torch.float32) + + output = rectilinear_grid_gradient(field, (x.to(torch.float32),), periods=1.0) + assert output.shape == (1, 16) + + with pytest.raises(ValueError, match="supports 1D-3D fields"): + RectilinearGridGradient.dispatch( + torch.randn(2, 2, 2, 2, device=device, dtype=torch.float32), + (x, x, x, x), + periods=1.0, + implementation="torch", + ) + + with pytest.raises(ValueError, match="must contain one axis tensor"): + RectilinearGridGradient.dispatch( + torch.randn(32, 32, device=device, dtype=torch.float32), + (torch.linspace(0.0, 1.0, 32, device=device),), + periods=1.0, + implementation="torch", + ) + + with pytest.raises(ValueError, match="strictly increasing"): + bad_x = torch.tensor([0.0, 0.3, 0.2, 0.8], device=device, dtype=torch.float32) + bad_f = torch.randn(4, device=device, dtype=torch.float32) + RectilinearGridGradient.dispatch( + bad_f, + (bad_x,), + periods=1.0, + implementation="torch", + ) + + with pytest.raises(ValueError, match="must be larger than coordinate span"): + RectilinearGridGradient.dispatch( + torch.randn(16, device=device, dtype=torch.float32), + (torch.linspace(0.0, 1.0, 16, device=device, dtype=torch.float32),), + periods=0.8, + implementation="torch", + ) + + with pytest.raises(ValueError, match="supports derivative orders"): + RectilinearGridGradient.dispatch( + torch.randn(16, device=device, dtype=torch.float32), + (torch.linspace(0.0, 1.0, 16, device=device, dtype=torch.float32),), + periods=1.0, + derivative_orders=3, + implementation="torch", + ) + + with pytest.raises(TypeError, match="include_mixed must be a bool"): + RectilinearGridGradient.dispatch( + torch.randn(16, device=device, dtype=torch.float32), + (torch.linspace(0.0, 1.0, 16, device=device, dtype=torch.float32),), + periods=1.0, + derivative_orders=2, + include_mixed=1, # type: ignore[arg-type] + implementation="torch", + ) + + with pytest.raises(ValueError, match="only valid when requesting 2nd derivatives"): + RectilinearGridGradient.dispatch( + torch.randn(16, device=device, dtype=torch.float32), + (torch.linspace(0.0, 1.0, 16, device=device, dtype=torch.float32),), + periods=1.0, + derivative_orders=1, + include_mixed=True, + implementation="torch", + ) + + with pytest.raises( + ValueError, match="mixed derivatives require at least 2D inputs" + ): + RectilinearGridGradient.dispatch( + torch.randn(16, device=device, dtype=torch.float32), + (torch.linspace(0.0, 1.0, 16, device=device, dtype=torch.float32),), + periods=1.0, + derivative_orders=2, + include_mixed=True, + implementation="torch", + ) diff --git a/test/nn/functional/derivatives/test_spectral_grid_gradient.py b/test/nn/functional/derivatives/test_spectral_grid_gradient.py new file mode 100644 index 0000000000..54df34d0bd --- /dev/null +++ b/test/nn/functional/derivatives/test_spectral_grid_gradient.py @@ -0,0 +1,244 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pytest +import torch + +from physicsnemo.nn.functional import spectral_grid_gradient +from physicsnemo.nn.functional.derivatives import SpectralGridGradient + + +def _make_periodic_test_case( + device: str, + dim: int, +) -> tuple[torch.Tensor, tuple[float, ...], torch.Tensor, torch.Tensor]: + """Build periodic fields with analytic first and second derivatives.""" + torch_device = torch.device(device) + + if dim == 1: + n0 = 512 + l0 = 2.0 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float64) * (l0 / n0) + k0 = 2.0 * torch.pi / l0 + + field = torch.sin(k0 * x0) + 0.25 * torch.cos(2.0 * k0 * x0) + first = torch.stack( + [ + k0 * torch.cos(k0 * x0) - 0.5 * k0 * torch.sin(2.0 * k0 * x0), + ], + dim=0, + ) + second = torch.stack( + [ + -(k0 * k0) * torch.sin(k0 * x0) - (k0 * k0) * torch.cos(2.0 * k0 * x0), + ], + dim=0, + ) + return field, (l0,), first, second + + if dim == 2: + n0, n1 = 256, 224 + l0, l1 = 2.0, 1.5 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float64) * (l0 / n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float64) * (l1 / n1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + k0 = 2.0 * torch.pi / l0 + k1 = 2.0 * torch.pi / l1 + + ax = k0 * xx + 0.3 + by = k1 * yy - 0.2 + field = torch.sin(ax) * torch.cos(by) + + dfdx = k0 * torch.cos(ax) * torch.cos(by) + dfdy = -k1 * torch.sin(ax) * torch.sin(by) + d2fdx2 = -(k0 * k0) * torch.sin(ax) * torch.cos(by) + d2fdy2 = -(k1 * k1) * torch.sin(ax) * torch.cos(by) + d2fdxdy = -k0 * k1 * torch.cos(ax) * torch.sin(by) + + first = torch.stack((dfdx, dfdy), dim=0) + second = torch.stack((d2fdx2, d2fdy2, d2fdxdy), dim=0) + return field, (l0, l1), first, second + + n0, n1, n2 = 128, 112, 96 + l0, l1, l2 = 2.0, 1.5, 1.25 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float64) * (l0 / n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float64) * (l1 / n1) + x2 = torch.arange(n2, device=torch_device, dtype=torch.float64) * (l2 / n2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + k0 = 2.0 * torch.pi / l0 + k1 = 2.0 * torch.pi / l1 + k2 = 2.0 * torch.pi / l2 + + ax = k0 * xx + 0.2 + by = k1 * yy - 0.4 + cz = k2 * zz + 0.1 + field = torch.sin(ax) * torch.cos(by) * torch.sin(cz) + + dfdx = k0 * torch.cos(ax) * torch.cos(by) * torch.sin(cz) + dfdy = -k1 * torch.sin(ax) * torch.sin(by) * torch.sin(cz) + dfdz = k2 * torch.sin(ax) * torch.cos(by) * torch.cos(cz) + + d2fdx2 = -(k0 * k0) * torch.sin(ax) * torch.cos(by) * torch.sin(cz) + d2fdy2 = -(k1 * k1) * torch.sin(ax) * torch.cos(by) * torch.sin(cz) + d2fdz2 = -(k2 * k2) * torch.sin(ax) * torch.cos(by) * torch.sin(cz) + d2fdxdy = -k0 * k1 * torch.cos(ax) * torch.sin(by) * torch.sin(cz) + d2fdxdz = k0 * k2 * torch.cos(ax) * torch.cos(by) * torch.cos(cz) + d2fdydz = -k1 * k2 * torch.sin(ax) * torch.sin(by) * torch.cos(cz) + + first = torch.stack((dfdx, dfdy, dfdz), dim=0) + second = torch.stack((d2fdx2, d2fdy2, d2fdz2, d2fdxdy, d2fdxdz, d2fdydz), dim=0) + return field, (l0, l1, l2), first, second + + +# Validate torch backend first-order derivatives against analytic periodic fields. +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_spectral_grid_gradient_torch_first_order(device: str, dim: int): + field, lengths, first_expected, _ = _make_periodic_test_case(device=device, dim=dim) + output = SpectralGridGradient.dispatch( + field, + lengths=lengths, + derivative_orders=1, + include_mixed=False, + implementation="torch", + ) + torch.testing.assert_close(output, first_expected, atol=1e-4, rtol=1e-4) + + +# Validate torch backend second-order derivatives against analytic periodic fields. +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_spectral_grid_gradient_torch_second_order(device: str, dim: int): + field, lengths, _, second_expected = _make_periodic_test_case( + device=device, dim=dim + ) + output = SpectralGridGradient.dispatch( + field, + lengths=lengths, + derivative_orders=2, + include_mixed=(dim > 1), + implementation="torch", + ) + torch.testing.assert_close(output, second_expected, atol=1e-4, rtol=1e-4) + + +# Validate unified derivative-order requests concatenate outputs deterministically. +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_spectral_grid_gradient_torch_combined_orders(device: str, dim: int): + field, lengths, first_expected, second_expected = _make_periodic_test_case( + device=device, dim=dim + ) + output = SpectralGridGradient.dispatch( + field, + lengths=lengths, + derivative_orders=(1, 2), + include_mixed=(dim > 1), + implementation="torch", + ) + expected = torch.cat((first_expected, second_expected), dim=0) + torch.testing.assert_close(output, expected, atol=1e-4, rtol=1e-4) + + +# Validate benchmark input generation contract for forward inputs. +def test_spectral_grid_gradient_make_inputs_forward(device: str): + label, args, kwargs = next( + iter(SpectralGridGradient.make_inputs_forward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + field = args[0] + output = SpectralGridGradient.dispatch(*args, implementation="torch", **kwargs) + expected_count = field.ndim + if kwargs["derivative_orders"] == 2 and kwargs["include_mixed"]: + expected_count += (field.ndim * (field.ndim - 1)) // 2 + assert output.shape[0] == expected_count + + +# Validate benchmark input generation contract for backward inputs. +def test_spectral_grid_gradient_make_inputs_backward(device: str): + label, args, kwargs = next( + iter(SpectralGridGradient.make_inputs_backward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + field = args[0] + assert field.requires_grad + + output = SpectralGridGradient.dispatch(*args, implementation="torch", **kwargs) + output.square().mean().backward() + assert field.grad is not None + + +# Validate exported API and validation error paths. +def test_spectral_grid_gradient_error_handling(device: str): + field = torch.randn(64, device=device, dtype=torch.float32) + output = spectral_grid_gradient(field, lengths=2.0, derivative_orders=1) + assert output.shape == (1, 64) + + with pytest.raises(ValueError, match="supports 1D-3D fields"): + SpectralGridGradient.dispatch( + torch.randn(4, 4, 4, 4, device=device, dtype=torch.float32), + implementation="torch", + ) + + with pytest.raises(TypeError, match="floating-point"): + SpectralGridGradient.dispatch( + torch.ones(64, device=device, dtype=torch.int64), + implementation="torch", + ) + + with pytest.raises(ValueError, match="supports derivative orders"): + SpectralGridGradient.dispatch( + field, + derivative_orders=3, + implementation="torch", + ) + + with pytest.raises(ValueError, match="only valid when requesting 2nd derivatives"): + SpectralGridGradient.dispatch( + field, + derivative_orders=1, + include_mixed=True, + implementation="torch", + ) + + with pytest.raises( + ValueError, match="mixed derivatives require at least 2D inputs" + ): + SpectralGridGradient.dispatch( + field, + derivative_orders=2, + include_mixed=True, + implementation="torch", + ) + + with pytest.raises(ValueError, match="must have 2 entries"): + SpectralGridGradient.dispatch( + torch.randn(32, 32, device=device, dtype=torch.float32), + lengths=(1.0,), + implementation="torch", + ) + + with pytest.raises(ValueError, match="strictly positive"): + SpectralGridGradient.dispatch( + torch.randn(32, 32, device=device, dtype=torch.float32), + lengths=(1.0, 0.0), + implementation="torch", + ) diff --git a/test/nn/functional/derivatives/test_uniform_grid_gradient.py b/test/nn/functional/derivatives/test_uniform_grid_gradient.py new file mode 100644 index 0000000000..7392862042 --- /dev/null +++ b/test/nn/functional/derivatives/test_uniform_grid_gradient.py @@ -0,0 +1,532 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.nn.functional import uniform_grid_gradient +from physicsnemo.nn.functional.derivatives import UniformGridGradient +from physicsnemo.nn.functional.derivatives.uniform_grid_gradient.uniform_grid_gradient import ( + _AUTO_3D_TORCH_MAX_NUMEL, + _auto_select_implementation, +) +from test.conftest import requires_module +from test.nn.functional._parity_utils import clone_case + + +# Build periodic analytic fields for derivative correctness checks. +def _make_periodic_field(device: str, dims: int, derivative_order: int): + torch_device = torch.device(device) + + if dims == 1: + n0 = 512 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + field = torch.sin(2.0 * torch.pi * x0) + spacing = 1.0 / float(n0) + if derivative_order == 1: + expected = (2.0 * torch.pi) * torch.cos(2.0 * torch.pi * x0).unsqueeze(0) + else: + expected = ( + -((2.0 * torch.pi) ** 2) * torch.sin(2.0 * torch.pi * x0) + ).unsqueeze(0) + return field, spacing, expected + + if dims == 2: + n0, n1 = 192, 160 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + field = torch.sin(2.0 * torch.pi * xx) + 0.5 * torch.cos(4.0 * torch.pi * yy) + spacing = (1.0 / float(n0), 1.0 / float(n1)) + if derivative_order == 1: + deriv_x = (2.0 * torch.pi) * torch.cos(2.0 * torch.pi * xx) + deriv_y = -2.0 * torch.pi * torch.sin(4.0 * torch.pi * yy) + else: + deriv_x = -((2.0 * torch.pi) ** 2) * torch.sin(2.0 * torch.pi * xx) + deriv_y = -8.0 * (torch.pi**2) * torch.cos(4.0 * torch.pi * yy) + expected = torch.stack((deriv_x, deriv_y), dim=0) + return field, spacing, expected + + n0, n1, n2 = 80, 72, 64 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + x2 = torch.arange(n2, device=torch_device, dtype=torch.float32) / float(n2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + field = ( + torch.sin(2.0 * torch.pi * xx) + + 0.5 * torch.cos(2.0 * torch.pi * yy) + + 0.25 * torch.sin(4.0 * torch.pi * zz) + ) + spacing = (1.0 / float(n0), 1.0 / float(n1), 1.0 / float(n2)) + if derivative_order == 1: + deriv_x = (2.0 * torch.pi) * torch.cos(2.0 * torch.pi * xx) + deriv_y = -1.0 * torch.pi * torch.sin(2.0 * torch.pi * yy) + deriv_z = 1.0 * torch.pi * torch.cos(4.0 * torch.pi * zz) + else: + deriv_x = -((2.0 * torch.pi) ** 2) * torch.sin(2.0 * torch.pi * xx) + deriv_y = -2.0 * (torch.pi**2) * torch.cos(2.0 * torch.pi * yy) + deriv_z = -4.0 * (torch.pi**2) * torch.sin(4.0 * torch.pi * zz) + expected = torch.stack((deriv_x, deriv_y, deriv_z), dim=0) + return field, spacing, expected + + +# Validate torch backend against analytic periodic derivatives. +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize("derivative_order", [1, 2]) +@pytest.mark.parametrize("order", [2, 4]) +def test_uniform_grid_gradient_torch( + device: str, dims: int, derivative_order: int, order: int +): + field, spacing, expected = _make_periodic_field(device, dims, derivative_order) + output = UniformGridGradient.dispatch( + field, + spacing=spacing, + order=order, + derivative_orders=derivative_order, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=5e-2, rtol=5e-2) + + +# Validate unified derivative-order requests concatenate outputs deterministically. +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_uniform_grid_gradient_torch_combined_orders(device: str, dims: int): + field, spacing, expected_first = _make_periodic_field( + device, dims, derivative_order=1 + ) + _, _, expected_second = _make_periodic_field(device, dims, derivative_order=2) + + output = UniformGridGradient.dispatch( + field, + spacing=spacing, + order=2, + derivative_orders=(1, 2), + include_mixed=False, + implementation="torch", + ) + expected = torch.cat((expected_first, expected_second), dim=0) + torch.testing.assert_close(output, expected, atol=5e-2, rtol=5e-2) + + +# Validate mixed second derivatives are available through unified API. +@pytest.mark.parametrize("dims", [2, 3]) +def test_uniform_grid_gradient_torch_second_order_mixed(device: str, dims: int): + field, spacing, _expected = _make_periodic_field(device, dims, derivative_order=2) + output = UniformGridGradient.dispatch( + field, + spacing=spacing, + order=2, + derivative_orders=2, + include_mixed=True, + implementation="torch", + ) + expected_count = dims + (dims * (dims - 1)) // 2 + assert output.shape[0] == expected_count + + +# Validate higher-order stencil improves analytic error for smooth fields. +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_uniform_grid_gradient_torch_order4_more_accurate(device: str, dims: int): + field, spacing, expected = _make_periodic_field(device, dims, derivative_order=1) + out_o2 = UniformGridGradient.dispatch( + field, + spacing=spacing, + order=2, + derivative_orders=1, + implementation="torch", + ) + out_o4 = UniformGridGradient.dispatch( + field, + spacing=spacing, + order=4, + derivative_orders=1, + implementation="torch", + ) + + err_o2 = torch.linalg.vector_norm((out_o2 - expected).reshape(-1)).item() + err_o4 = torch.linalg.vector_norm((out_o4 - expected).reshape(-1)).item() + assert err_o4 < err_o2 + + +# Validate warp backend forward parity against torch for representative cases. +@requires_module("warp") +def test_uniform_grid_gradient_backend_forward_parity(device: str): + for _label, args, kwargs in UniformGridGradient.make_inputs_forward(device=device): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_warp = UniformGridGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + UniformGridGradient.compare_forward(out_warp, out_torch) + + +# Validate warp backend against analytic periodic derivatives. +@requires_module("warp") +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize("derivative_order", [1, 2]) +def test_uniform_grid_gradient_warp(device: str, dims: int, derivative_order: int): + field, spacing, expected = _make_periodic_field(device, dims, derivative_order) + output = UniformGridGradient.dispatch( + field, + spacing=spacing, + order=2, + derivative_orders=derivative_order, + implementation="warp", + ) + torch.testing.assert_close(output, expected, atol=7e-2, rtol=7e-2) + + +# Validate warp backend backward parity against torch for representative workloads. +@requires_module("warp") +def test_uniform_grid_gradient_backend_backward_parity(device: str): + for _label, args, kwargs in UniformGridGradient.make_inputs_backward(device=device): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridGradient.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_torch.square().mean().backward() + grad_torch = args_torch[0].grad + assert grad_torch is not None + + out_warp = UniformGridGradient.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + out_warp.square().mean().backward() + grad_warp = args_warp[0].grad + assert grad_warp is not None + + UniformGridGradient.compare_backward(grad_warp, grad_torch) + + +# Validate warp backend backward parity for fused combined-order requests. +@requires_module("warp") +@pytest.mark.parametrize("dims", [1, 2, 3]) +def test_uniform_grid_gradient_backend_backward_combined_orders(device: str, dims: int): + field, spacing, _ = _make_periodic_field(device, dims, derivative_order=1) + + field_torch = field.to(torch.float32).detach().clone().requires_grad_(True) + field_warp = field.to(torch.float32).detach().clone().requires_grad_(True) + + out_torch = UniformGridGradient.dispatch( + field_torch, + spacing=spacing, + order=2, + derivative_orders=(1, 2), + include_mixed=False, + implementation="torch", + ) + out_warp = UniformGridGradient.dispatch( + field_warp, + spacing=spacing, + order=2, + derivative_orders=(1, 2), + include_mixed=False, + implementation="warp", + ) + + grad_seed = torch.randn_like(out_torch) + grad_torch = torch.autograd.grad( + outputs=out_torch, + inputs=field_torch, + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + grad_warp = torch.autograd.grad( + outputs=out_warp, + inputs=field_warp, + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + + assert grad_torch is not None + assert grad_warp is not None + torch.testing.assert_close(grad_warp, grad_torch, atol=7e-2, rtol=7e-2) + + +# Validate warp backend mixed second-derivative backward parity against torch. +@requires_module("warp") +@pytest.mark.parametrize("dims", [2, 3]) +def test_uniform_grid_gradient_backend_backward_mixed_orders(device: str, dims: int): + field, spacing, _ = _make_periodic_field(device, dims, derivative_order=1) + + field_torch = field.to(torch.float32).detach().clone().requires_grad_(True) + field_warp = field.to(torch.float32).detach().clone().requires_grad_(True) + + out_torch = UniformGridGradient.dispatch( + field_torch, + spacing=spacing, + order=2, + derivative_orders=2, + include_mixed=True, + implementation="torch", + ) + out_warp = UniformGridGradient.dispatch( + field_warp, + spacing=spacing, + order=2, + derivative_orders=2, + include_mixed=True, + implementation="warp", + ) + + grad_seed = torch.randn_like(out_torch) + grad_torch = torch.autograd.grad( + outputs=out_torch, + inputs=field_torch, + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + grad_warp = torch.autograd.grad( + outputs=out_warp, + inputs=field_warp, + grad_outputs=grad_seed, + create_graph=False, + retain_graph=False, + allow_unused=False, + )[0] + + assert grad_torch is not None + assert grad_warp is not None + torch.testing.assert_close(grad_warp, grad_torch, atol=7e-2, rtol=7e-2) + + +# Validate benchmark input generation contract for forward inputs. +def test_uniform_grid_gradient_make_inputs_forward(device: str): + label, args, kwargs = next( + iter(UniformGridGradient.make_inputs_forward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + field = args[0] + assert field.ndim in (1, 2, 3) + assert torch.is_floating_point(field) + + output = UniformGridGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + assert output.shape[0] == field.ndim + + +# Validate benchmark input generation contract for backward inputs. +def test_uniform_grid_gradient_make_inputs_backward(device: str): + label, args, kwargs = next( + iter(UniformGridGradient.make_inputs_backward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + field = args[0] + assert field.requires_grad + + output = UniformGridGradient.dispatch( + *args, + implementation="torch", + **kwargs, + ) + output.square().mean().backward() + assert field.grad is not None + + +# Validate compare-forward hook contract. +def test_uniform_grid_gradient_compare_forward_contract(device: str): + _label, args, kwargs = next( + iter(UniformGridGradient.make_inputs_forward(device=device)) + ) + output = UniformGridGradient.dispatch(*args, implementation="torch", **kwargs) + reference = output.detach().clone() + UniformGridGradient.compare_forward(output, reference) + + +# Validate compare-backward hook contract. +def test_uniform_grid_gradient_compare_backward_contract(device: str): + _label, args, kwargs = next( + iter(UniformGridGradient.make_inputs_backward(device=device)) + ) + field = args[0] + + output = UniformGridGradient.dispatch(*args, implementation="torch", **kwargs) + output.square().mean().backward() + + assert field.grad is not None + UniformGridGradient.compare_backward(field.grad, field.grad.detach().clone()) + + +# Validate auto-dispatch default path matches explicit selected implementation. +def test_uniform_grid_gradient_dispatch_auto_matches_selected(device: str): + field = torch.randn(64, 64, device=device, dtype=torch.float32) + implementation = _auto_select_implementation(field) + + output_auto = uniform_grid_gradient( + field, + spacing=(1.0, 1.0), + derivative_orders=1, + include_mixed=False, + ) + output_explicit = UniformGridGradient.dispatch( + field, + spacing=(1.0, 1.0), + derivative_orders=1, + include_mixed=False, + implementation=implementation, + ) + torch.testing.assert_close(output_auto, output_explicit) + + +# Validate CUDA auto-dispatch heuristic structure across dimensions/sizes. +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_uniform_grid_gradient_dispatch_auto_heuristic_cuda(): + field_1d = torch.randn(4096, device="cuda", dtype=torch.float32) + assert _auto_select_implementation(field_1d) == "torch" + + field_2d = torch.randn(512, 512, device="cuda", dtype=torch.float32) + assert _auto_select_implementation(field_2d) == "torch" + + torch_n = int(round(_AUTO_3D_TORCH_MAX_NUMEL ** (1.0 / 3.0))) + field_3d_torch = torch.randn( + torch_n, + torch_n, + torch_n, + device="cuda", + dtype=torch.float32, + ) + assert field_3d_torch.numel() <= _AUTO_3D_TORCH_MAX_NUMEL + assert _auto_select_implementation(field_3d_torch) == "torch" + + large_n = torch_n + 1 + field_3d_large = torch.randn( + large_n, + large_n, + large_n, + device="cuda", + dtype=torch.float32, + ) + assert field_3d_large.numel() > _AUTO_3D_TORCH_MAX_NUMEL + assert _auto_select_implementation(field_3d_large) == "warp" + + field_grad = torch.randn( + 64, + 64, + 64, + device="cuda", + dtype=torch.float32, + requires_grad=True, + ) + assert _auto_select_implementation(field_grad) == "warp" + + +# Validate exported functional API and error handling paths. +def test_uniform_grid_gradient_error_handling(device: str): + field = torch.randn(16, device=device, dtype=torch.float32) + + output = uniform_grid_gradient(field, spacing=1.0) + assert output.shape == (1, 16) + assert output.dtype == torch.float32 + + with pytest.raises(ValueError, match="supports 1D-3D fields"): + UniformGridGradient.dispatch( + torch.randn(4, 4, 4, 4, device=device, dtype=torch.float32), + implementation="torch", + ) + + with pytest.raises(TypeError, match="floating-point"): + UniformGridGradient.dispatch( + torch.ones(8, device=device, dtype=torch.int32), + implementation="torch", + ) + + with pytest.raises(ValueError, match="spacing must have"): + UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + spacing=(1.0,), + implementation="torch", + ) + + with pytest.raises(ValueError, match="strictly positive"): + UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + spacing=(1.0, 0.0), + implementation="torch", + ) + + with pytest.raises(ValueError, match="supports"): + UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + order=6, + implementation="torch", + ) + + with pytest.raises(TypeError, match="integer"): + UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + order=2.0, # type: ignore[arg-type] + implementation="torch", + ) + + with pytest.raises(ValueError, match="supports derivative orders"): + UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + derivative_orders=3, + implementation="torch", + ) + + with pytest.raises(TypeError, match="include_mixed must be a bool"): + UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + derivative_orders=2, + include_mixed=1, # type: ignore[arg-type] + implementation="torch", + ) + + with pytest.raises(ValueError, match="only valid when requesting 2nd derivatives"): + UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + derivative_orders=1, + include_mixed=True, + implementation="torch", + ) + + out = UniformGridGradient.dispatch( + torch.randn(8, 8, device=device, dtype=torch.float32), + derivative_orders=2, + include_mixed=True, + implementation="torch", + ) + assert out.shape[0] == 3