Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions tests/models/qwen3/qwen_params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax.numpy as jnp
import numpy as np
import safetensors.numpy as safe_np
from flax import nnx
from tunix.models.qwen3 import model as qwen3_model
from tunix.models.qwen3 import params as qwen3_params
from tunix.tests import lora_params_test_base
Expand Down Expand Up @@ -200,6 +201,84 @@ def create_checkpoint(self, model) -> str:
return self.base_checkpoint_dir



class Qwen3ModelConfigTest(absltest.TestCase):
"""Tests specific to Qwen3 configurations and architectural flags."""

def test_tied_embeddings_config_values(self):
"""Verify that smaller Qwen3 models correctly default to tied embeddings."""
# 0.5B / 0.6B model should use tied embeddings
config_0p6b = qwen3_model.ModelConfig.qwen3_0p6b()
self.assertTrue(
config_0p6b.use_tied_embedding,
"qwen3_0p6b config should have use_tied_embedding=True"
)

# 1.5B / 1.7B model should use tied embeddings
config_1p7b = qwen3_model.ModelConfig.qwen3_1p7b()
self.assertTrue(
config_1p7b.use_tied_embedding,
"qwen3_1p7b config should have use_tied_embedding=True"
)

# 8B model typically does not use tied embeddings
config_8b = qwen3_model.ModelConfig.qwen3_8b()
self.assertFalse(
config_8b.use_tied_embedding,
"qwen3_8b config should have use_tied_embedding=False"
)

def test_model_instantiation_with_tied_embeddings(self):
"""Verify that the Qwen3 model omits the lm_head when embeddings are tied."""
rngs = nnx.Rngs(params=0)

# Create a config WITH tied embeddings
tied_config = qwen3_model.ModelConfig(
num_layers=2,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
num_heads=2,
head_dim=32,
num_kv_heads=2,
rope_theta=10000,
norm_eps=1e-6,
use_tied_embedding=True,
)
tied_model = qwen3_model.Qwen3(tied_config, rngs=rngs)

# The model should decode using the embedder, so lm_head shouldn't exist
self.assertFalse(
hasattr(tied_model, "lm_head"),
"Model should not have a separate lm_head when use_tied_embedding is True"
)

def test_model_instantiation_without_tied_embeddings(self):
"""Verify that the Qwen3 model includes the lm_head when embeddings are not tied."""
rngs = nnx.Rngs(params=0)

# Create a config WITHOUT tied embeddings
untied_config = qwen3_model.ModelConfig(
num_layers=2,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
num_heads=2,
head_dim=32,
num_kv_heads=2,
rope_theta=10000,
norm_eps=1e-6,
use_tied_embedding=False,
)
untied_model = qwen3_model.Qwen3(untied_config, rngs=rngs)

# The model should have a distinct lm_head layer
self.assertTrue(
hasattr(untied_model, "lm_head"),
"Model must have a separate lm_head when use_tied_embedding is False"
)


if __name__ == "__main__":
# Check if running in Jupyter/IPython environment
if test_common.is_running_in_colab():
Expand Down
2 changes: 2 additions & 0 deletions tunix/models/qwen3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def qwen3_0p6b(cls): # qwen3-0.6B
num_kv_heads=8,
norm_eps=1e-06,
rope_theta=1_000_000,
use_tied_embedding=True,
)

qwen3_0p6b_base = qwen3_0p6b # qwen3-0.6B-base
Expand All @@ -129,6 +130,7 @@ def qwen3_1p7b(cls): # qwen3-1.7B
num_kv_heads=8,
norm_eps=1e-06,
rope_theta=1_000_000,
use_tied_embedding=True,
)

qwen3_1p7b_base = qwen3_1p7b # qwen3-1.7B-base
Expand Down
Loading