Skip to content

Commit 373351e

Browse files
authored
Add compatibility fixes for pandas 3 and torch 2.10 (#289)
1 parent bc2d925 commit 373351e

File tree

6 files changed

+53
-2
lines changed

6 files changed

+53
-2
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
# We aim to support the versions on pytorch.org
2828
# as well as selected previous versions on
2929
# https://pytorch.org/get-started/previous-versions/
30-
torch-version: ["2.6.0", "2.9.1"]
30+
torch-version: ["2.6.0", "2.10.0"]
3131
sklearn-version: ["latest"]
3232
include:
3333
# windows test with standard config

cebra/integrations/deeplabcut.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ def load_data(self, pcutoff: float = 0.6) -> npt.NDArray:
169169
pred_xy = []
170170
for i, _ in enumerate(self.dlc_df.index):
171171
data = (self.dlc_df.iloc[i].loc[self.scorer].loc[
172-
self.keypoints_list].to_numpy().reshape(-1, len(dlc_df_coords)))
172+
self.keypoints_list].to_numpy().copy().reshape(
173+
-1, len(dlc_df_coords)))
173174

174175
# Handles nan values with interpolation
175176
if i > 0 and i < len(self.dlc_df) - 1:

cebra/integrations/sklearn/cebra.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,8 @@ def transform(self,
12531253

12541254
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
12551255

1256+
X = cebra_sklearn_dataset._ensure_writable(X)
1257+
12561258
if isinstance(X, np.ndarray):
12571259
X = torch.from_numpy(X)
12581260

cebra/integrations/sklearn/dataset.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#
2222
"""Datasets to be used as part of the sklearn framework."""
2323

24+
import traceback
25+
import warnings
2426
from typing import Iterable, Optional
2527

2628
import numpy as np
@@ -34,6 +36,28 @@
3436
import cebra.solver
3537

3638

39+
def _ensure_writable(array: npt.NDArray) -> npt.NDArray:
40+
if not array.flags.writeable:
41+
stack = traceback.extract_stack()[-5:-1]
42+
stack_str = ''.join(traceback.format_list(stack[-4:]))
43+
44+
warnings.warn(
45+
("You passed a non-writable Numpy array to CEBRA. Pytorch does currently "
46+
"not support non-writable tensors. As a result, CEBRA needs to copy the "
47+
"contents of the array, which might yield unnecessary memory overhead. "
48+
"Ideally, adapt the code such that the array you pass to CEBRA is writable "
49+
"to make your code memory efficient. "
50+
"You can find more context and the rationale for this fix here: "
51+
"https://github.com/AdaptiveMotorControlLab/CEBRA/pull/289."
52+
"\n\n"
53+
"Trace:\n" + stack_str),
54+
UserWarning,
55+
stacklevel=2,
56+
)
57+
array = array.copy()
58+
return array
59+
60+
3761
class SklearnDataset(cebra.data.SingleSessionDataset):
3862
"""Dataset for wrapping array-like input/index pairs.
3963
@@ -110,6 +134,7 @@ def _parse_data(self, X: npt.NDArray):
110134
# one sample is a conservative default here to ensure that sklearn tests
111135
# passes with the correct error messages.
112136
X = cebra_sklearn_utils.check_input_array(X, min_samples=2)
137+
X = _ensure_writable(X)
113138
self.neural = torch.from_numpy(X).float().to(self.device)
114139

115140
def _parse_labels(self, labels: Optional[tuple]):
@@ -143,6 +168,8 @@ def _parse_labels(self, labels: Optional[tuple]):
143168
f"or lists that can be converted to arrays, but got {type(y)}"
144169
)
145170

171+
y = _ensure_writable(y)
172+
146173
# Define the index as either continuous or discrete indices, depending
147174
# on the dtype in the index array.
148175
if cebra.helper._is_floating(y):

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ markers = [
3939
"cuda",
4040
]
4141
addopts = "--ignore=cebra/integrations/threejs --ignore=cebra/integrations/streamlit.py --ignore=cebra/datasets"
42+
# NOTE(stes): See https://github.com/AdaptiveMotorControlLab/CEBRA/pull/289.
43+
filterwarnings = [
44+
"error:The given NumPy array is not writable.*PyTorch does not support non-writable tensors:UserWarning",
45+
]
4246

4347

4448

tests/test_sklearn.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,3 +1544,20 @@ def test_last_incomplete_batch_smaller_than_offset():
15441544
model.fit(train.neural, train.continuous)
15451545

15461546
_ = model.transform(train.neural, batch_size=300)
1547+
1548+
1549+
def test_non_writable_array():
1550+
X = np.random.randn(100, 10)
1551+
y = np.random.randn(100, 2)
1552+
X.setflags(write=False)
1553+
y.setflags(write=False)
1554+
with pytest.raises(ValueError, match="assignment destination is read-only"):
1555+
X[:] = 0
1556+
y[:] = 0
1557+
1558+
cebra_model = cebra.CEBRA(max_iterations=2, batch_size=32, device="cpu")
1559+
1560+
cebra_model.fit(X, y)
1561+
embedding = cebra_model.transform(X)
1562+
assert isinstance(embedding, np.ndarray)
1563+
assert embedding.shape[0] == X.shape[0]

0 commit comments

Comments
 (0)