Skip to content

Add Qwen3.5 MoE (35B-A3B) model export and runner for CUDA backend#18169

Merged
mergennachin merged 2 commits intomainfrom
mnachin/qwen3_5_moe
Mar 20, 2026
Merged

Add Qwen3.5 MoE (35B-A3B) model export and runner for CUDA backend#18169
mergennachin merged 2 commits intomainfrom
mnachin/qwen3_5_moe

Conversation

@mergennachin
Copy link
Copy Markdown
Contributor

@mergennachin mergennachin commented Mar 13, 2026

Memory-efficient loading using meta-device construction + lazy
safetensors shard-by-shard loading + assign=True state dict loading,
following the voxtral_realtime pattern. Peak CPU memory during loading
is ~1x model size instead of ~3x.

Expert weights are structured as grouped nn.Linear modules (16 groups
of 16 experts each) so quantize_model_() handles them automatically.
Layer-by-layer quantization on CUDA avoids loading the full bf16 model
onto GPU at once.

Includes C++ runner using the shared TextLLMRunner, Makefile target,
and CMake presets.

Reference implementations:

Screenshot 2026-03-18 at 8 03 56 PM

6.5 token/s for decode -- can improve later

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 13, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18169

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 13, 2026
@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@mergennachin mergennachin force-pushed the mnachin/qwen3_5_moe branch 3 times, most recently from 2bd84b3 to 54055da Compare March 19, 2026 00:17
@mergennachin mergennachin marked this pull request as ready for review March 19, 2026 00:19
Copilot AI review requested due to automatic review settings March 19, 2026 00:19
@mergennachin mergennachin force-pushed the mnachin/qwen3_5_moe branch 2 times, most recently from 9c07cbf to 3f5e7a5 Compare March 19, 2026 00:23
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ExecuTorch export + C++ runner pipeline for the Qwen3.5 MoE (35B-A3B) model targeting the CUDA backend, including build system integration and model-specific loading/quantization/export code.

Changes:

  • Introduce a new examples/models/qwen3_5_moe/ package with model definition, HF safetensors loading/remapping, and CUDA export script.
  • Add a C++ runner using TextLLMRunner, plus CMake presets/lists to build it.
  • Wire up a new top-level make qwen3_5_moe-cuda target and document usage in a new README.

Reviewed changes

Copilot reviewed 8 out of 9 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
Makefile Adds qwen3_5_moe-cuda build target and help text entry.
examples/models/qwen3_5_moe/requirements.txt Adds extra Python dependency for the Triton kernel path (FLA).
examples/models/qwen3_5_moe/README.md Documents export/build/run workflow and model details.
examples/models/qwen3_5_moe/model.py Implements export-friendly Qwen3.5 MoE model + HF checkpoint remapping/loading.
examples/models/qwen3_5_moe/export.py Exports + lowers the model to CUDA ExecuTorch .pte/.ptd with optional quantization.
examples/models/qwen3_5_moe/main.cpp Adds a minimal C++ CLI runner using TextLLMRunner + HF tokenizer.
examples/models/qwen3_5_moe/CMakePresets.json Adds CMake presets/workflow for building the runner with CUDA.
examples/models/qwen3_5_moe/CMakeLists.txt Adds standalone CMake build for the runner linked against ExecuTorch + CUDA delegate.
examples/models/qwen3_5_moe/init.py Marks the directory as a Python package for imports.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread examples/models/qwen3_5_moe/main.cpp
Comment thread examples/models/qwen3_5_moe/main.cpp Outdated
Comment thread examples/models/qwen3_5_moe/export.py
Comment thread examples/models/qwen3_5_moe/model.py Outdated
Comment thread Makefile Outdated
Comment thread examples/models/qwen3_5_moe/model.py
Comment thread .github/workflows/cuda.yml Outdated
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="

# Build Qwen3.5 MoE runner
make qwen3_5_moe-cuda
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add e2e (export and run the model) later. currently, export takes a long time, want to optimize it first

return model, config


