Skip to content

Commit ba3812f

Browse files
authored
Improve robustness of cebra model loading (#292)
1 parent 373351e commit ba3812f

File tree

4 files changed

+216
-76
lines changed

4 files changed

+216
-76
lines changed

.github/workflows/build.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,28 @@ jobs:
2929
# https://pytorch.org/get-started/previous-versions/
3030
torch-version: ["2.6.0", "2.10.0"]
3131
sklearn-version: ["latest"]
32+
numpy-version: ["latest"]
33+
3234
include:
3335
# windows test with standard config
3436
- os: windows-latest
3537
torch-version: 2.6.0
3638
python-version: "3.12"
3739
sklearn-version: "latest"
40+
numpy-version: "latest"
3841

3942
# legacy sklearn (several API differences)
4043
- os: ubuntu-latest
4144
torch-version: 2.6.0
4245
python-version: "3.12"
4346
sklearn-version: "legacy"
47+
numpy-version: "latest"
48+
49+
- os: ubuntu-latest
50+
torch-version: 2.6.0
51+
python-version: "3.12"
52+
sklearn-version: "latest"
53+
numpy-version: "legacy"
4454

4555
# TODO(stes): latest torch and python
4656
# requires a PyTables release compatible with
@@ -55,6 +65,7 @@ jobs:
5565
torch-version: 2.4.0
5666
python-version: "3.10"
5767
sklearn-version: "legacy"
68+
numpy-version: "latest"
5869

5970
runs-on: ${{ matrix.os }}
6071

@@ -88,6 +99,11 @@ jobs:
8899
run: |
89100
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'
90101
102+
- name: Check numpy legacy version
103+
if: matrix.numpy-version == 'legacy'
104+
run: |
105+
pip install "numpy<2" '.[dev,datasets,integrations]'
106+
91107
- name: Run the formatter
92108
run: |
93109
make format

cebra/integrations/sklearn/cebra.py

Lines changed: 95 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import importlib.metadata
2525
import itertools
26+
import pickle
27+
import warnings
2628
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
2729
Union)
2830

@@ -50,8 +52,13 @@
5052
# windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072)
5153
# on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn)
5254
CEBRA_LOAD_SAFE_GLOBALS = [
53-
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
54-
np.dtypes.Int32DType, np.dtypes.Float64DType, np.dtypes.Int64DType
55+
cebra.data.Offset,
56+
torch.torch_version.TorchVersion,
57+
np.dtype,
58+
np.dtypes.Int32DType,
59+
np.dtypes.Int64DType,
60+
np.dtypes.Float32DType,
61+
np.dtypes.Float64DType,
5562
]
5663

5764

@@ -62,20 +69,22 @@ def check_version(estimator):
6269
sklearn.__version__) < packaging.version.parse("1.6.dev")
6370

6471

65-
def _safe_torch_load(filename, weights_only, **kwargs):
66-
if weights_only is None:
67-
if packaging.version.parse(
68-
torch.__version__) >= packaging.version.parse("2.6.0"):
69-
weights_only = True
70-
else:
71-
weights_only = False
72+
def _safe_torch_load(filename, weights_only=False, **kwargs):
73+
checkpoint = None
74+
legacy_mode = packaging.version.parse(
75+
torch.__version__) < packaging.version.parse("2.6.0")
7276

73-
if not weights_only:
77+
if legacy_mode:
7478
checkpoint = torch.load(filename, weights_only=False, **kwargs)
7579
else:
76-
# NOTE(stes): This is only supported for torch 2.6+
7780
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
78-
checkpoint = torch.load(filename, weights_only=True, **kwargs)
81+
checkpoint = torch.load(filename,
82+
weights_only=weights_only,
83+
**kwargs)
84+
85+
if not isinstance(checkpoint, dict):
86+
_check_type_checkpoint(checkpoint)
87+
checkpoint = checkpoint._get_state_dict()
7988

8089
return checkpoint
8190

@@ -315,8 +324,9 @@ def _require_arg(key):
315324

316325
def _check_type_checkpoint(checkpoint):
317326
if not isinstance(checkpoint, cebra.CEBRA):
318-
raise RuntimeError("Model loaded from file is not compatible with "
319-
"the current CEBRA version.")
327+
raise RuntimeError(
328+
"Model loaded from file is not compatible with "
329+
f"the current CEBRA version. Got: {type(checkpoint)}")
320330
if not sklearn_utils.check_fitted(checkpoint):
321331
raise ValueError(
322332
"CEBRA model is not fitted. Loading it is not supported.")
@@ -1336,6 +1346,26 @@ def _get_state(self):
13361346
}
13371347
return state
13381348

