Skip to content

Conversation

@aarushisingh04
Copy link

@aarushisingh04 aarushisingh04 commented Jan 7, 2026

What does this PR do?

This PR adds support for preferred_element_type and out_sharding to core NNX
Attention and RNN modules, aligning them with existing NNX layers such as
Linear and Conv.

The change improves API consistency and gives users fine-grained control over
numerical precision and distributed sharding at the module level, without
changing any default behavior.


Motivation

In NNX, low-level layers like Linear and Conv already expose
preferred_element_type and out_sharding, but higher-level modules that rely
on these layers such as MultiHeadAttention and RNN cells did not forward or
expose these parameters.

Negatives/Difficulties -

  • Controlling numerical precision for attention and recurrent modules
  • Applying sharding constraints to module outputs in distributed settings
  • Writing consistent code when mixing Linear / Conv with Attention or RNNs

This PR closes that gap by plumbing the existing parameters through these
modules in a fully backward-compatible way.


Changes Made

1. Attention module enhancements (attention.py)

Updated:

  • dot_product_attention
  • dot_product_attention_weights
  • MultiHeadAttention

Changes:

  • Added preferred_element_type to function signatures and forwarded it to
    internal jnp.einsum calls
  • Added out_sharding to MultiHeadAttention.__call__ to allow sharding of the
    output projection
  • Updated docstrings to document the new parameters

2. RNN module enhancements (recurrent.py)

Updated:

  • LSTMCell
  • OptimizedLSTMCell
  • SimpleCell
  • GRUCell

Changes:

  • Added preferred_element_type to __init__ and forwarded it to internal
    Linear layers
  • Added out_sharding to __call__ for all cells
  • Standardized and completed docstrings (Args / Returns)

Backward Compatibility

  • No existing behavior is changed
  • All new parameters are optional
  • Default numerical precision and sharding remain unchanged unless explicitly
    specified by the user

Tests Added

A new targeted test suite was added:

  • tests/nnx/technical_support_test.py

This suite verifies:

  • preferred_element_type propagation for Attention and RNN modules
  • out_sharding exposure and forwarding in MultiHeadAttention and RNN cells

Specific tests include:

  • test_mha_preferred_element_type
  • test_mha_out_sharding_signature
  • test_lstm_preferred_element_type
  • test_gru_preferred_element_type
  • test_simple_cell_out_sharding_signature

All tests passed.


Regression Testing

Existing test suites were run to ensure no regressions:

  • tests/nnx/nn/attention_test.py
  • tests/nnx/nn/recurrent_test.py

All tests passed.


Verification Command

$env:PYTHONPATH="."; pytest -v \
  tests/nnx/technical_support_test.py \
  tests/nnx/nn/attention_test.py \
  tests/nnx/nn/recurrent_test.py

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link).
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

@google-cla
Copy link

google-cla bot commented Jan 7, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@aarushisingh04
Copy link
Author

hey @samanklesaria
could you check this PR out whenever possible?

@samanklesaria
Copy link
Collaborator

I'm a bit skeptical of adding out_sharding to MultiHeadAttention. As I see it, out_sharding is used to resolve ambiguity if the explicitly sharded tensors you pass to an operation are sharded inconsistently for some dimension. The MultiHeadAttention layer creates ("output"s) a lot of different tensors when called: Q, K, V, the attention weights, and final output. It seems like the out_sharding argument you add to MultiHeadAttention is used only when projecting the final output. So it can't be used to resolve sharding inconsistencies in any of these intermediate steps. Rather than only offering the user an opportunity to resolve inconsistencies in only one of the many operations implied by MultiHeadAttention, I think it would be cleaner if we kept the current functionality and didn't support it for any of them. If the user wants the output of MultiHeadAttention sharded in a specific way, the can just reshard it manually afterwards.

@aarushisingh04
Copy link
Author

@samanklesaria

thanks for the feedback

  1. attention.py keeps preferred_element_type, avoided adding out_sharding
  2. recurrent.py has a duplicate import removed (and keeps both preferred_element_type andout_sharding )
  3. moved the tests to attention_test.py and recurrent_test.py

@aarushisingh04 aarushisingh04 changed the title feat(nnx): add preferred_element_type and out_sharding support to Attention and RNN modules feat(nnx): add preferred_element_type to attention.py and recurrent.py, out_sharding to recurrent.py Jan 13, 2026
Copy link
Collaborator

@samanklesaria samanklesaria left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@aarushisingh04
Copy link
Author

hey @samanklesaria just checking in, is this okay to merge? thanks!

@samanklesaria
Copy link
Collaborator

@aarushisingh04 It's in the merge queue. But PRs have to go through a separate testing process on the Google side before they get merged in, so it can take a bit.

)
if len(query_shape) > 4:
out = jnp.reshape(out, query_shape)
return out
Copy link
Collaborator

@chapman20j chapman20j Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you incorporate preferred_element_type here as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants