Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ asv.conf.json @loliverhennigh

# Coding standards
CODING_STANDARDS/ @CharlelieLrt @coreyjadams @loliverhennigh
CODING_STANDARDS/FUNCTIONAL_APIS.md @loliverhennigh

# Interrogate "grandfathered exceptions" list
./test/ci_tests/interrogate_baseline.txt @ktangsali @coreyjadams @nickgeneva @peterdsharpe
Expand Down
56 changes: 56 additions & 0 deletions CODING_STANDARDS/FUNCTIONAL_APIS.md
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a CHANGELOG entry for this PR :)

Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ This document is structured in two main sections:
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs_forward`/`make_inputs_backward`/`compare_forward` |
| [`FNC-006`](#fnc-006-testing-functionals) | Testing functionals | Adding functional tests |
| [`FNC-007`](#fnc-007-benchmark-registry) | Benchmark registry | Adding a functional to ASV |
| [`FNC-008`](#fnc-008-warp-integration-must-use-torch-custom-ops) | Warp integration must use torch custom ops | Adding/refactoring Warp-backed functionals |

---

Expand Down Expand Up @@ -501,3 +502,58 @@ FUNCTIONAL_SPECS = (KNN, RadiusSearch)
# Adding a functional before input generators are implemented.
FUNCTIONAL_SPECS = (MyFunctionalWithoutInputs,)
```

---

### FNC-008: Warp integration must use torch custom ops

**Description:**

Warp-backed functionals in `physicsnemo/nn/functional/**` must be integrated
into PyTorch using `torch.library.custom_op`, `register_fake`, and (when
backward is supported) `register_autograd`. Do not use
`torch.autograd.Function` wrappers for Warp-backed functionals.

If a functional has no meaningful backward path, `register_autograd` is not
required; otherwise, the custom op must register a backward implementation.

**Rationale:**

`torch.library.custom_op` provides a consistent integration path for eager,
`torch.compile`, fake tensor propagation, and runtime dispatch behavior.
Avoiding per-functional `torch.autograd.Function` wrappers also keeps backend
integration uniform across functionals.

**Example:**

```python
@torch.library.custom_op("physicsnemo::my_warp_op", mutates_args=())
def my_warp_op_impl(x: torch.Tensor) -> torch.Tensor:
...
return y

@my_warp_op_impl.register_fake
def _my_warp_op_impl_fake(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)

def setup_my_warp_op_context(ctx, inputs, output):
...

def backward_my_warp_op(ctx, grad_output):
...
return grad_x

my_warp_op_impl.register_autograd(
backward_my_warp_op,
setup_context=setup_my_warp_op_context,
)
```

**Anti-pattern:**

```python
class _MyWarpAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
...
```
15 changes: 15 additions & 0 deletions benchmarks/physicsnemo/nn/functional/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@

"""Registry of FunctionSpec classes to benchmark with ASV."""