1349+
def _get_state_dict(self):
1350+
backend = "sklearn"
1351+
return {
1352+
'args': self.get_params(),
1353+
'state': self._get_state(),
1354+
'state_dict': self.solver_.state_dict(),
1355+
'metadata': {
1356+
'backend':
1357+
backend,
1358+
'cebra_version':
1359+
cebra.__version__,
1360+
'torch_version':
1361+
torch.__version__,
1362+
'numpy_version':
1363+
np.__version__,
1364+
'sklearn_version':
1365+
importlib.metadata.distribution("scikit-learn").version
1366+
}
1367+
}
1368+
13391369
def save(self,
13401370
filename: str,
13411371
backend: Literal["torch", "sklearn"] = "sklearn"):
@@ -1384,28 +1414,16 @@ def save(self,
13841414
"""
13851415
if sklearn_utils.check_fitted(self):
13861416
if backend == "torch":
1417+
warnings.warn(
1418+
"Saving with backend='torch' is deprecated and will be removed in a future version. "
1419+
"Please use backend='sklearn' instead.",
1420+
DeprecationWarning,
1421+
stacklevel=2,
1422+
)
13871423
checkpoint = torch.save(self, filename)
13881424

13891425
elif backend == "sklearn":
1390-
checkpoint = torch.save(
1391-
{
1392-
'args': self.get_params(),
1393-
'state': self._get_state(),
1394-
'state_dict': self.solver_.state_dict(),
1395-
'metadata': {
1396-
'backend':
1397-
backend,
1398-
'cebra_version':
1399-
cebra.__version__,
1400-
'torch_version':
1401-
torch.__version__,
1402-
'numpy_version':
1403-
np.__version__,
1404-
'sklearn_version':
1405-
importlib.metadata.distribution("scikit-learn"
1406-
).version
1407-
}
1408-
}, filename)
1426+
checkpoint = torch.save(self._get_state_dict(), filename)
14091427
else:
14101428
raise NotImplementedError(f"Unsupported backend: {backend}")
14111429
else:
@@ -1457,29 +1475,60 @@ def load(cls,
14571475
>>> tmp_file.unlink()
14581476
"""
14591477
supported_backends = ["auto", "sklearn", "torch"]
1478+
14601479
if backend not in supported_backends:
14611480
raise NotImplementedError(
14621481
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
14631482
)
14641483

1465-
checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
1484+
if backend not in ["auto", "sklearn"]:
1485+
warnings.warn(
1486+
"From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; "
1487+
"the sklearn backend is now always used. Models saved with the torch backend can still be loaded.",
1488+
category=DeprecationWarning,
1489+
stacklevel=2,
1490+
)
14661491

1467-
if backend == "auto":
1468-
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"
1492+
backend = "sklearn"
1493+
1494+
# NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards,
1495+
# the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
1496+
# introduced in torch 2.6.0.
1497+
try:
1498+
checkpoint = _safe_torch_load(filename, **kwargs)
1499+
except pickle.UnpicklingError as e:
1500+
if weights_only is not False:
1501+
if packaging.version.parse(
1502+
cebra.__version__) < packaging.version.parse("0.7"):
1503+
warnings.warn(
1504+
"Failed to unpickle checkpoint with weights_only=True. "
1505+
"Falling back to loading with weights_only=False. "
1506+
"This is unsafe and should only be done if you trust the source of the model file. "
1507+
"In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.",
1508+
category=UserWarning,
1509+
stacklevel=2,
1510+
)
1511+
else:
1512+
raise ValueError(
1513+
"Failed to unpickle checkpoint with weights_only=True. "
1514+
"This may be due to an incompatible model file format. "
1515+
"To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. "
1516+
"Example: CEBRA.load(filename, weights_only=False)."
1517+
) from e
14691518

1470-
if isinstance(checkpoint, dict) and backend == "torch":
1471-
raise RuntimeError(
1472-
"Cannot use 'torch' backend with a dictionary-based checkpoint. "
1473-
"Please try a different backend.")
1474-
if not isinstance(checkpoint, dict) and backend == "sklearn":
1519+
checkpoint = _safe_torch_load(filename,
1520+
weights_only=False,
1521+
**kwargs)
1522+
1523+
if backend != "sklearn":
1524+
raise ValueError(f"Unsupported backend: {backend}")
1525+
1526+
if not isinstance(checkpoint, dict):
14751527
raise RuntimeError(
14761528
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
1477-
"Please try a different backend.")
1529+
f"Please try a different backend. Got: {type(checkpoint)}")
14781530

1479-
if backend == "sklearn":
1480-
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
1481-
else:
1482-
cebra_ = _check_type_checkpoint(checkpoint)
1531+
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
14831532

14841533
n_features = cebra_.n_features_
14851534
cebra_.solver_.n_features = ([

cebra/registry.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from __future__ import annotations
4747

4848
import fnmatch
49+
import functools
4950
import itertools
5051
import sys
5152
import textwrap
@@ -214,14 +215,29 @@ def _zip_dict(d):
214215
yield dict(zip(keys, combination))
215216

216217
def _create_class(cls, **default_kwargs):
218+
class_name = pattern.format(**default_kwargs)
217219

218-
@register(pattern.format(**default_kwargs), base=pattern)
220+
@register(class_name, base=pattern)
219221
class _ParametrizedClass(cls):
220222

221223
def __init__(self, *args, **kwargs):
222224
default_kwargs.update(kwargs)
223225
super().__init__(*args, **default_kwargs)
224226

227+
# Make the class pickleable by copying metadata from the base class
228+
# and registering it in the module namespace
229+
functools.update_wrapper(_ParametrizedClass, cls, updated=[])
230+
231+
# Set a unique qualname so pickle finds this class, not the base class
232+
unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}"
233+
_ParametrizedClass.__qualname__ = unique_name
234+
_ParametrizedClass.__name__ = unique_name
235+
236+
# Register in module namespace so pickle can find it via getattr
237+
parent_module = sys.modules.get(cls.__module__)
238+
if parent_module is not None:
239+
setattr(parent_module, unique_name, _ParametrizedClass)
240+
225241
def _parametrize(cls):
226242
for _default_kwargs in kwargs:
227243
_create_class(cls, **_default_kwargs)

0 commit comments

Comments
 (0)