Skip to content

Commit 1523a74

Browse files
Move all insecure modelopt state loading to single function
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 08d35a7 commit 1523a74

File tree

8 files changed

+45
-27
lines changed

8 files changed

+45
-27
lines changed

docs/source/guides/2_save_load.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,7 @@ Here is the example workflow of restoring the ModelOpt-modified model architectu
129129
model = ...
130130
131131
# Restore the model architecture using the saved `modelopt_state`
132-
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
133-
modelopt_state = torch.load("modelopt_state.pth", weights_only=False)
134-
model = mto.restore_from_modelopt_state(model, modelopt_state)
132+
model = mto.restore_from_modelopt_state(model, modelopt_state_path="modelopt_state.pth")
135133
136134
# Load the model weights separately after restoring the model architecture
137135
custom_method_to_load_model_weights(model)

examples/diffusers/distillation/distillation_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,10 +591,9 @@ def _apply_modelopt_quantization(self) -> None:
591591
f"Resuming: restoring quantization architecture from "
592592
f"{modelopt_state_path} (weights loaded later by accelerator)"
593593
)
594-
# Security NOTE: weights_only=False is used on ModelOpt-generated state,
595-
# not on untrusted user input.
596-
state = torch.load(modelopt_state_path, weights_only=False, map_location="cpu")
597-
self._transformer = mto.restore_from_modelopt_state(self._transformer, state)
594+
self._transformer = mto.restore_from_modelopt_state(
595+
self._transformer, modelopt_state_path=modelopt_state_path
596+
)
598597
logger.info("Quantization architecture restored for resume")
599598
return
600599
else:

examples/llm_qat/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ torch.save(mto.modelopt_state(model), "modelopt_quantizer_states.pt")
8181

8282
# To resume training from a checkpoint or load the final QAT model for evaluation,
8383
# load the quantizer states before loading the model weights
84-
# mto.restore_from_modelopt_state(model, torch.load("modelopt_quantizer_states.pt", weights_only=False))
84+
# mto.restore_from_modelopt_state(model, modelopt_state_path="modelopt_quantizer_states.pt")
8585
# After loading the quantizer states, load the model weights
8686
# model.load_state_dict(state_dict_from_last_checkpoint)
8787

examples/llm_qat/export.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import warnings
1919
from pathlib import Path
2020

21-
import torch
2221
from transformers import AutoModelForCausalLM, AutoTokenizer
2322

2423
import modelopt.torch.opt as mto
@@ -51,8 +50,7 @@ def get_model(
5150

5251
# Restore modelopt state for LoRA models. For QAT/QAD models from_pretrained call handles this
5352
if hasattr(model, "peft_config"):
54-
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
55-
modelopt_state = torch.load(f"{ckpt_path}/modelopt_state_train.pth", weights_only=False)
53+
modelopt_state = mto.load_modelopt_state(f"{ckpt_path}/modelopt_state_train.pth")
5654
restore_from_modelopt_state(model, modelopt_state)
5755
print_rank_0("Restored modelopt state")
5856

modelopt/torch/opt/conversion.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
__all__ = [
5252
"ModeloptStateManager",
5353
"apply_mode",
54+
"load_modelopt_state",
5455
"modelopt_state",
5556
"restore",
5657
"restore_from_modelopt_state",
@@ -512,7 +513,29 @@ def save(model: nn.Module, f: str | os.PathLike | BinaryIO, **kwargs) -> None:
512513
torch.save(ckpt_dict, f, **kwargs)
513514

514515

515-
def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]) -> nn.Module:
516+
def load_modelopt_state(modelopt_state_path: str | os.PathLike, **kwargs) -> dict[str, Any]:
517+
"""Load the modelopt state from a file.
518+
519+
Args:
520+
modelopt_state_path: Target file location.
521+
**kwargs: additional args for ``torch.load()``.
522+
523+
Returns:
524+
A modelopt state dictionary describing the modifications to the model.
525+
"""
526+
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
527+
kwargs.setdefault("weights_only", False)
528+
kwargs.setdefault("map_location", "cpu")
529+
# TODO: Add some validation to ensure the file is a valid modelopt state file.
530+
modelopt_state = torch.load(modelopt_state_path, **kwargs)
531+
return modelopt_state
532+
533+
534+
def restore_from_modelopt_state(
535+
model: ModelLike,
536+
modelopt_state: dict[str, Any] | None = None,
537+
modelopt_state_path: str | os.PathLike | None = None,
538+
) -> nn.Module:
516539
"""Restore the model architecture from the modelopt state dictionary based on the user-provided model.
517540
518541
This method does not restore the model parameters such as weights, biases and quantization scales.
@@ -526,10 +549,7 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]
526549
model = ... # Create the model-like object
527550
528551
# Restore the previously saved modelopt state followed by model weights
529-
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
530-
mto.restore_from_modelopt_state(
531-
model, torch.load("modelopt_state.pt", weights_only=False)
532-
) # Restore modelopt state
552+
mto.restore_from_modelopt_state(model, modelopt_state_path="modelopt_state.pt")
533553
model.load_state_dict(torch.load("model_weights.pt"), ...) # Load the model weights
534554
535555
If you want to restore the model weights and the modelopt state with saved scales, please use
@@ -543,11 +563,21 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]
543563
modelopt_state: The modelopt state dict describing the modelopt modifications to the model. The
544564
``modelopt_state`` can be generated via
545565
:meth:`mto.modelopt_state()<modelopt.torch.opt.conversion.modelopt_state>`.
566+
Cannot be used with modelopt_state_path.
567+
modelopt_state_path: The path to the modelopt state file.
568+
Cannot be used with modelopt_state.
546569
547570
Returns:
548571
A modified model architecture based on the restored modifications with the unmodified
549572
weights as stored in the provided ``model`` argument.
550573
"""
574+
assert (modelopt_state is not None) != (modelopt_state_path is not None), (
575+
"Either modelopt_state or modelopt_state_path must be provided, but not both."
576+
)
577+
if modelopt_state_path is not None:
578+
modelopt_state = load_modelopt_state(modelopt_state_path)
579+
assert modelopt_state, "modelopt_state is required!"
580+
551581
# initialize ModelLikeModule if needed.
552582
model = model if isinstance(model, nn.Module) else ModelLikeModule(model)
553583

