2323
2424import importlib .metadata
2525import itertools
26+ import pickle
27+ import warnings
2628from typing import (Callable , Dict , Iterable , List , Literal , Optional , Tuple ,
2729 Union )
2830
5052# windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072)
5153# on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn)
5254CEBRA_LOAD_SAFE_GLOBALS = [
53- cebra .data .Offset , torch .torch_version .TorchVersion , np .dtype ,
54- np .dtypes .Int32DType , np .dtypes .Float64DType , np .dtypes .Int64DType
55+ cebra .data .Offset ,
56+ torch .torch_version .TorchVersion ,
57+ np .dtype ,
58+ np .dtypes .Int32DType ,
59+ np .dtypes .Int64DType ,
60+ np .dtypes .Float32DType ,
61+ np .dtypes .Float64DType ,
5562]
5663
5764
@@ -62,20 +69,22 @@ def check_version(estimator):
6269 sklearn .__version__ ) < packaging .version .parse ("1.6.dev" )
6370
6471
65- def _safe_torch_load (filename , weights_only , ** kwargs ):
66- if weights_only is None :
67- if packaging .version .parse (
68- torch .__version__ ) >= packaging .version .parse ("2.6.0" ):
69- weights_only = True
70- else :
71- weights_only = False
72+ def _safe_torch_load (filename , weights_only = False , ** kwargs ):
73+ checkpoint = None
74+ legacy_mode = packaging .version .parse (
75+ torch .__version__ ) < packaging .version .parse ("2.6.0" )
7276
73- if not weights_only :
77+ if legacy_mode :
7478 checkpoint = torch .load (filename , weights_only = False , ** kwargs )
7579 else :
76- # NOTE(stes): This is only supported for torch 2.6+
7780 with torch .serialization .safe_globals (CEBRA_LOAD_SAFE_GLOBALS ):
78- checkpoint = torch .load (filename , weights_only = True , ** kwargs )
81+ checkpoint = torch .load (filename ,
82+ weights_only = weights_only ,
83+ ** kwargs )
84+
85+ if not isinstance (checkpoint , dict ):
86+ _check_type_checkpoint (checkpoint )
87+ checkpoint = checkpoint ._get_state_dict ()
7988
8089 return checkpoint
8190
@@ -315,8 +324,9 @@ def _require_arg(key):
315324
316325def _check_type_checkpoint (checkpoint ):
317326 if not isinstance (checkpoint , cebra .CEBRA ):
318- raise RuntimeError ("Model loaded from file is not compatible with "
319- "the current CEBRA version." )
327+ raise RuntimeError (
328+ "Model loaded from file is not compatible with "
329+ f"the current CEBRA version. Got: { type (checkpoint )} " )
320330 if not sklearn_utils .check_fitted (checkpoint ):
321331 raise ValueError (
322332 "CEBRA model is not fitted. Loading it is not supported." )
@@ -1336,6 +1346,26 @@ def _get_state(self):
13361346 }
13371347 return state
13381348
1349+ def _get_state_dict (self ):
1350+ backend = "sklearn"
1351+ return {
1352+ 'args' : self .get_params (),
1353+ 'state' : self ._get_state (),
1354+ 'state_dict' : self .solver_ .state_dict (),
1355+ 'metadata' : {
1356+ 'backend' :
1357+ backend ,
1358+ 'cebra_version' :
1359+ cebra .__version__ ,
1360+ 'torch_version' :
1361+ torch .__version__ ,
1362+ 'numpy_version' :
1363+ np .__version__ ,
1364+ 'sklearn_version' :
1365+ importlib .metadata .distribution ("scikit-learn" ).version
1366+ }
1367+ }
1368+
13391369 def save (self ,
13401370 filename : str ,
13411371 backend : Literal ["torch" , "sklearn" ] = "sklearn" ):
@@ -1384,28 +1414,16 @@ def save(self,
13841414 """
13851415 if sklearn_utils .check_fitted (self ):
13861416 if backend == "torch" :
1417+ warnings .warn (
1418+ "Saving with backend='torch' is deprecated and will be removed in a future version. "
1419+ "Please use backend='sklearn' instead." ,
1420+ DeprecationWarning ,
1421+ stacklevel = 2 ,
1422+ )
13871423 checkpoint = torch .save (self , filename )
13881424
13891425 elif backend == "sklearn" :
1390- checkpoint = torch .save (
1391- {
1392- 'args' : self .get_params (),
1393- 'state' : self ._get_state (),
1394- 'state_dict' : self .solver_ .state_dict (),
1395- 'metadata' : {
1396- 'backend' :
1397- backend ,
1398- 'cebra_version' :
1399- cebra .__version__ ,
1400- 'torch_version' :
1401- torch .__version__ ,
1402- 'numpy_version' :
1403- np .__version__ ,
1404- 'sklearn_version' :
1405- importlib .metadata .distribution ("scikit-learn"
1406- ).version
1407- }
1408- }, filename )
1426+ checkpoint = torch .save (self ._get_state_dict (), filename )
14091427 else :
14101428 raise NotImplementedError (f"Unsupported backend: { backend } " )
14111429 else :
@@ -1457,29 +1475,60 @@ def load(cls,
14571475 >>> tmp_file.unlink()
14581476 """
14591477 supported_backends = ["auto" , "sklearn" , "torch" ]
1478+
14601479 if backend not in supported_backends :
14611480 raise NotImplementedError (
14621481 f"Unsupported backend: '{ backend } '. Supported backends are: { ', ' .join (supported_backends )} "
14631482 )
14641483
1465- checkpoint = _safe_torch_load (filename , weights_only , ** kwargs )
1484+ if backend not in ["auto" , "sklearn" ]:
1485+ warnings .warn (
1486+ "From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; "
1487+ "the sklearn backend is now always used. Models saved with the torch backend can still be loaded." ,
1488+ category = DeprecationWarning ,
1489+ stacklevel = 2 ,
1490+ )
14661491
1467- if backend == "auto" :
1468- backend = "sklearn" if isinstance (checkpoint , dict ) else "torch"
1492+ backend = "sklearn"
1493+
1494+ # NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards,
1495+ # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
1496+ # introduced in torch 2.6.0.
1497+ try :
1498+ checkpoint = _safe_torch_load (filename , ** kwargs )
1499+ except pickle .UnpicklingError as e :
1500+ if weights_only is not False :
1501+ if packaging .version .parse (
1502+ cebra .__version__ ) < packaging .version .parse ("0.7" ):
1503+ warnings .warn (
1504+ "Failed to unpickle checkpoint with weights_only=True. "
1505+ "Falling back to loading with weights_only=False. "
1506+ "This is unsafe and should only be done if you trust the source of the model file. "
1507+ "In the future, loading these checkpoints will only work if weights_only=False is explicitly passed." ,
1508+ category = UserWarning ,
1509+ stacklevel = 2 ,
1510+ )
1511+ else :
1512+ raise ValueError (
1513+ "Failed to unpickle checkpoint with weights_only=True. "
1514+ "This may be due to an incompatible model file format. "
1515+ "To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. "
1516+ "Example: CEBRA.load(filename, weights_only=False)."
1517+ ) from e
14691518
1470- if isinstance (checkpoint , dict ) and backend == "torch" :
1471- raise RuntimeError (
1472- "Cannot use 'torch' backend with a dictionary-based checkpoint. "
1473- "Please try a different backend." )
1474- if not isinstance (checkpoint , dict ) and backend == "sklearn" :
1519+ checkpoint = _safe_torch_load (filename ,
1520+ weights_only = False ,
1521+ ** kwargs )
1522+
1523+ if backend != "sklearn" :
1524+ raise ValueError (f"Unsupported backend: { backend } " )
1525+
1526+ if not isinstance (checkpoint , dict ):
14751527 raise RuntimeError (
14761528 "Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1477- "Please try a different backend." )
1529+ f "Please try a different backend. Got: { type ( checkpoint ) } " )
14781530
1479- if backend == "sklearn" :
1480- cebra_ = _load_cebra_with_sklearn_backend (checkpoint )
1481- else :
1482- cebra_ = _check_type_checkpoint (checkpoint )
1531+ cebra_ = _load_cebra_with_sklearn_backend (checkpoint )
14831532
14841533 n_features = cebra_ .n_features_
14851534 cebra_ .solver_ .n_features = ([
0 commit comments