Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions examples/apple/coreml/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ def parse_args() -> argparse.ArgumentParser:
parser.add_argument("--use_partitioner", action=argparse.BooleanOptionalAction)
parser.add_argument("--generate_etrecord", action=argparse.BooleanOptionalAction)
parser.add_argument("--save_processed_bytes", action=argparse.BooleanOptionalAction)
parser.add_argument(
"--checkpoint",
required=False,
default=None,
help="checkpoing for llama model",
)

parser.add_argument(
"--params",
required=False,
default=None,
help="params for llama model",
)


args = parser.parse_args()
return args
Expand Down Expand Up @@ -163,28 +177,40 @@ def generate_compile_specs_from_args(args):
f"Valid compute units are {valid_compute_units}."
)

model_config = {}
model_config["module_name"] = MODEL_NAME_TO_MODEL[args.model_name][0]
model_config["model_class_name"] = MODEL_NAME_TO_MODEL[args.model_name][1]

if args.model_name == "llama2":
if args.checkpoint:
model_config["checkpoint"] = args.checkpoint
if args.params:
model_config["params"] = args.params
model_config["use_kv_cache"] = True
model, example_inputs, _ = EagerModelFactory.create_model(
*MODEL_NAME_TO_MODEL[args.model_name]
**model_config
)

compile_specs = generate_compile_specs_from_args(args)
lowered_module = None

if args.use_partitioner:
model.eval()
exir_program_aten = torch.export.export(model, example_inputs)
edge_program_manager = exir.to_edge(exir_program_aten)
with torch.no_grad():
exir_program_aten = torch.export.export(model, example_inputs)
edge_program_manager = exir.to_edge(exir_program_aten, compile_config=_EDGE_COMPILE_CONFIG)
edge_copy = copy.deepcopy(edge_program_manager)
partitioner = CoreMLPartitioner(
skip_ops_for_coreml_delegation=None, compile_specs=compile_specs
)
delegated_program_manager = edge_program_manager.to_backend(partitioner)
exec_program = delegated_program_manager.to_executorch()
else:
lowered_module, edge_copy = lower_module_to_coreml(
module=model,
compile_specs=compile_specs,
)
with torch.no_grad():
lowered_module, edge_copy = lower_module_to_coreml(
module=model,
compile_specs=compile_specs,
)
exec_program = export_lowered_module_to_executorch_program(
lowered_module,
example_inputs,
Expand Down
Loading