5151__all__ = [
5252 "ModeloptStateManager" ,
5353 "apply_mode" ,
54+ "load_modelopt_state" ,
5455 "modelopt_state" ,
5556 "restore" ,
5657 "restore_from_modelopt_state" ,
@@ -512,7 +513,29 @@ def save(model: nn.Module, f: str | os.PathLike | BinaryIO, **kwargs) -> None:
512513 torch .save (ckpt_dict , f , ** kwargs )
513514
514515
515- def restore_from_modelopt_state (model : ModelLike , modelopt_state : dict [str , Any ]) -> nn .Module :
516+ def load_modelopt_state (modelopt_state_path : str | os .PathLike , ** kwargs ) -> dict [str , Any ]:
517+ """Load the modelopt state from a file.
518+
519+ Args:
520+ modelopt_state_path: Target file location.
521+ **kwargs: additional args for ``torch.load()``.
522+
523+ Returns:
524+ A modelopt state dictionary describing the modifications to the model.
525+ """
526+ # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
527+ kwargs .setdefault ("weights_only" , False )
528+ kwargs .setdefault ("map_location" , "cpu" )
529+ # TODO: Add some validation to ensure the file is a valid modelopt state file.
530+ modelopt_state = torch .load (modelopt_state_path , ** kwargs )
531+ return modelopt_state
532+
533+
534+ def restore_from_modelopt_state (
535+ model : ModelLike ,
536+ modelopt_state : dict [str , Any ] | None = None ,
537+ modelopt_state_path : str | os .PathLike | None = None ,
538+ ) -> nn .Module :
516539 """Restore the model architecture from the modelopt state dictionary based on the user-provided model.
517540
518541 This method does not restore the model parameters such as weights, biases and quantization scales.
@@ -526,10 +549,7 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]
526549 model = ... # Create the model-like object
527550
528551 # Restore the previously saved modelopt state followed by model weights
529- # Security NOTE: weights_only=False is used here on ModelOpt-generated state_dict, not on untrusted user input
530- mto.restore_from_modelopt_state(
531- model, torch.load("modelopt_state.pt", weights_only=False)
532- ) # Restore modelopt state
552+ mto.restore_from_modelopt_state(model, modelopt_state_path="modelopt_state.pt")
533553 model.load_state_dict(torch.load("model_weights.pt"), ...) # Load the model weights
534554
535555 If you want to restore the model weights and the modelopt state with saved scales, please use
@@ -543,11 +563,21 @@ def restore_from_modelopt_state(model: ModelLike, modelopt_state: dict[str, Any]
543563 modelopt_state: The modelopt state dict describing the modelopt modifications to the model. The
544564 ``modelopt_state`` can be generated via
545565 :meth:`mto.modelopt_state()<modelopt.torch.opt.conversion.modelopt_state>`.
566+ Cannot be used with modelopt_state_path.
567+ modelopt_state_path: The path to the modelopt state file.
568+ Cannot be used with modelopt_state.
546569
547570 Returns:
548571 A modified model architecture based on the restored modifications with the unmodified
549572 weights as stored in the provided ``model`` argument.
550573 """
574+ assert (modelopt_state is not None ) != (modelopt_state_path is not None ), (
575+ "Either modelopt_state or modelopt_state_path must be provided, but not both."
576+ )
577+ if modelopt_state_path is not None :
578+ modelopt_state = load_modelopt_state (modelopt_state_path )
579+ assert modelopt_state , "modelopt_state is required!"
580+
551581 # initialize ModelLikeModule if needed.
552582 model = model if isinstance (model , nn .Module ) else ModelLikeModule (model )
553583
0 commit comments