Conversation
Greptile SummaryThis PR introduces seven derivative functionals (
Important Files Changed
|
| @FunctionSpec.register(name="torch", rank=0, baseline=True) | ||
| def torch_forward( | ||
| points: Float[Tensor, "num_points dim"], | ||
| spacing: float | Sequence[float] = 1.0, | ||
| include_center: bool = True, | ||
| ) -> Float[Tensor, "num_points stencil_size dim"]: | ||
| """Dispatch stencil-point construction to the PyTorch backend.""" | ||
| return meshless_fd_stencil_points_torch( | ||
| points=points, | ||
| spacing=spacing, | ||
| include_center=bool(include_center), | ||
| ) |
There was a problem hiding this comment.
Inconsistent use of
jaxtyping annotations across the PR
MeshlessFDStencilPoints, MeshlessFDDerivatives, and SpectralGridGradient use jaxtyping (Float[Tensor, "..."]) in their FunctionSpec.register-decorated methods, while MeshGreenGaussGradient, MeshLSQGradient, RectilinearGridGradient, and UniformGridGradient in the same PR use plain torch.Tensor. Consistent annotation style across all functionals in this module would improve maintainability.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| @@ -0,0 +1,19 @@ | |||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | |||
There was a problem hiding this comment.
This is currently unused in the mesh module however I already had a really nice implementation of it and wanted to add it.
| ) | ||
|
|
||
|
|
||
| def meshless_fd_stencil_points_torch( |
There was a problem hiding this comment.
This is just used for generating the data for testing. Probably can remove honestly but ok...
| ### Forward kernels (periodic central differences) | ||
| ### ============================================================ | ||
|
|
||
|
|
There was a problem hiding this comment.
I might break this file up...
ktangsali
left a comment
There was a problem hiding this comment.
Looks great, thanks @loliverhennigh! I also tested the PhysicsInformer integration with this PR and we have everything needed for supporting that higher level abstraction as well. Thank you for addressing the mixed derivatives comment and getting it to feature parity with Sym's implementations.
I have added a few comments. Some comments (like the docstring one for functions) are applicable to all the different gradient methods but I have just added the comment to one of the functions / implementations. Kindly take a look.
Overall this is looking in a great shape! Don't forget to update the Changelog.
| spacing: float | Sequence[float] = 1.0, | ||
| derivative_orders: int | Sequence[int] = 1, | ||
| include_mixed: bool = False, | ||
| ) -> Float[Tensor, "num_derivs num_points channels"]: |
There was a problem hiding this comment.
I think as Greptile also pointed out, there is some inconsistency in jaxtyping across the different gradient methods.
| neighbor_indices=neighbor_indices, | ||
| min_neighbors=int(min_neighbors), | ||
| ) | ||
| if points.requires_grad: |
There was a problem hiding this comment.
From my understanding torch will support this right? Any specific reason why we cannot support this with Warp too?
| ) -> torch.Tensor: | ||
| """Compute Green-Gauss cell-centered gradients with Warp kernels.""" | ||
| validate_inputs(points=points, cells=cells, neighbors=neighbors, values=values) | ||
| if points.requires_grad: |
| include_mixed=bool(include_mixed), | ||
| ) | ||
|
|
||
| orig_dtype = field.dtype |
There was a problem hiding this comment.
I believe this is a common pattern across the kernels - which makes sense to me, but maybe we add it as a note somewhere? Something like:
Note: Warp backends internally compute in float32. Float64 inputs are accepted but derivative accuracy is limited to float32 precision.
Because we would be losing precision if someone decides for some reason to pass fp64 (which folks might to get more accurate grads)
| _BENCHMARK_CASES = ( | ||
| ("2d-tri-24x24-scalar", 24, 24, False), | ||
| ("2d-tri-36x36-scalar", 36, 36, False), | ||
| ("2d-tri-36x36-vector", 36, 36, True), |
There was a problem hiding this comment.
Seems like we don't have a 3D test here ? Would be good to test the 3D path too.
| used: on CUDA, 1D/2D fields prefer ``torch``; 3D fields use a single-threshold | ||
| crossover (``torch`` -> ``warp``) as problem size | ||
| grows. Inputs requiring gradients prefer ``warp`` to use the explicit | ||
| custom backward kernels. |
There was a problem hiding this comment.
This would be the public facing API right? Is my understanding correct ? If so, can we add minimal examples to them? And maybe the parameters can be further expanded?
Examples
--------
>>> field = torch.sin(torch.linspace(0, 2*torch.pi, 64))
>>> grads = uniform_grid_gradient(field, spacing=2*torch.pi/64)
>>> grads.shape
torch.Size([1, 64])
There was a problem hiding this comment.
Because I see that we are rendering these functions in the docs...
There was a problem hiding this comment.
We should also add a CHANGELOG entry for this PR :)
PhysicsNeMo Pull Request
Description
Various derivative functionals