Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/models/automodel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,14 @@ def test_create_model_dynamically(
mesh = jax.sharding.Mesh(jax.devices(), ("devices",))
naming_info = naming.ModelNaming(model_name=model_name)
automodel.create_model_from_safe_tensors(
model_name, "file_dir", "model_config", mesh
model_name, "file_dir", "model_config", mesh, "dtype", "mode"
)
mock_create_fn.assert_called_once_with(
file_dir="file_dir", config="model_config", mesh=mesh
file_dir="file_dir",
config="model_config",
mesh=mesh,
dtype="dtype",
mode="mode",
)

if naming_info.model_family in ("gemma", "gemma1p1", "gemma2", "gemma3"):
Expand Down
18 changes: 14 additions & 4 deletions tests/models/safetensors_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from safetensors import numpy as stnp
from tunix.models import safetensors_loader
from tunix.tests import test_common
from tunix.utils import env_utils


def key_mapping(config):
Expand Down Expand Up @@ -86,11 +87,17 @@ def setUpClass(cls):
)

@parameterized.named_parameters(
dict(testcase_name='absolute_path', path_type='abs'),
dict(testcase_name='relative_path', path_type='rel'),
dict(testcase_name='relative_dot_path', path_type='rel_dot'),
*(([dict(testcase_name='opt_loader_enabled', mode='optimized')]
if not env_utils.is_internal_env() else []) + [
dict(testcase_name='absolute_path', path_type='abs'),
dict(testcase_name='relative_path', path_type='rel'),
dict(testcase_name='relative_dot_path', path_type='rel_dot'),
dict(testcase_name='opt_loader_disabled', mode='original'),
])
)
def test_load_and_create_model(self, path_type):
def test_load_and_create_model(
self, path_type='abs', mode='optimized'
):
try:
st_dir_abs = self.create_tempdir().full_path
except Exception: # pylint: disable=broad-except
Expand Down Expand Up @@ -119,6 +126,7 @@ def test_load_and_create_model(self, path_type):
self.model.config,
key_mapping,
dtype=jnp.float32,
mode=mode,
)
loaded_state = nnx.state(loaded_model)
jax.tree.map(
Expand All @@ -128,6 +136,8 @@ def test_load_and_create_model(self, path_type):
)

def test_load_and_create_model_from_gcs(self):
if env_utils.is_internal_env():
self.skipTest('GCS is not supported in GOOGLE_INTERNAL_PACKAGE_PATH')
try:
st_dir_abs = self.create_tempdir().full_path
except Exception: # pylint: disable=broad-except
Expand Down
18 changes: 16 additions & 2 deletions tunix/models/automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,12 @@ def download_model(


def create_model_from_safe_tensors(
model_name: str, file_dir: str, model_config: Any, mesh: jax.sharding.Mesh
model_name: str,
file_dir: str,
model_config: Any,
mesh: jax.sharding.Mesh,
dtype: jnp.dtype | None = None,
mode: str = "auto",
) -> Any:
"""Dynamically imports the correct module and calls `create_model_from_safe_tensors` based on the model_name.

Expand All @@ -321,6 +326,9 @@ def create_model_from_safe_tensors(
file_dir: Directory containing the safe tensors.
model_config: Model configuration object.
mesh: Mesh object for device layout.
dtype: Optional dtype to cast the loaded tensors to.
mode: The mode to use for loading the model. Options are ('auto',
'optimized', 'original').

Returns:
The result of the create_model_from_safe_tensors call.
Expand All @@ -347,7 +355,13 @@ def create_model_from_safe_tensors(
logging.info(
'Calling %s.create_model_from_safe_tensors', params_module.__name__
)
return create_fn(file_dir=file_dir, config=model_config, mesh=mesh)
return create_fn(
file_dir=file_dir,
config=model_config,
mesh=mesh,
dtype=dtype,
mode=mode,
)


class AutoModel:
Expand Down
2 changes: 2 additions & 0 deletions tunix/models/gemma/params_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def create_model_from_safe_tensors(
config: model_lib.ModelConfig,
mesh: jax.sharding.Mesh | None = None,
dtype: jnp.dtype | None = None,
mode: str = "auto",
) -> model_lib.Gemma:
v_ckpt = _peek_vocab_size_from_safetensors(file_dir)
if v_ckpt != config.num_embed:
Expand All @@ -184,4 +185,5 @@ def create_model_from_safe_tensors(
mesh=mesh,
preprocess_fn=_make_preprocess_fn(config),
dtype=dtype,
mode=mode,
)
2 changes: 2 additions & 0 deletions tunix/models/gemma3/params_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def create_model_from_safe_tensors(
config: model_lib.ModelConfig,
mesh: jax.sharding.Mesh | None = None,
dtype: jnp.dtype | None = None,
mode: str = "auto",
):
return safetensors_loader.load_and_create_model(
file_dir=file_dir,
Expand All @@ -319,4 +320,5 @@ def create_model_from_safe_tensors(
mesh=mesh,
preprocess_fn=_make_preprocess_fn(config),
dtype=dtype,
mode=mode,
)
2 changes: 2 additions & 0 deletions tunix/models/llama3/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def create_model_from_safe_tensors(
config: model_lib.ModelConfig,
mesh: jax.sharding.Mesh | None = None,
dtype: jnp.dtype | None = None,
mode: str = "auto",
) -> model_lib.Llama3:
"""Load tensors from the safetensors file and create a Llama3 model."""
return safetensors_loader.load_and_create_model(
Expand All @@ -92,4 +93,5 @@ def create_model_from_safe_tensors(
mesh=mesh,
preprocess_fn=None,
dtype=dtype,
mode=mode,
)
2 changes: 2 additions & 0 deletions tunix/models/qwen2/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def create_model_from_safe_tensors(
config: model_lib.ModelConfig,
mesh: jax.sharding.Mesh | None = None,
dtype: jnp.dtype | None = None,
mode: str = "auto",
) -> model_lib.Qwen2:
"""Load tensors from the safetensors file and create a Qwen2 model."""
return safetensors_loader.load_and_create_model(
Expand All @@ -96,4 +97,5 @@ def create_model_from_safe_tensors(
mesh=mesh,
preprocess_fn=None,
dtype=dtype,
mode=mode,
)
2 changes: 2 additions & 0 deletions tunix/models/qwen3/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def create_model_from_safe_tensors(
config: model_lib.ModelConfig,
mesh: jax.sharding.Mesh | None = None,
dtype: jnp.dtype | None = None,
mode: str = "auto",
) -> model_lib.Qwen3:
"""Load tensors from the safetensors file and create a Qwen3 model."""
return safetensors_loader.load_and_create_model(
Expand All @@ -126,6 +127,7 @@ def create_model_from_safe_tensors(
mesh=mesh,
preprocess_fn=_stack_experts,
dtype=dtype,
mode=mode,
)


Expand Down
Loading
Loading