-
Notifications
You must be signed in to change notification settings - Fork 783
feat(nnx): add preferred_element_type to attention.py and recurrent.py, out_sharding to recurrent.py #5179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ention and RNN modules
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
hey @samanklesaria |
|
I'm a bit skeptical of adding |
|
thanks for the feedback
|
samanklesaria
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
|
hey @samanklesaria just checking in, is this okay to merge? thanks! |
|
@aarushisingh04 It's in the merge queue. But PRs have to go through a separate testing process on the Google side before they get merged in, so it can take a bit. |
| ) | ||
| if len(query_shape) > 4: | ||
| out = jnp.reshape(out, query_shape) | ||
| return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you incorporate preferred_element_type here as well?
What does this PR do?
This PR adds support for
preferred_element_typeandout_shardingto core NNXAttention and RNN modules, aligning them with existing NNX layers such as
LinearandConv.The change improves API consistency and gives users fine-grained control over
numerical precision and distributed sharding at the module level, without
changing any default behavior.
Motivation
In NNX, low-level layers like
LinearandConvalready exposepreferred_element_typeandout_sharding, but higher-level modules that relyon these layers such as
MultiHeadAttentionand RNN cells did not forward orexpose these parameters.
Negatives/Difficulties -
Linear/Convwith Attention or RNNsThis PR closes that gap by plumbing the existing parameters through these
modules in a fully backward-compatible way.
Changes Made
1. Attention module enhancements (
attention.py)Updated:
dot_product_attentiondot_product_attention_weightsMultiHeadAttentionChanges:
preferred_element_typeto function signatures and forwarded it tointernal
jnp.einsumcallsout_shardingtoMultiHeadAttention.__call__to allow sharding of theoutput projection
2. RNN module enhancements (
recurrent.py)Updated:
LSTMCellOptimizedLSTMCellSimpleCellGRUCellChanges:
preferred_element_typeto__init__and forwarded it to internalLinearlayersout_shardingto__call__for all cellsBackward Compatibility
specified by the user
Tests Added
A new targeted test suite was added:
tests/nnx/technical_support_test.pyThis suite verifies:
preferred_element_typepropagation for Attention and RNN modulesout_shardingexposure and forwarding inMultiHeadAttentionand RNN cellsSpecific tests include:
test_mha_preferred_element_typetest_mha_out_sharding_signaturetest_lstm_preferred_element_typetest_gru_preferred_element_typetest_simple_cell_out_sharding_signatureAll tests passed.
Regression Testing
Existing test suites were run to ensure no regressions:
tests/nnx/nn/attention_test.pytests/nnx/nn/recurrent_test.pyAll tests passed.
Verification Command
Checklist