from physicsnemo.nn.functional.derivatives import (
MeshGreenGaussGradient,
MeshlessFDDerivatives,
MeshLSQGradient,
RectilinearGridGradient,
SpectralGridGradient,
UniformGridGradient,
)
from physicsnemo.nn.functional.fourier_spectral import (
IRFFT,
IRFFT2,
Expand All @@ -42,6 +50,13 @@
# Neighbor queries.
KNN,
RadiusSearch,
# Derivatives.
UniformGridGradient,
RectilinearGridGradient,
MeshLSQGradient,
MeshGreenGaussGradient,
SpectralGridGradient,
MeshlessFDDerivatives,
# Geometry.
SignedDistanceField,
# Interpolation.
Expand Down
34 changes: 34 additions & 0 deletions docs/api/nn/functionals/derivatives.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Derivative Functionals
======================

.. autofunction:: physicsnemo.nn.functional.uniform_grid_gradient

.. figure:: /img/nn/functional/derivatives/uniform_grid_gradient.png
:alt: Uniform grid gradient example
:width: 100%

.. autofunction:: physicsnemo.nn.functional.rectilinear_grid_gradient

.. figure:: /img/nn/functional/derivatives/rectilinear_grid_gradient.png
:alt: Rectilinear grid gradient example
:width: 100%

.. autofunction:: physicsnemo.nn.functional.mesh_lsq_gradient

.. figure:: /img/nn/functional/derivatives/mesh_lsq_gradient.png
:alt: Mesh LSQ gradient example
:width: 100%

.. autofunction:: physicsnemo.nn.functional.mesh_green_gauss_gradient

.. figure:: /img/nn/functional/derivatives/mesh_green_gauss_gradient.png
:alt: Mesh Green-Gauss gradient example
:width: 100%

.. autofunction:: physicsnemo.nn.functional.spectral_grid_gradient

.. figure:: /img/nn/functional/derivatives/spectral_grid_gradient.png
:alt: Spectral grid gradient example
:width: 100%

.. autofunction:: physicsnemo.nn.functional.meshless_fd_derivatives
1 change: 1 addition & 0 deletions docs/api/physicsnemo.nn.functionals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ in the documentation for performance comparisons.
:name: PhysicsNeMo Functionals

nn/functionals/neighbors
nn/functionals/derivatives
nn/functionals/geometry
nn/functionals/fourier_spectral
nn/functionals/regularization_parameterization
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
138 changes: 31 additions & 107 deletions physicsnemo/mesh/calculus/_lsq_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,109 +29,19 @@

import torch

from physicsnemo.mesh.utilities._tolerances import safe_eps

if TYPE_CHECKING:
from physicsnemo.mesh.mesh import Mesh


def _solve_batched_lsq_gradients(
positions: torch.Tensor, # shape: (n_entities, n_spatial_dims)
values: torch.Tensor, # shape: (n_entities, ...)
adjacency, # Adjacency object
weight_power: float,
min_neighbors: int = 0,
def _to_mesh_gradient_layout(
gradients: torch.Tensor,
values: torch.Tensor,
) -> torch.Tensor:
"""Core batched LSQ gradient solver (shared by point and cell versions).

For each entity (point or cell), solves a weighted least-squares problem:
min_{∇φ} Σ_neighbors w_i ||∇φ·(x_i - x_0) - (φ_i - φ_0)||²

Parameters
----------
positions : torch.Tensor
Entity positions (points or cell centroids)
values : torch.Tensor
Values at entities (scalars or tensor fields)
adjacency
Adjacency structure (entity-to-entity neighbors)
weight_power : float
Exponent for inverse distance weighting
min_neighbors : int
Minimum neighbors required for gradient computation

Returns
-------
torch.Tensor
Gradients at entities, shape (n_entities, n_spatial_dims) for scalars,
or (n_entities, n_spatial_dims, ...) for tensor fields.
Entities with insufficient neighbors have zero gradients.
"""
n_entities = len(positions)
n_spatial_dims = positions.shape[1]
device = positions.device
dtype = values.dtype

### Determine output shape
is_scalar = values.ndim == 1
if is_scalar:
gradient_shape = (n_entities, n_spatial_dims)
else:
gradient_shape = (n_entities, n_spatial_dims) + values.shape[1:]

gradients = torch.zeros(gradient_shape, dtype=dtype, device=device)

### Process each neighbor-count group in parallel
from physicsnemo.mesh.calculus._neighborhoods import iter_neighborhood_batches

for batch in iter_neighborhood_batches(
positions, adjacency, min_neighbors=min_neighbors
):
entity_indices = batch.entity_indices
neighbors_flat = batch.neighbor_indices
A = batch.relative_positions # (n_group, n_neighbors, n_spatial_dims)
n_group = len(entity_indices)
n_neighbors = batch.n_neighbors

### Entities with no neighbors retain their zero-initialized gradient
if n_neighbors == 0:
continue

### Function differences (b vector)
b = values[neighbors_flat] - values[entity_indices].unsqueeze(1)

### Compute inverse-distance weights
distances = torch.linalg.vector_norm(A, dim=-1) # (n_group, n_neighbors)
weights = 1.0 / distances.pow(weight_power).clamp(min=safe_eps(distances.dtype))

### Apply sqrt-weights to system
sqrt_w = weights.sqrt().unsqueeze(-1) # (n_group, n_neighbors, 1)
A_weighted = sqrt_w * A # (n_group, n_neighbors, n_spatial_dims)

### Solve batched least-squares
if is_scalar:
b_weighted = sqrt_w.squeeze(-1) * b # (n_group, n_neighbors)
solution = torch.linalg.lstsq(
A_weighted, b_weighted.unsqueeze(-1), rcond=None
).solution.squeeze(-1) # (n_group, n_spatial_dims)

gradients[entity_indices] = solution
else:
# Tensor field: flatten extra dims, solve, reshape back
b_weighted = sqrt_w * b # (n_group, n_neighbors, ...)
orig_shape = b.shape[2:]
b_flat = b_weighted.reshape(n_group, n_neighbors, -1)

solution = torch.linalg.lstsq(
A_weighted, b_flat, rcond=None
).solution # (n_group, n_spatial_dims, n_components)

solution_reshaped = solution.reshape(n_group, n_spatial_dims, *orig_shape)
# Permute spatial_dims to second position
perm = [0] + list(range(2, solution_reshaped.ndim)) + [1]
gradients[entity_indices] = solution_reshaped.permute(*perm)

return gradients
"""Convert functional layout ``(n, dims, ...)`` to mesh layout ``(n, ..., dims)``."""
if values.ndim == 1:
return gradients
perm = [0] + list(range(2, gradients.ndim)) + [1]
return gradients.permute(*perm)


def compute_point_gradient_lsq(
Expand Down Expand Up @@ -184,14 +94,21 @@ def compute_point_gradient_lsq(
### Get point-to-point adjacency
adjacency = mesh.get_point_to_points_adjacency()

### Use shared batched LSQ solver
return _solve_batched_lsq_gradients(
positions=mesh.points,
### Delegate LSQ solve to the functional API using the torch backend.
from physicsnemo.nn.functional.derivatives.mesh_lsq_gradient import (
mesh_lsq_gradient,
)

gradients = mesh_lsq_gradient(
points=mesh.points,
values=point_values,
adjacency=adjacency,
neighbor_offsets=adjacency.offsets,
neighbor_indices=adjacency.indices,
weight_power=weight_power,
min_neighbors=min_neighbors,
implementation="torch",
)
return _to_mesh_gradient_layout(gradients, point_values)


def compute_cell_gradient_lsq(
Expand Down Expand Up @@ -230,11 +147,18 @@ def compute_cell_gradient_lsq(
### Get cell centroids
cell_centroids = mesh.cell_centroids # (n_cells, n_spatial_dims)

### Use shared batched LSQ solver
return _solve_batched_lsq_gradients(
positions=cell_centroids,
### Delegate LSQ solve to the functional API using the torch backend.
from physicsnemo.nn.functional.derivatives.mesh_lsq_gradient import (
mesh_lsq_gradient,
)

gradients = mesh_lsq_gradient(
points=cell_centroids,
values=cell_values,
adjacency=adjacency,
neighbor_offsets=adjacency.offsets,
neighbor_indices=adjacency.indices,
weight_power=weight_power,
min_neighbors=0, # Cells may have fewer neighbors than points
min_neighbors=0, # Cells may have fewer neighbors than points.
implementation="torch",
)
return _to_mesh_gradient_layout(gradients, cell_values)
14 changes: 14 additions & 0 deletions physicsnemo/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .derivatives import (
mesh_green_gauss_gradient,
mesh_lsq_gradient,
meshless_fd_derivatives,
rectilinear_grid_gradient,
spectral_grid_gradient,
uniform_grid_gradient,
)
from .equivariant_ops import (
legendre_polynomials,
polar_and_dipole_basis,
Expand All @@ -36,17 +44,23 @@
"interpolation",
"knn",
"legendre_polynomials",
"mesh_green_gauss_gradient",
"meshless_fd_derivatives",
"mesh_lsq_gradient",
"na1d",
"na2d",
"na3d",
"polar_and_dipole_basis",
"radius_search",
"real",
"rectilinear_grid_gradient",
"rfft",
"rfft2",
"signed_distance_field",
"smooth_log",
"spectral_grid_gradient",
"spherical_basis",
"uniform_grid_gradient",
"vector_project",
"view_as_complex",
"weight_fact",
Expand Down
43 changes: 43 additions & 0 deletions physicsnemo/nn/functional/derivatives/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .mesh_green_gauss_gradient import MeshGreenGaussGradient, mesh_green_gauss_gradient
from .mesh_lsq_gradient import MeshLSQGradient, mesh_lsq_gradient
from .meshless_finite_difference import (
MeshlessFDDerivatives,
meshless_fd_derivatives,
)
from .rectilinear_grid_gradient import (
RectilinearGridGradient,
rectilinear_grid_gradient,
)
from .spectral_grid_gradient import SpectralGridGradient, spectral_grid_gradient
from .uniform_grid_gradient import UniformGridGradient, uniform_grid_gradient

__all__ = [
"MeshGreenGaussGradient",
"MeshlessFDDerivatives",
"MeshLSQGradient",
"RectilinearGridGradient",
"SpectralGridGradient",
"UniformGridGradient",
"mesh_green_gauss_gradient",
"meshless_fd_derivatives",
"mesh_lsq_gradient",
"rectilinear_grid_gradient",
"spectral_grid_gradient",
"uniform_grid_gradient",
]
Loading
Loading