def _quantize(model, config, args):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be put somewhere generic. It seems non model specific

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're already using quantize_model in extension/llm/export/quantize.py. Agree it's generic in principle. Currently only this model needs it (35B params don't fit on GPU at once). Happy to extract it into extension/llm/export/quantize.py if/when a second model needs the pattern.

@@ -0,0 +1,835 @@
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we should merge these with llama_transformer.py Do you want to punt that to a later PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ideally yes. There are so many moving parts to enable MoE. One we find a common pattern, will "upstream" to llama_transformer.py


# Split via slicing (torch.split produces split_copy which lacks AOTI fallback)
kd = self.key_dim
q = qkv_conv[..., :kd].reshape(B, T, self.num_k_heads, self.head_k_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this slicing is sort of expensive after functionalization. Probably fine for AOTI based flows since they try and lower before functionalization iirc, but mlx lowering would have issues.

Comment thread examples/models/qwen3_5_moe/model.py Outdated

# Materialize remaining meta-device buffers (KV caches, conv/recurrent
# state, causal masks, RoPE inv_freq) as zeros on CPU
for fqn, buf in list(model.named_buffers()):
Copy link
Copy Markdown
Contributor

@JacobSzwejbka JacobSzwejbka Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why materialize them? If we are just tracing why is meta not ok?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, buffers aren't quantized so they don't need to be materialized during loading. Moved materialization out of from_hf_checkpoint (buffers stay on meta) and into export.py right before torch.export. Unfortunately torch.export(strict=True) can't handle meta buffers — dynamo fails to create FakeTensors from them when there are in-place updates (e.g. cache[:, input_pos] = h), so we still need to materialize before tracing.

Comment thread examples/models/qwen3_5_moe/README.md Outdated
| `--max-seq-len` | `4096` | KV cache length |
| `--qlinear` | (none) | Linear layer quantization: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear-group-size` | `32` | Group size for linear quantization |
| `--qlinear-packing-format` | (none) | Packing format for 4w: `tile_packed_to_4d` |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the packing format a cli arg?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the cli arg

| `--model-dir` | (required) | HuggingFace model directory with `config.json` + safetensors |
| `--output-dir` | `./qwen35_moe_exports` | Output directory |
| `--max-seq-len` | `4096` | KV cache length |
| `--qlinear` | (none) | Linear layer quantization: `4w`, `8w`, `8da4w`, `8da8w` |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we compbine this and the group size into one arg 4w,32 etc

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will leave as-is, since it is consistent with other CUDA models

executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)

# Needed for cpuinfo where it uses android specific log lib
if(ANDROID)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have this for a CUDA binary?

endif()

# On Windows, copy required DLLs to the executable directory
if(MSVC AND EXECUTORCH_BUILD_CUDA)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(MSVC AND EXECUTORCH_BUILD_CUDA)
Assert EXECUTORCH_BUILD_CUDA==1 above
if(MSVC)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

find_package(CUDAToolkit REQUIRED) already handles it

`model.pte` file with int4 quantization. At inference time, the C++ runner
loads the `.pte`, `.ptd`, and a HuggingFace tokenizer, then generates text.

## Prerequisites
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add FLA?


Export produces a `model.pte` and `aoti_cuda_blob.ptd` containing the
compiled CUDA kernels and quantized weights. Int4 quantization is
recommended — the model is too large to fit in VRAM at bf16.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for which GPU?

Memory-efficient loading using meta-device construction + lazy
safetensors shard-by-shard loading + assign=True state dict loading,
following the voxtral_realtime pattern. Peak CPU memory during loading
is ~1x model size instead of ~3x.

Expert weights are structured as grouped nn.Linear modules (16 groups
of 16 experts each) so quantize_model_() handles them automatically.
Layer-by-layer quantization on CUDA avoids loading the full bf16 model
onto GPU at once.

Includes C++ runner using the shared TextLLMRunner, Makefile target,
and CMake presets.

Reference implementations:
- https://github.com/mergennachin/nano_qwen35_moe/
- vLLM: vllm/model_executor/models/qwen3_5.py
]
)

def forward(self, x, expert_indices):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how is this lowered? If I were to guess this is where majority of the perf to be gained. You should be able to find a fused-moe kernels which leverages groupped gemm kernels.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's next step

# Untie lm_head/embedding so they can be quantized independently:
# embedding uses index lookup (8w), lm_head uses matmul (4w).
if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr():
model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of untying, what if we run lm_head as int8? Memory should be better, but int8 is also slower, just curious.

Comment thread examples/models/qwen3_5_moe/export.py Outdated
qlinear_group_size=args.qlinear_group_size,
qlinear_packing_format=args.qlinear_packing_format,
)
layer.to(device="cpu")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is needed for us to fit in vram, but does it slow down lowering?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't slow down the lowering in my profiling.

Comment on lines +152 to +153
exported = exported.run_decompositions(
{torch.ops.aten.conv1d.default: conv1d_to_conv2d}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this pass be folded inside the partitioner?

config.temperature = FLAGS_temperature;
config.max_new_tokens = FLAGS_max_new_tokens;

auto error = runner->generate(FLAGS_prompt.c_str(), config);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add basic benchmarking like llama.cpp::llama_bench?

within tinygemm int4 packing limits. 256 experts / 16 = 16 groups, giving
32 matmul nodes per layer instead of 768 with per-expert linears.

Forward pass: compute all groups → cat → gather top-k → SwiGLU → compute
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has redundant compute, wonder how much we can speed up if we don't compute all the experts and throw them away later.

Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments, but if its functional then we should merge this. And later improve. Thanks @mergennachin.

Copilot AI review requested due to automatic review settings March 20, 2026 18:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new ExecuTorch example for exporting and running the Qwen3.5 MoE (35B-A3B) text model on the CUDA backend, including a memory-efficient safetensors loader, a C++ runner using the shared TextLLMRunner, and build/CI wiring.

Changes:

  • Introduces a self-contained Qwen3.5 MoE model implementation + safetensors checkpoint remapping/loading.
  • Adds a CUDA-only export pipeline with optional int4/int8 quantization and CUDA backend lowering.
  • Adds C++ runner + CMake presets + Makefile target, and builds the runner in CUDA CI.

Reviewed changes

Copilot reviewed 10 out of 11 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
Makefile Adds qwen3_5_moe-cuda build target and help text entry.
examples/models/qwen3_5_moe/requirements.txt Adds flash-linear-attention dependency pin.
examples/models/qwen3_5_moe/README.md Documents export/build/run workflow for the new model.
examples/models/qwen3_5_moe/model.py Implements export-friendly Qwen3.5 MoE model + safetensors loading/remapping.
examples/models/qwen3_5_moe/model.md Architecture/design notes for the implementation and export strategy.
examples/models/qwen3_5_moe/main.cpp Adds a minimal TextLLMRunner-based C++ runner.
examples/models/qwen3_5_moe/export.py Adds export + (int4/int8) quantization + CUDA lowering pipeline.
examples/models/qwen3_5_moe/CMakePresets.json Adds presets to configure/build the runner with CUDA.
examples/models/qwen3_5_moe/CMakeLists.txt Adds runner build definition and CUDA backend linkage.
.github/workflows/cuda.yml Builds the new runner in the CUDA CI job.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread examples/models/qwen3_5_moe/main.cpp Outdated
Comment on lines +20 to +46
DEFINE_string(
data_path,
"",
"Comma-separated data files (.ptd) for CUDA backend.");
DEFINE_string(tokenizer_path, "", "HuggingFace tokenizer.json path.");
DEFINE_string(prompt, "Hello", "Prompt text.");
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");

namespace llm = ::executorch::extension::llm;

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

if (FLAGS_model_path.empty()) {
ET_LOG(Error, "Must specify --model_path");
return 1;
}
if (FLAGS_tokenizer_path.empty()) {
ET_LOG(Error, "Must specify --tokenizer_path");
return 1;
}

std::vector<std::string> data_files;
if (!FLAGS_data_path.empty()) {
data_files.push_back(FLAGS_data_path);
}
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--data_path is documented as “Comma-separated data files”, but the runner currently pushes the entire string as a single entry. This breaks multi-.ptd usage (and is inconsistent with the flag description). Split FLAGS_data_path on commas (and trim whitespace) before passing data_files to create_text_llm_runner().

Copilot uses AI. Check for mistakes.
Comment on lines +38 to +43
if args.qlinear or args.qembedding:
_quantize(model, config, args)
else:
model.to(dtype=torch.bfloat16)

return model, config
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quantization runs before _materialize_buffers(), while the model still contains meta-device buffers from from_hf_checkpoint(). As a result, _to_device_skip_meta() skips moving key submodules (e.g., attention/GatedDeltaNet) to CUDA because they have meta buffers, so quantize_model_() may run on CPU weights or fail to apply CUDA packing. Materialize/replace meta buffers before the layer-by-layer CUDA quantization (or update _to_device_skip_meta() to move parameters even when a module has meta buffers).

Copilot uses AI. Check for mistakes.
# Dynamic shapes
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dynamic shape upper bound uses max=config.max_seq_len - 1, which caps exported sequence length at one token less than the KV cache size (e.g., 4095 for a 4096 cache). Since max_seq_len is the intended context length, this should likely be max=config.max_seq_len to avoid needlessly reducing supported prompt length.

Suggested change
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len)

Copilot uses AI. Check for mistakes.
Comment on lines +270 to +271
k = k.repeat_interleave(self.n_kv_groups, dim=1)
v = v.repeat_interleave(self.n_kv_groups, dim=1)
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat_interleave() on the full KV cache creates a new (B, n_heads, max_seq_len, head_dim) tensor every forward, which is a large allocation and can dominate decode-time memory/bandwidth. Consider using an expand-based view (e.g., add a group dimension then reshape) to broadcast K/V to n_heads without copying, if SDPA accepts the resulting strides.

Suggested change
k = k.repeat_interleave(self.n_kv_groups, dim=1)
v = v.repeat_interleave(self.n_kv_groups, dim=1)
B_kv, n_kv_heads, S_kv, D_kv = k.shape
# Broadcast K/V across groups without materializing full repeats
k = k.unsqueeze(2).expand(B_kv, n_kv_heads, self.n_kv_groups, S_kv, D_kv)
v = v.unsqueeze(2).expand(B_kv, n_kv_heads, self.n_kv_groups, S_kv, D_kv)
k = k.reshape(B_kv, n_kv_heads * self.n_kv_groups, S_kv, D_kv)
v = v.reshape(B_kv, n_kv_heads * self.n_kv_groups, S_kv, D_kv)

Copilot uses AI. Check for mistakes.
Comment thread examples/models/qwen3_5_moe/model.py Outdated
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)

# FLA Triton kernel
state = self.recurrent_state[:B].clone()
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state = self.recurrent_state[:B].clone() allocates a full copy of the recurrent state on every forward call. The chunk_gated_delta_rule op does not mutate its initial_state argument (it returns final_state separately), so you can pass self.recurrent_state[:B] directly and avoid the per-step clone.

Suggested change
state = self.recurrent_state[:B].clone()
state = self.recurrent_state[:B]

Copilot uses AI. Check for mistakes.
Comment on lines +60 to +61
target_link_options_gc_sections(qwen3_5_moe_runner)
target_link_options(qwen3_5_moe_runner PRIVATE "LINKER:-s")
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This target unconditionally adds the linker option LINKER:-s when not Debug. That flag is not valid under MSVC, but the presets explicitly allow Windows. Guard LINKER:-s (and potentially target_link_options_gc_sections) behind if(NOT MSVC)/platform checks (similar to voxtral_realtime), and consider adding the post-build DLL copy for aoti_cuda_shims on MSVC so the runner can execute on Windows.

Suggested change
target_link_options_gc_sections(qwen3_5_moe_runner)
target_link_options(qwen3_5_moe_runner PRIVATE "LINKER:-s")
if(NOT MSVC)
target_link_options_gc_sections(qwen3_5_moe_runner)
target_link_options(qwen3_5_moe_runner PRIVATE "LINKER:-s")
endif()

Copilot uses AI. Check for mistakes.
Comment thread .github/workflows/cuda.yml Outdated
Comment on lines +138 to +139
# Build Qwen3.5 MoE runner
make qwen3_5_moe-cuda
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make qwen3_5_moe-cuda re-runs cmake --workflow --preset llm-release-cuda (per the Makefile target), but this job already built ExecuTorch with CUDA just above. This duplicates a full configure/build/install step and will add significant CI time; consider invoking only the runner build workflow (cd examples/models/qwen3_5_moe && cmake --workflow --preset qwen3-5-moe-cuda) here, or adjusting the Makefile target to skip rebuilding ExecuTorch when cmake-out is already populated.

Suggested change
# Build Qwen3.5 MoE runner
make qwen3_5_moe-cuda
# Build Qwen3.5 MoE runner (without rebuilding ExecuTorch)
pushd examples/models/qwen3_5_moe
cmake --workflow --preset qwen3-5-moe-cuda
popd

Copilot uses AI. Check for mistakes.
- Defer buffer materialization from model loading to export time
  (from_hf_checkpoint now returns model with meta buffers, materialized
  in export.py before torch.export)
- Use _to_device_skip_meta() to move only non-meta submodules to CUDA
  during quantization, correctly handling tensor subclasses
- Remove --qlinear-packing-format CLI arg, auto-set tile_packed_to_4d
  for 4w quantization
- Remove Android, MSVC, and Apple platform guards from CMakeLists.txt
  (CUDA-on-Linux only runner)
- Add GPU spec (A100 80GB) and pip install requirements.txt to README
@mergennachin mergennachin merged commit 36c22e9 into main Mar 20, 2026
167 of 173 checks passed
@mergennachin mergennachin deleted the mnachin/qwen3_5_moe branch March 20, 2026 19:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants