Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
3dc5729
Test IBL extractors tests failing for PI update
alejoe91 Dec 29, 2025
d1a0532
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 6, 2026
79ca022
original commit - good times
m-beau Jan 6, 2026
22501da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2026
7ca3d35
good times - progress
m-beau Jan 7, 2026
e3f31bf
merge
m-beau Jan 7, 2026
ab0e8dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2026
fe0aaf1
Merge remote-tracking branch 'alessio/select_sorting_periods' into go…
m-beau Jan 7, 2026
7279b67
wip
alejoe91 Jan 7, 2026
13ebb8f
Merge remote-tracking branch 'alessio/select_sorting_periods' into go…
m-beau Jan 7, 2026
1962f21
Fix test for base sorting and propagate to basevector extension
alejoe91 Jan 7, 2026
7fbe160
wip
m-beau Jan 7, 2026
5645ee6
Merge branch 'select_sorting_periods' of https://github.com/alejoe91/…
m-beau Jan 7, 2026
528c82b
Fix tests in quailty metrics
alejoe91 Jan 8, 2026
fccdbe3
finished implementing good periods
m-beau Jan 8, 2026
7adab75
Merge branch 'select_sorting_periods' into goodtimes
alejoe91 Jan 8, 2026
f36c7fc
Some fixes
alejoe91 Jan 8, 2026
775dda7
Fix retrieval of spikevector features
alejoe91 Jan 8, 2026
6f02b7f
Merge branch 'select_sorting_periods' into goodtimes
alejoe91 Jan 8, 2026
15df754
Fix tests, saving and loading
alejoe91 Jan 8, 2026
40e3417
started working on get_data method for good periods
m-beau Jan 8, 2026
cdf7846
Solve conflicts, still wip
alejoe91 Jan 8, 2026
81d745e
done refactoring self.data serializable format and get_data method
m-beau Jan 8, 2026
93a53ca
credits
m-beau Jan 8, 2026
493d215
Make good_periods blazing fast!
alejoe91 Jan 9, 2026
a1fb167
Add credits
alejoe91 Jan 9, 2026
e8518b0
Solve conflicts
alejoe91 Jan 9, 2026
f6752ac
Fix tests
alejoe91 Jan 9, 2026
a251826
oups
alejoe91 Jan 9, 2026
983d255
Sam's review + implement select/merge/split data
alejoe91 Jan 9, 2026
c5dbb93
Rename to valid_unit_periods and wip widgets
alejoe91 Jan 12, 2026
f382f89
Fix imports
alejoe91 Jan 12, 2026
ad50845
Add widget and extend params
alejoe91 Jan 12, 2026
bb46f27
Update src/spikeinterface/core/sorting_tools.py
alejoe91 Jan 13, 2026
121a0b1
Apply suggestion from @chrishalcrow
alejoe91 Jan 13, 2026
cbf3213
refactor presence ratio and drift metrics to use periods properly
alejoe91 Jan 13, 2026
4409aa5
Fix rp_violations
alejoe91 Jan 13, 2026
71f8668
implement firing range and fix drift
alejoe91 Jan 13, 2026
1ea0d68
fix naming issue
alejoe91 Jan 13, 2026
a86c2d3
remove solved todos
alejoe91 Jan 13, 2026
d98ff66
sync with select_sorting_period PR
alejoe91 Jan 13, 2026
84da1a2
wip: test user defined
alejoe91 Jan 13, 2026
c539f58
wip: tests
alejoe91 Jan 13, 2026
d8e1f90
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jan 13, 2026
3f93f97
Implement select_segment_periods in core
alejoe91 Jan 13, 2026
cd85456
remove utils
alejoe91 Jan 13, 2026
7a42fe3
rebase on #4316
alejoe91 Jan 13, 2026
4f754cb
Merge with main
alejoe91 Jan 14, 2026
cbc0986
Fix import
alejoe91 Jan 14, 2026
56b672e
Merge branch 'select_sorting_periods_core' into select_sorting_periods
alejoe91 Jan 14, 2026
cd2ba0b
wip
alejoe91 Jan 14, 2026
046430e
fix import
alejoe91 Jan 14, 2026
bb86253
Add misc_metric changes
alejoe91 Jan 14, 2026
accbc31
Merge branch 'select_sorting_periods' into goodtimes
alejoe91 Jan 14, 2026
807f5c6
Add tests for user defined and combined
alejoe91 Jan 14, 2026
89d563b
Add to built_in extensions
alejoe91 Jan 14, 2026
50f33f0
fix tests
alejoe91 Jan 14, 2026
f2d48ba
Remove debug print
alejoe91 Jan 14, 2026
6b48730
Merge branch 'select_sorting_periods' into goodtimes
alejoe91 Jan 15, 2026
6b1284b
Merge branch 'goodtimes' of github.com:m-beau/spikeinterface into goo…
alejoe91 Jan 15, 2026
e173a63
wip: fix intervals
alejoe91 Jan 15, 2026
80bc50f
Change base_period_dtype order and fix select_sorting_periods array i…
alejoe91 Jan 15, 2026
4c8fa23
fix conflicts
alejoe91 Jan 15, 2026
e1f5bab
Merge metrics implementations
alejoe91 Jan 15, 2026
96e6a53
fix tests
alejoe91 Jan 15, 2026
3198911
Fix generation of bins
alejoe91 Jan 15, 2026
bbc28c5
Refactor generation of subperiods
alejoe91 Jan 15, 2026
9d2ad09
fix conflicts
alejoe91 Jan 15, 2026
8312db2
fix conflicts2
alejoe91 Jan 15, 2026
87fbe9a
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jan 16, 2026
7446a43
Use cached get_spike_vector_to_indices
alejoe91 Jan 16, 2026
873a687
Solve conflicts
alejoe91 Jan 16, 2026
bc91b81
fix conflicts3
alejoe91 Jan 16, 2026
51e906a
Fix error in merging
alejoe91 Jan 16, 2026
88da6fc
Merge branch 'select_sorting_periods' into goodtimes
alejoe91 Jan 16, 2026
ab5a771
fix conflicts
alejoe91 Jan 20, 2026
2209514
Add supports_periods in BaseMetric/Extension
alejoe91 Jan 20, 2026
b23c431
wip: test metrics with periods
alejoe91 Jan 20, 2026
6fb26a4
almost there?
alejoe91 Jan 20, 2026
0fe7f3e
Fix periods arg in MetricExtensions
alejoe91 Jan 20, 2026
f087e08
Make bin edges unique
alejoe91 Jan 20, 2026
e785b64
fix conflicts with selecto_sorting_periods
alejoe91 Jan 20, 2026
173e747
Add support_periods to spike train metrics and tests
alejoe91 Jan 21, 2026
066c378
Force NaN/-1 values for float/int metrics if num_spikes is 0
alejoe91 Jan 21, 2026
65e1848
Fix test_empty_units: -1 is a valid value for ints
alejoe91 Jan 21, 2026
f1c4682
Fix firing range if unit samples < bin samples
alejoe91 Jan 21, 2026
3291638
fix noise_cutoff if empty units
alejoe91 Jan 21, 2026
b5bf3c3
Move warnings at the end of the loop for firing range and drift
alejoe91 Jan 21, 2026
8aeedcc
clean up tests and add get_available_metric_names
alejoe91 Jan 22, 2026
d4db43c
simplify total samples
alejoe91 Jan 22, 2026
d0a1e66
Go back to Pierre's implementation for drifts
alejoe91 Jan 22, 2026
6926532
Merge branch 'select_sorting_periods' into goodtimes
alejoe91 Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
from collections import namedtuple

from .numpyextractors import NumpySorting
from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
from .recording_tools import get_noise_levels
Expand Down Expand Up @@ -823,10 +824,9 @@ class BaseMetric:
metric_columns = {} # column names and their dtypes of the dataframe
metric_descriptions = {} # descriptions of each metric column
needs_recording = False # whether the metric needs recording
needs_tmp_data = (
False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level
)
needs_job_kwargs = False
needs_tmp_data = False # whether the metric needs temporary data computed with MetricExtension._prepare_data
needs_job_kwargs = False # whether the metric needs job_kwargs
supports_periods = False # whether the metric function supports periods
depend_on = [] # extensions the metric depends on

# the metric function must have the signature:
Expand All @@ -839,7 +839,7 @@ class BaseMetric:
metric_function = None # to be defined in subclass

@classmethod
def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs):
def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs, periods=None):
"""Compute the metric.

Parameters
Expand All @@ -854,6 +854,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs
Temporary data to pass to the metric function
job_kwargs : dict
Job keyword arguments to control parallelization
periods : np.ndarray | None
Numpy array of unit periods of unit_period_dtype if supports_periods is True

Returns
-------
Expand All @@ -865,6 +867,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs
args += (tmp_data,)
if cls.needs_job_kwargs:
args += (job_kwargs,)
if cls.supports_periods:
args += (periods,)

results = cls.metric_function(*args, **metric_params)

Expand Down Expand Up @@ -897,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension):
need_backward_compatibility_on_load = False
metric_list: list[BaseMetric] = None # list of BaseMetric

@classmethod
def get_available_metric_names(cls):
"""Get the available metric names.

Returns
-------
available_metric_names : list[str]
List of available metric names.
"""
return [m.metric_name for m in cls.metric_list]

@classmethod
def get_default_metric_params(cls):
"""Get the default metric parameters.
Expand Down Expand Up @@ -988,6 +1003,7 @@ def _set_params(
metric_params: dict | None = None,
delete_existing_metrics: bool = False,
metrics_to_compute: list[str] | None = None,
periods: np.ndarray | None = None,
**other_params,
):
"""
Expand All @@ -1004,6 +1020,8 @@ def _set_params(
If True, existing metrics in the extension will be deleted before computing new ones.
metrics_to_compute : list[str] | None
List of metric names to compute. If None, all metrics in `metric_names` are computed.
periods : np.ndarray | None
Numpy array of unit_period_dtype defining periods to compute metrics over.
other_params : dict
Additional parameters for metric computation.

Expand Down Expand Up @@ -1079,6 +1097,7 @@ def _set_params(
metrics_to_compute=metrics_to_compute,
delete_existing_metrics=delete_existing_metrics,
metric_params=metric_params,
periods=periods,
**other_params,
)
return params
Expand Down Expand Up @@ -1129,6 +1148,8 @@ def _compute_metrics(
if metric_names is None:
metric_names = self.params["metric_names"]

periods = self.params.get("periods", None)

column_names_dtypes = {}
for metric_name in metric_names:
metric = [m for m in self.metric_list if m.metric_name == metric_name][0]
Expand All @@ -1153,6 +1174,7 @@ def _compute_metrics(
metric_params=metric_params,
tmp_data=tmp_data,
job_kwargs=job_kwargs,
periods=periods,
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name}: {e}")
Expand All @@ -1179,6 +1201,7 @@ def _run(self, **job_kwargs):

metrics_to_compute = self.params["metrics_to_compute"]
delete_existing_metrics = self.params["delete_existing_metrics"]
periods = self.params.get("periods", None)

_, job_kwargs = split_job_kwargs(job_kwargs)
job_kwargs = fix_job_kwargs(job_kwargs)
Expand Down Expand Up @@ -1445,13 +1468,22 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}"

all_data = self.data[return_data_name]
keep_mask = None
if periods is not None:
keep_mask = select_sorting_periods_mask(
self.sorting_analyzer.sorting,
periods,
)
all_data = all_data[keep_mask]
# since we have the mask already, we can use it directly to avoid double computation
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=True)
sliced_spike_vector = spike_vector[keep_mask]
sorting = NumpySorting(
sliced_spike_vector,
sampling_frequency=self.sorting_analyzer.sampling_frequency,
unit_ids=self.sorting_analyzer.unit_ids,
)
else:
sorting = self.sorting_analyzer.sorting

if outputs == "numpy":
if copy:
Expand All @@ -1460,16 +1492,16 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
return all_data
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)

if keep_mask is not None:
# since we are filtering spikes, we need to recompute the spike indices
spike_vector = spike_vector[keep_mask]
spike_vector = sorting.to_spike_vector(concatenated=False)
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
else:
# use the cache of indices
spike_indices = self.sorting_analyzer.sorting.get_spike_vector_to_indices()
data_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
for segment_index in range(self.sorting_analyzer.get_num_segments()):
data_by_units[segment_index] = {}
for unit_id in unit_ids:
inds = spike_indices[segment_index][unit_id]
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,6 +2702,7 @@ def _save_data(self):
extension_group.create_dataset(
name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON()
)
extension_group[ext_data_name].attrs["dict"] = True
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options)
elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame):
Expand Down Expand Up @@ -2884,6 +2885,7 @@ def set_data(self, ext_data_name, ext_data):
"spike_locations": "spikeinterface.postprocessing",
"template_similarity": "spikeinterface.postprocessing",
"unit_locations": "spikeinterface.postprocessing",
"valid_unit_periods": "spikeinterface.postprocessing",
# from metrics
"quality_metrics": "spikeinterface.metrics",
"template_metrics": "spikeinterface.metrics",
Expand Down
81 changes: 79 additions & 2 deletions src/spikeinterface/metrics/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,85 @@
import pytest

from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer
from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
)

job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def make_small_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[10.0],
num_units=10,
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"])

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
"random_spikes": {"seed": 1205},
"noise_levels": {"seed": 1205},
"waveforms": {},
"templates": {"operators": ["average", "median"]},
"spike_amplitudes": {},
"spike_locations": {},
"principal_components": {},
}

sorting_analyzer.compute(extensions_to_compute)

return sorting_analyzer


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()
return make_small_analyzer()


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=1205,
)

channel_ids_as_integers = [id for id in range(recording.get_num_channels())]
unit_ids_as_integers = [id for id in range(sorting.get_num_units())]
recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers)
sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs)

return sorting_analyzer
1 change: 1 addition & 0 deletions src/spikeinterface/metrics/quality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
compute_sliding_rp_violations,
compute_sd_ratio,
compute_synchrony_metrics,
compute_refrac_period_violations,
)
Loading
Loading