diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index d669379b6a9..555d22f9e22 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -688,6 +688,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}") edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( + args=args, dtype_override=dtype_override, checkpoint=args.checkpoint, checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore @@ -1243,6 +1244,7 @@ def _load_llama_model( def _get_source_transforms( # noqa + args: argparse.Namespace, dtype_override: DType, *, checkpoint: Optional[str] = None,