modelopt/torch/opt/plugins/huggingface.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ def new_init_fn(self, *args, **kwargs):
7979
modelopt_state_path = _get_modelopt_state_path(model_path)
8080
_original__init__(self, *args, **kwargs)
8181
if os.path.isfile(modelopt_state_path):
82-
# Security NOTE: weights_only=False is used on ModelOpt-generated state_dict, not on untrusted user input
83-
modelopt_state = torch.load(modelopt_state_path, map_location="cpu", weights_only=False)
8482
with extra_context() if extra_context else nullcontext():
85-
restore_from_modelopt_state(self, modelopt_state)
83+
restore_from_modelopt_state(self, modelopt_state_path=modelopt_state_path)
8684

8785
print_rank_0(f"Restored ModelOpt state from {modelopt_state_path}")
8886

modelopt/torch/opt/plugins/peft.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs):
7272
assert adapter_name in self.peft_config, (
7373
f"ModelOpt modified model should have adapter_name={adapter_name} in peft_config"
7474
)
75-
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
76-
restore_from_modelopt_state(
77-
self, torch.load(modelopt_state_path, map_location="cpu", weights_only=False)
78-
)
75+
restore_from_modelopt_state(self, modelopt_state_path=modelopt_state_path)
7976

8077
outputs = self._modelopt_cache["load_adapter"](self, model_id, adapter_name, *args, **kwargs)
8178

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import modelopt.torch.opt as mto
2929
import modelopt.torch.quantization as mtq
3030
from modelopt.torch.distill.plugins.huggingface import KDTrainer
31-
from modelopt.torch.opt.conversion import restore_from_modelopt_state
3231
from modelopt.torch.opt.plugins import ModelOptHFTrainer
3332
from modelopt.torch.utils import print_rank_0
3433

@@ -233,10 +232,9 @@ def _save_modelopt_state_with_weights(self):
233232
print_rank_0(f"Saved modelopt state to {self._modelopt_state_path}")
234233

235234
def _restore_modelopt_state_with_weights(self):
236-
# Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
237-
modelopt_state = torch.load(self._modelopt_state_path, weights_only=False)
235+
modelopt_state = mto.load_modelopt_state(self._modelopt_state_path)
238236
modelopt_weights = modelopt_state.pop("modelopt_state_weights", None)
239-
restore_from_modelopt_state(self.model, modelopt_state)
237+
mto.restore_from_modelopt_state(self.model, modelopt_state)
240238
if modelopt_weights is not None:
241239
set_quantizer_state_dict(self.model, modelopt_weights)
242240
print_rank_0("Restored modelopt state with weights.")

0 commit comments

Comments
 (0)