Skip to content

Auto Differentiation Implementation 1/3#1226

Draft
mar-yan24 wants to merge 1 commit intogoogle-deepmind:mainfrom
mar-yan24:mark/autodifferentiation
Draft

Auto Differentiation Implementation 1/3#1226
mar-yan24 wants to merge 1 commit intogoogle-deepmind:mainfrom
mar-yan24:mark/autodifferentiation

Conversation

@mar-yan24
Copy link

Autodifferentiation Support 1/3

Overview

So a bit ago I was sorta interested in implementing automatic differentiation into MJWarp cause I wanna do a project with diff contact geometry and I had some time on my hands so I decided to begin working on a personal implementation of AD support in MJWarp. I'll probably continue working on this over the month but I'm putting in a draft here to see if I can get some maintainer feedback and maybe discuss if there is still community desire for this as referenced in 'issue' #500.

Basically, these changes add reverse-mode AD support for the smooth dynamics pipeline of MuJoCo Warp. Most people should now be able to compute gradients of scalar loss functions with respect to qpos, qvel, ctrl, and other state variables by recording a wp.Tape over kinematics -> fwd_velocity -> fwd_actuation -> euler.

The implementation follows a selective enable_backward strategy: only the four modules that participate in the differentiable smooth-dynamics path have backward code generation enabled. All other modules (collision, constraint, solver, sensor, render, ray, etc.) remain at enable_backward: False. This should keep compilation time and binary size normal i think.

Architecture

Selective backward generation

Module enable_backward Notes
smooth.py True kinematics, crb, rne, com_vel, etc.
forward.py True fwd_velocity, fwd_actuation, euler/rk4
passive.py True passive forces (spring, damper, fluid)
derivative.py True analytical derivatives (qDeriv)
All others (13+) False collision, constraint, solver, sensor, ...

Within the enabled modules, tile kernels (wp.launch_tiled) still have per-kernel enable_backward=False overrides (smooth.py lines 1053/2825/2903, forward.py line 309) because cuSolverDx LTO compilation does not support adjoint generation.

Kernel compilation time

@Kenny-Vilella raised a good concern in discussion #993 which was about enabling backward globally. As I mentioned earlier, one of the issues we wanna avoid is hella long compile. This selective approach basically generates adjoint kernels for only ~30 smooth-dynamics kernels out of 100+ total. Warp caches compiled kernels, so the cost is one-time per kernel signature.

New modules

  • grad.py - coordination layer: enable_grad(), disable_grad(), make_diff_data(), diff_step(), diff_forward(), SMOOTH_GRAD_FIELDS.
  • adjoint.py - centralizes @wp.func_grad registrations. Phase 1 provides a custom adjoint for quat_integrate (avoids gradient singularity at zero angular velocity).
  • grad_test.py - AD test suite: kinematics, fwd_velocity, fwd_actuation, euler_step, quaternion integration, and utility tests.

Summary

As a whole, my goal with this is to get some initial feedback and review from both the community and the maintainers on whether or not this project and implementation is feasible. Any words of advice and feedback is appreciated!

@thowell thowell linked an issue Mar 16, 2026 that may be closed by this pull request
@thowell thowell self-requested a review March 16, 2026 09:53
@thowell
Copy link
Collaborator

thowell commented Mar 16, 2026

@mar-yan24 thank you for this contribution!

thanks for scoping this to the smooth dynamics for now. have you identified any key blockers for adding differentiation support to other parts of the code like the collision pipeline or constraint solver?

@thowell
Copy link
Collaborator

thowell commented Mar 16, 2026

@erikfrey @adenzler-nvidia
how do we want to think about an api for differentiation?

@adenzler-nvidia @Kenny-Vilella
what are the performance implications for utilizing wp.clone? are these calls something we should consider guarding with wp.static?
what considerations should be made for tile operations and differentiability?

@mar-yan24
Copy link
Author

mar-yan24 commented Mar 16, 2026

Thanks for the review @thowell! Yea I scoped to smooth dynamics after doing a rough survey of the collision and solver code. Within the collision pipeline and constraint solver, there are some blockers I need to look into and do some testing on but these are what I mainly found.

For the collision pipeline, the fundamental issue is that collision detection is a discrete geometric query, with several integer configuration vars (cltype/clface/clcorner) controlling branching with no gradient, and atomic counters make the contact count data-dependent. There are also some algorithmic non-differentiabilities, which im pretty sure is not fixable just by enabling backward (enable_backward) on the existing code. Probably gotta bypass the discrete pipeline with smooth distance proxies and custom adjoins in adjoint.py.

For constraint solver I think I need to spend some more time looking into it, but from what ive seen, the biggest issues for enable_backward permission are wp.capture_while (runs until all worlds converse, so number of iterations varies per world -> tape needs fixed computation graph to replay backwards), constraint activation (constraint active depends on pos < 0, the discontinuity needs to be avoided, maybe add small perturbation to qpos?), and wp.tile_cholesky (does not support LTO adjoin generation, no flow for grad).

I also don't mind taking a look at wp.clone and tile operations for performance optimization. I don't really know as much on performance optimization but im a student I have a decent amount of time on my hands lol.

If it would help, I can also write a MD file for the larger changes I need to implement and a high-level roadmap for the whole implementation.

@Kenny-Vilella
Copy link
Collaborator

Some tile operations have some limitation on the adjoint calculation, this includes:

  • wp.tile_cholesky
  • wp.tile_matmul
  • wp.tile_*_solve

But most of them should be OK.

For wp.clone, it's a memory operation so it may be expensive depending on the size of the array.
One thing to consider is whether it is faster to spin up a kernel, but it's not a major issue.

@mar-yan24 mar-yan24 force-pushed the mark/autodifferentiation branch from 47a5120 to eb256a4 Compare March 17, 2026 03:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Differentiability

3 participants