diff --git a/doc/how_to/import_kilosort_data.rst b/doc/how_to/import_kilosort_data.rst index dad522334a..ed7bf7b6c0 100644 --- a/doc/how_to/import_kilosort_data.rst +++ b/doc/how_to/import_kilosort_data.rst @@ -49,7 +49,7 @@ If you'd like to store the information you've computed, you can save the analyze ) You now have a fully functional ``SortingAnalyzer`` - congrats! You can now use `spikeinterface-gui `__. to view the results -interactively, or start manually labelling your units to `create an automated curation model `__. +interactively, or start manually labeling your units to `create an automated curation model `__. Note that if you have access to the raw recording, you can attach it to the analyzer, and re-compute extensions from the raw data. E.g. diff --git a/doc/modules/metrics.rst b/doc/modules/metrics.rst index d3f3512670..fe5b631f19 100644 --- a/doc/modules/metrics.rst +++ b/doc/modules/metrics.rst @@ -28,14 +28,32 @@ metric information. For example, you can get the list of available metrics using .. code-block:: Available metric columns: - ['peak_to_valley', 'peak_trough_ratio', 'half_width', 'repolarization_slope', - 'recovery_slope', 'num_positive_peaks', 'num_negative_peaks', 'velocity_above', - 'velocity_below', 'exp_decay', 'spread'] + [ + 'peak_to_trough_duration', + 'half_width', + 'repolarization_slope', + 'recovery_slope', + 'num_positive_peaks', + 'num_negative_peaks', + 'main_to_next_peak_duration', + 'peak_before_to_trough_ratio', + 'peak_after_to_trough_ratio', + 'peak_before_to_peak_after_ratio', + 'main_peak_to_trough_ratio', + 'trough_width', + 'peak_before_width', + 'peak_after_width', + 'waveform_baseline_flatness', + 'velocity_above', + 'velocity_below', + 'exp_decay', + 'spread' + ] .. code-block:: python - metric_descriptions = ComputeTemplateMetrics.get_metric_descriptions() + metric_descriptions = ComputeTemplateMetrics.get_metric_column_descriptions() print("Metric descriptions: ") print(metric_descriptions) @@ -44,21 +62,30 @@ metric information. For example, you can get the list of available metrics using Metric descriptions: { - 'peak_to_valley': 'Duration in s between the trough (minimum) and the peak (maximum) of the spike waveform.', - 'peak_trough_ratio': 'Ratio of the amplitude of the peak (maximum) to the trough (minimum) of the spike waveform.', + 'peak_to_trough_duration': 'Duration in seconds between the trough (minimum) and the peak (maximum) of the spike waveform.', 'half_width': 'Duration in s at half the amplitude of the trough (minimum) of the spike waveform.', 'repolarization_slope': 'Slope of the repolarization phase of the spike waveform, between the trough (minimum) and return to baseline in uV/s.', 'recovery_slope': 'Slope of the recovery phase of the spike waveform, after the peak (maximum) returning to baseline in uV/s.', 'num_positive_peaks': 'Number of positive peaks in the template', - 'num_negative_peaks': 'Number of negative peaks in the template', + 'num_negative_peaks': 'Number of negative peaks (troughs) in the template', + 'main_to_next_peak_duration': 'Duration in seconds from main extremum to next extremum.', + 'peak_before_to_trough_ratio': 'Ratio of peak before amplitude to trough amplitude', + 'peak_after_to_trough_ratio': 'Ratio of peak after amplitude to trough amplitude', + 'peak_before_to_peak_after_ratio': 'Ratio of peak before amplitude to peak after amplitude', + 'main_peak_to_trough_ratio': 'Ratio of main peak amplitude to trough amplitude', + 'trough_width': 'Width of the main trough in seconds', + 'peak_before_width': 'Width of the main peak before trough in seconds', + 'peak_after_width': 'Width of the main peak after trough in seconds', + 'waveform_baseline_flatness': 'Ratio of max baseline amplitude to max waveform amplitude. Lower = flatter baseline.', 'velocity_above': 'Velocity of the spike propagation above the max channel in um/ms', 'velocity_below': 'Velocity of the spike propagation below the max channel in um/ms', - 'exp_decay': 'Exponential decay of the template amplitude over distance from the extremum channel (1/um).', + 'exp_decay': 'Spatial decay of the template amplitude over distance from the extremum channel (1/um). Uses exponential or linear fit based on linear_fit parameter.', 'spread': 'Spread of the template amplitude in um, calculated as the distance between channels whose templates exceed the spread_threshold.' } + .. toctree:: :caption: Metrics submodules :maxdepth: 1 diff --git a/src/spikeinterface/comparison/comparisontools.py b/src/spikeinterface/comparison/comparisontools.py index dd06e60458..1771ba5c16 100644 --- a/src/spikeinterface/comparison/comparisontools.py +++ b/src/spikeinterface/comparison/comparisontools.py @@ -569,7 +569,7 @@ def make_hungarian_match(agreement_scores, min_score): def do_score_labels(sorting1, sorting2, delta_frames, unit_map12, label_misclassification=False): """ - Makes the labelling at spike level for each spike train: + Makes the labeling at spike level for each spike train: * TP: true positive * CL: classification error * FN: False negative diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 30038bc270..2c551b5342 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -997,6 +997,38 @@ def get_optional_dependencies(cls, **params): depend_on = list(cls.depend_on) + list(metric_depend_on) return depend_on + def get_computed_metric_names(self): + """ + Get the list of already computed metric names. + + Returns + ------- + computed_metric_names : list[str] + List of computed metric names. + """ + if self.data is None or len(self.data) == 0: + return [] + else: + computed_metric_columns = self.data["metrics"].columns.tolist() + computed_metric_names = [] + for m in self.metric_list: + if all(col in computed_metric_columns for col in m.metric_columns.keys()): + computed_metric_names.append(m.metric_name) + return computed_metric_names + + def _cast_metrics(self, metrics_df): + metric_dtypes = {} + for m in self.metric_list: + metric_dtypes.update(m.metric_columns) + + for col in metrics_df.columns: + if col in metric_dtypes: + try: + metrics_df[col] = metrics_df[col].astype(metric_dtypes[col]) + except Exception as e: + print(f"Error casting column {col}: {e}") + return metrics_df + def _set_params( self, metric_names: list[str] | None = None, @@ -1155,6 +1187,13 @@ def _compute_metrics( metric = [m for m in self.metric_list if m.metric_name == metric_name][0] column_names_dtypes.update(metric.metric_columns) + # drop metric that don't map to any metric names + possible_metric_names = [m.metric_name for m in self.metric_list] + wrong_metric_names = [m for m in metric_names if m not in possible_metric_names] + if len(wrong_metric_names) > 0: + warnings.warn(f"The following metric names are not recognized and will be ignored: {wrong_metric_names}") + metric_names = [m for m in metric_names if m in possible_metric_names] + metrics = pd.DataFrame(index=unit_ids, columns=list(column_names_dtypes.keys())) run_times = {} diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index 119ab1d598..d5ba74cfd9 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -189,7 +189,7 @@ def test_aggregation_labeling_for_lists(): assert np.all(user_group_property == [6, 6, 7, 7]) -def test_aggretion_labelling_for_dicts(): +def test_aggretion_labeling_for_dicts(): """Aggregated dicts of recordings get different labels depending on their underlying `property`s""" recording1 = generate_recording(num_channels=4, durations=[20], set_probe=False) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index b64070e662..56de464cee 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -20,5 +20,12 @@ from .sortingview_curation import apply_sortingview_curation # automated curation +from .bombcell_curation import ( + bombcell_get_default_thresholds, + bombcell_label_units, + get_bombcell_labeling_summary, + save_bombcell_results, +) + from .model_based_curation import auto_label_units, load_model from .train_manual_curation import train_model, get_default_classifier_search_spaces diff --git a/src/spikeinterface/curation/bombcell_curation.py b/src/spikeinterface/curation/bombcell_curation.py new file mode 100644 index 0000000000..84a954ffa6 --- /dev/null +++ b/src/spikeinterface/curation/bombcell_curation.py @@ -0,0 +1,343 @@ +""" +Unit labeling based on quality metrics (bombcell). + +Unit Types: + 0 (NOISE): Failed waveform quality checks + 1 (GOOD): Passed all thresholds + 2 (MUA): Failed spike quality checks + 3 (NON_SOMA): Non-somatic units (axonal) +""" + +from __future__ import annotations + +import numpy as np +from typing import Optional + +NOISE_METRICS = [ + "num_positive_peaks", + "num_negative_peaks", + "peak_to_trough_duration", + "waveform_baseline_flatness", + "peak_after_to_trough_ratio", + "exp_decay", +] + +SPIKE_QUALITY_METRICS = [ + "amplitude_median", + "snr", + "amplitude_cutoff", + "num_spikes", + "rp_contamination", + "presence_ratio", + "drift_ptp", +] + +NON_SOMATIC_METRICS = [ + "peak_before_to_trough_ratio", + "peak_before_width", + "trough_width", + "peak_before_to_peak_after_ratio", + "main_peak_to_trough_ratio", +] + + +def bombcell_get_default_thresholds() -> dict: + """ + bombcell - Returns default thresholds for unit labeling. + + Each metric has 'min' and 'max' values. Use None to disable a threshold (e.g. to ignore a metric completely + or to only have a min or a max threshold) + """ + # bombcell + return { + # Waveform quality (failures -> NOISE) + "num_positive_peaks": {"min": None, "max": 2}, + "num_negative_peaks": {"min": None, "max": 1}, + "peak_to_trough_duration": {"min": 0.0001, "max": 0.00115}, # seconds + "waveform_baseline_flatness": {"min": None, "max": 0.5}, + "peak_after_to_trough_ratio": {"min": None, "max": 0.8}, + "exp_decay": {"min": 0.01, "max": 0.1}, + # Spike quality (failures -> MUA) + "amplitude_median": {"min": 40, "max": None}, # uV + "snr_baseline": {"min": 5, "max": None}, + "amplitude_cutoff": {"min": None, "max": 0.2}, + "num_spikes": {"min": 300, "max": None}, + "rp_contamination": {"min": None, "max": 0.1}, + "presence_ratio": {"min": 0.7, "max": None}, + "drift_ptp": {"min": None, "max": 100}, # um + # Non-somatic detection + "peak_before_to_trough_ratio": {"min": None, "max": 3}, + "peak_before_width": {"min": 0.00015, "max": None}, # seconds + "trough_width": {"min": 0.0002, "max": None}, # seconds + "peak_before_to_peak_after_ratio": {"min": None, "max": 3}, + "main_peak_to_trough_ratio": {"min": None, "max": 0.8}, + } + + +def _is_threshold_disabled(value): + """Check if a threshold value is disabled (None or np.nan).""" + if value is None: + return True + if isinstance(value, float) and np.isnan(value): + return True + return False + + +def bombcell_label_units( + sorting_analyzer=None, + thresholds: Optional[dict] = None, + label_non_somatic: bool = True, + split_non_somatic_good_mua: bool = False, + external_metrics: Optional["pd.DataFrame | list[pd.DataFrame]"] = None, +) -> tuple[np.ndarray, np.ndarray]: + """ + bombcell - label units based on quality metrics and thresholds. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer, optional + SortingAnalyzer with computed quality_metrics and/or template_metrics extensions. + If provided, metrics are extracted automatically using get_metrics_extension_data(). + thresholds : dict or None + Threshold dict: {"metric": {"min": val, "max": val}}. Use None to disable. + label_non_somatic : bool + If True, detect non-somatic (axonal) units. + split_non_somatic_good_mua : bool + If True, split non-somatic into NON_SOMA_GOOD (3) and NON_SOMA_MUA (4). + external_metrics: Optional[pd.DataFrame | list[pd.DataFrame]] = None + External metrics DataFrame(s) (index = unit_ids) to use instead of those from SortingAnalyzer. + + Returns + ------- + unit_type : np.ndarray + Numeric: 0=NOISE, 1=GOOD, 2=MUA, 3=NON_SOMA + unit_type_string : np.ndarray + String labels. + """ + import pandas as pd + + if sorting_analyzer is not None: + combined_metrics = sorting_analyzer.get_metrics_extension_data() + if combined_metrics.empty: + raise ValueError( + "SortingAnalyzer has no metrics extensions computed. " + "Compute quality_metrics and/or template_metrics first." + ) + else: + if external_metrics is None: + raise ValueError("Either sorting_analyzer or external_metrics must be provided") + if isinstance(external_metrics, list): + assert all( + isinstance(df, pd.DataFrame) for df in external_metrics + ), "All items in external_metrics must be DataFrames" + combined_metrics = pd.concat(external_metrics, axis=1) + else: + combined_metrics = external_metrics + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + + n_units = len(combined_metrics) + unit_type = np.full(n_units, np.nan) + absolute_value_metrics = ["amplitude_median"] + + # NOISE: waveform failures + noise_mask = np.zeros(n_units, dtype=bool) + for metric_name in NOISE_METRICS: + if metric_name not in combined_metrics.columns or metric_name not in thresholds: + continue + values = combined_metrics[metric_name].values + if metric_name in absolute_value_metrics: + values = np.abs(values) + thresh = thresholds[metric_name] + noise_mask |= np.isnan(values) + if not _is_threshold_disabled(thresh["min"]): + noise_mask |= values < thresh["min"] + if not _is_threshold_disabled(thresh["max"]): + noise_mask |= values > thresh["max"] + unit_type[noise_mask] = 0 + + # MUA: spike quality failures + mua_mask = np.zeros(n_units, dtype=bool) + for metric_name in SPIKE_QUALITY_METRICS: + if metric_name not in combined_metrics.columns or metric_name not in thresholds: + continue + values = combined_metrics[metric_name].values + if metric_name in absolute_value_metrics: + values = np.abs(values) + thresh = thresholds[metric_name] + valid_mask = np.isnan(unit_type) + if not _is_threshold_disabled(thresh["min"]): + mua_mask |= valid_mask & ~np.isnan(values) & (values < thresh["min"]) + if not _is_threshold_disabled(thresh["max"]): + mua_mask |= valid_mask & ~np.isnan(values) & (values > thresh["max"]) + unit_type[mua_mask & np.isnan(unit_type)] = 2 + + # GOOD: passed all checks + unit_type[np.isnan(unit_type)] = 1 + + # NON-SOMATIC + if label_non_somatic: + + def get_metric(name): + if name in combined_metrics.columns: + return combined_metrics[name].values + return np.full(n_units, np.nan) + + peak_before_width = get_metric("peak_before_width") + trough_width = get_metric("trough_width") + width_thresh_peak = thresholds.get("peak_before_width", {}).get("min", None) + width_thresh_trough = thresholds.get("trough_width", {}).get("min", None) + + narrow_peak = ( + ~np.isnan(peak_before_width) & (peak_before_width < width_thresh_peak) + if not _is_threshold_disabled(width_thresh_peak) + else np.zeros(n_units, dtype=bool) + ) + narrow_trough = ( + ~np.isnan(trough_width) & (trough_width < width_thresh_trough) + if not _is_threshold_disabled(width_thresh_trough) + else np.zeros(n_units, dtype=bool) + ) + width_conditions = narrow_peak & narrow_trough + + peak_before_to_trough = get_metric("peak_before_to_trough_ratio") + peak_before_to_peak_after = get_metric("peak_before_to_peak_after_ratio") + main_peak_to_trough = get_metric("main_peak_to_trough_ratio") + + ratio_thresh_pbt = thresholds.get("peak_before_to_trough_ratio", {}).get("max", None) + ratio_thresh_pbpa = thresholds.get("peak_before_to_peak_after_ratio", {}).get("max", None) + ratio_thresh_mpt = thresholds.get("main_peak_to_trough_ratio", {}).get("max", None) + + large_initial_peak = ( + ~np.isnan(peak_before_to_trough) & (peak_before_to_trough > ratio_thresh_pbt) + if not _is_threshold_disabled(ratio_thresh_pbt) + else np.zeros(n_units, dtype=bool) + ) + large_peak_ratio = ( + ~np.isnan(peak_before_to_peak_after) & (peak_before_to_peak_after > ratio_thresh_pbpa) + if not _is_threshold_disabled(ratio_thresh_pbpa) + else np.zeros(n_units, dtype=bool) + ) + large_main_peak = ( + ~np.isnan(main_peak_to_trough) & (main_peak_to_trough > ratio_thresh_mpt) + if not _is_threshold_disabled(ratio_thresh_mpt) + else np.zeros(n_units, dtype=bool) + ) + + # (ratio AND width) OR standalone main_peak_to_trough + ratio_conditions = large_initial_peak | large_peak_ratio + is_non_somatic = (ratio_conditions & width_conditions) | large_main_peak + + if split_non_somatic_good_mua: + unit_type[(unit_type == 1) & is_non_somatic] = 3 + unit_type[(unit_type == 2) & is_non_somatic] = 4 + else: + unit_type[(unit_type != 0) & is_non_somatic] = 3 + + # String labels + if split_non_somatic_good_mua: + labels = {0: "NOISE", 1: "good", 2: "mua", 3: "non_soma_good", 4: "non_soma_mua"} + else: + labels = {0: "noise", 1: "good", 2: "mua", 3: "non_soma"} + + unit_type_string = np.array([labels.get(int(t), "unknown") for t in unit_type], dtype=object) + return unit_type.astype(int), unit_type_string + + +def get_bombcell_labeling_summary(unit_type: np.ndarray, unit_type_string: np.ndarray) -> dict: + """Get counts and percentages for each unit type.""" + n_total = len(unit_type) + unique_types, counts = np.unique(unit_type, return_counts=True) + + summary = {"total_units": n_total, "counts": {}, "percentages": {}} + for utype, count in zip(unique_types, counts): + label = unit_type_string[unit_type == utype][0] + summary["counts"][label] = int(count) + summary["percentages"][label] = round(100 * count / n_total, 1) + + return summary + + +def save_bombcell_results( + quality_metrics: "pd.DataFrame", + unit_type: np.ndarray, + unit_type_string: np.ndarray, + thresholds: dict, + folder, + save_narrow: bool = True, + save_wide: bool = True, +) -> None: + """ + Save labeling results to CSV files. + + Parameters + ---------- + quality_metrics : pd.DataFrame + DataFrame with quality metrics (index = unit_ids). + unit_type : np.ndarray + Numeric unit type codes. + unit_type_string : np.ndarray + String labels for each unit. + thresholds : dict + Threshold dictionary used for labeling. + folder : str or Path + Folder to save the CSV files. + save_narrow : bool, default: True + Save narrow/tidy format (one row per unit-metric). + save_wide : bool, default: True + Save wide format (one row per unit, metrics as columns). + """ + from pathlib import Path + import pandas as pd + + folder = Path(folder) + folder.mkdir(parents=True, exist_ok=True) + + unit_ids = quality_metrics.index.values + + # Wide format: one row per unit + if save_wide: + wide_df = quality_metrics.copy() + wide_df.insert(0, "label", unit_type_string) + wide_df.insert(1, "label_code", unit_type) + wide_df.to_csv(folder / "labeling_results_wide.csv") + + # Narrow format: one row per unit-metric combination + if save_narrow: + rows = [] + for i, unit_id in enumerate(unit_ids): + label = unit_type_string[i] + label_code = unit_type[i] + for metric_name in quality_metrics.columns: + if metric_name not in thresholds: + continue + value = quality_metrics.loc[unit_id, metric_name] + thresh = thresholds[metric_name] + thresh_min = thresh.get("min", None) + thresh_max = thresh.get("max", None) + + # Determine pass/fail + passed = True + if np.isnan(value): + passed = False + elif not _is_threshold_disabled(thresh_min) and value < thresh_min: + passed = False + elif not _is_threshold_disabled(thresh_max) and value > thresh_max: + passed = False + + rows.append( + { + "unit_id": unit_id, + "label": label, + "label_code": label_code, + "metric_name": metric_name, + "value": value, + "threshold_min": thresh_min, + "threshold_max": thresh_max, + "passed": passed, + } + ) + + narrow_df = pd.DataFrame(rows) + narrow_df.to_csv(folder / "labeling_results_narrow.csv", index=False) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index e779e13182..99822d64f6 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -3,11 +3,10 @@ import json import warnings import re +from packaging.version import parse from spikeinterface.core import SortingAnalyzer from spikeinterface.curation.train_manual_curation import ( - try_to_get_metrics_from_analyzer, - _get_computed_metrics, _format_metric_dataframe, ) from copy import deepcopy @@ -81,11 +80,12 @@ def predict_labels( # Get metrics DataFrame for classification if input_data is None: - input_data = _get_computed_metrics(self.sorting_analyzer) + input_data = self.sorting_analyzer.get_metrics_extension_data() else: if not isinstance(input_data, pd.DataFrame): raise ValueError("Input data must be a pandas DataFrame") + input_data = self.handle_backwards_compatibility_in_metrics(input_data, model_info=model_info) input_data = self._check_required_metrics_are_present(input_data) if model_info is not None: @@ -127,8 +127,23 @@ def predict_labels( return classified_units - def _check_required_metrics_are_present(self, calculated_metrics): + def handle_backwards_compatibility_in_metrics(self, calculated_metrics, model_info): + si_version = model_info["requirements"].get("spikeinterface", None) + if si_version is not None and parse(si_version) < parse("0.103.2"): + # if the model was trained with SI version < 0.103.2, we need to rename some metrics + calculated_metrics = calculated_metrics.copy() + # peak_to_trough_duration was named peak_to_valley + if "peak_to_trough_duration" in calculated_metrics.columns: + calculated_metrics = calculated_metrics.rename(columns={"peak_to_trough_duration": "peak_to_valley"}) + # main_peak_to_trough_ratio was named peak_trough_ratio and had inverted sign + if "main_peak_to_trough_ratio" in calculated_metrics.columns: + calculated_metrics = calculated_metrics.rename( + columns={"main_peak_to_trough_ratio": "peak_trough_ratio"} + ) + calculated_metrics["peak_trough_ratio"] = -1 * calculated_metrics["peak_trough_ratio"] + return calculated_metrics + def _check_required_metrics_are_present(self, calculated_metrics): # Check all the required metrics have been calculated required_metrics = set(self.required_metrics) if required_metrics.issubset(set(calculated_metrics)): diff --git a/src/spikeinterface/curation/tests/test_bombcell_curation.py b/src/spikeinterface/curation/tests/test_bombcell_curation.py new file mode 100644 index 0000000000..a867453064 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_bombcell_curation.py @@ -0,0 +1,17 @@ +import pytest +from pathlib import Path +from spikeinterface.curation.tests.common import sorting_analyzer_for_curation, trained_pipeline_path +from spikeinterface.curation.bombcell_curation import bombcell_label_units + + +def test_bombcell_label_units(sorting_analyzer_for_curation): + """Test bombcell_label_units function on a sorting_analyzer with computed quality metrics.""" + + sorting_analyzer = sorting_analyzer_for_curation + sorting_analyzer.compute("quality_metrics") + sorting_analyzer.compute("template_metrics") + + unit_type, unit_type_string = bombcell_label_units(sorting_analyzer=sorting_analyzer) + + assert len(unit_type) == sorting_analyzer.unit_ids.size + assert set(unit_type_string).issubset({"somatic", "non-somatic", "good", "mua", "noise"}) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 198e98037c..c5f29ac329 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -146,7 +146,7 @@ class PresenceRatio(BaseMetric): def compute_snrs( sorting_analyzer, unit_ids=None, - peak_sign: str = "neg", + peak_sign: str = "both", peak_mode: str = "extremum", ): """ @@ -207,6 +207,119 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] +def compute_snrs_versus_baseline( + sorting_analyzer, + unit_ids=None, + peak_sign: str = "neg", + baseline_window_ms: float = 0.5, +): + """ + Compute signal to noise ratio versus baseline. + + This differs from the standard SNR by using: + - Signal: Max absolute value of the median waveform on peak channel + - Noise: MAD (Median Absolute Deviation) of baseline samples from waveforms + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + unit_ids : list or None + The list of unit ids to compute the SNR. If None, all units are used. + peak_sign : "neg" | "pos" | "both", default: "neg" + The sign of the template to compute best channels. + baseline_window_ms : float, default: 0.5 + Duration in ms at the start of the waveform to use as baseline for noise calculation. + + Returns + ------- + snrs : dict + Computed signal to noise ratio for each unit. + + Notes + ----- + This implementation follows the bombcell methodology [1]: + - Signal is the maximum absolute amplitude of the median waveform on the peak channel + - Noise is computed as MAD of baseline samples (first N samples of each waveform) + + Requires the "waveforms" extension to be computed. + + References + ---------- + [1] https://github.com/Julie-Fabre/bombcell + """ + if not sorting_analyzer.has_extension("waveforms"): + raise ValueError( + "The 'waveforms' extension is required for compute_snrs_versus_baseline. " + "Please compute it first with: analyzer.compute('waveforms')" + ) + + if unit_ids is None: + unit_ids = sorting_analyzer.unit_ids + + waveforms_ext = sorting_analyzer.get_extension("waveforms") + nbefore = waveforms_ext.nbefore + sampling_frequency = sorting_analyzer.sampling_frequency + + # Calculate baseline samples from ms + baseline_samples = int(baseline_window_ms / 1000 * sampling_frequency) + baseline_samples = min(baseline_samples, nbefore) # Can't exceed nbefore + + # Get peak channel for each unit from templates + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) + + snrs = {} + for unit_id in unit_ids: + # Get waveforms for this unit (num_spikes, num_samples, num_channels) + waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) + + if waveforms is None or len(waveforms) == 0: + snrs[unit_id] = np.nan + continue + + # Get peak channel index + peak_chan_id = extremum_channels_ids[unit_id] + if sorting_analyzer.is_sparse(): + chan_ids = sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] + if peak_chan_id not in chan_ids: + snrs[unit_id] = np.nan + continue + peak_chan_idx = np.where(chan_ids == peak_chan_id)[0][0] + else: + peak_chan_idx = sorting_analyzer.channel_ids_to_indices([peak_chan_id])[0] + + # Extract waveforms on peak channel + waveforms_peak = waveforms[:, :, peak_chan_idx] # (num_spikes, num_samples) + + # Signal: max absolute value of the median waveform + median_waveform = np.median(waveforms_peak, axis=0) # median across spikes + signal = np.max(np.abs(median_waveform)) + + # Noise: MAD of baseline samples (first N samples of each waveform) + baseline_samples_all = waveforms_peak[:, :baseline_samples].flatten() + median_baseline = np.median(baseline_samples_all) + noise = np.median(np.abs(baseline_samples_all - median_baseline)) + + # Calculate SNR (avoid division by zero) + if noise > 0: + snrs[unit_id] = signal / noise + else: + snrs[unit_id] = np.nan + + return snrs + + +class SNRBaseline(BaseMetric): + metric_name = "snr_baseline" + metric_function = compute_snrs_versus_baseline + metric_params = {"peak_sign": "neg", "baseline_window_ms": 0.5} + metric_columns = {"snr_baseline": float} + metric_descriptions = { + "snr_baseline": "Signal to noise ratio versus baseline (median waveform max / baseline MAD). Based on bombcell." + } + depend_on = ["waveforms", "templates"] + + def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -890,7 +1003,10 @@ def compute_amplitude_cutoffs( if invert_amplitudes: amplitudes = -amplitudes all_fraction_missing[unit_id] = amplitude_cutoff( - amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio + amplitudes, + num_histogram_bins, + histogram_smoothing_value, + amplitudes_bins_min_ratio, ) if np.any(np.isnan(list(all_fraction_missing.values()))): @@ -1392,6 +1508,7 @@ class SDRatio(BaseMetric): FiringRate, PresenceRatio, SNR, + # SNRBaseline, ISIViolation, RPViolation, SlidingRPViolation, @@ -1516,7 +1633,12 @@ def isi_violations(spike_trains, total_duration_s, isi_threshold_s=0.0015, min_i return isi_violations_ratio, isi_violations_rate, isi_violations_count -def amplitude_cutoff(amplitudes, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5): +def amplitude_cutoff( + amplitudes, + num_histogram_bins=500, + histogram_smoothing_value=3, + amplitudes_bins_min_ratio=5, +): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index a1af1de348..fc22f09006 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -1,66 +1,445 @@ from __future__ import annotations +import warnings import numpy as np -from collections import namedtuple - from spikeinterface.core.analyzer_extension_core import BaseMetric -def get_trough_and_peak_idx(template): +def get_trough_and_peak_idx( + template, min_thresh_detect_peaks_troughs=0.4, smooth=True, smooth_window_frac=0.1, smooth_polyorder=3 +): """ - Return the indices into the input template of the detected trough - (minimum of template) and peak (maximum of template, after trough). - Assumes negative trough and positive peak. + Detect troughs and peaks in a template waveform and return detailed information + about each detected feature. Parameters ---------- - template: numpy.ndarray + template : numpy.ndarray The 1D template waveform + min_thresh_detect_peaks_troughs : float, default: 0.4 + Minimum prominence threshold as a fraction of the template's absolute max value + smooth : bool, default: True + Whether to apply smoothing before peak detection + smooth_window_frac : float, default: 0.1 + Smoothing window length as a fraction of template length (0.05-0.2 recommended) + smooth_polyorder : int, default: 3 + Polynomial order for Savitzky-Golay filter (must be < window_length) Returns ------- - trough_idx: int - The index of the trough - peak_idx: int - The index of the peak + troughs : dict + Dictionary containing: + - "indices": array of all trough indices + - "values": array of all trough values + - "prominences": array of all trough prominences + - "widths": array of all trough widths + - "main_idx": index of the main trough (most prominent) + - "main_loc": location (sample index) of the main trough in template + peaks_before : dict + Dictionary containing peaks detected before the main trough (initial peaks): + - "indices": array of all peak indices (in original template coordinates) + - "values": array of all peak values + - "prominences": array of all peak prominences + - "widths": array of all peak widths + - "main_idx": index of the main peak (most prominent) + - "main_loc": location (sample index) of the main peak in template + peaks_after : dict + Dictionary containing peaks detected after the main trough (repolarization peaks): + - "indices": array of all peak indices (in original template coordinates) + - "values": array of all peak values + - "prominences": array of all peak prominences + - "widths": array of all peak widths + - "main_idx": index of the main peak (most prominent) + - "main_loc": location (sample index) of the main peak in template """ + from scipy.signal import find_peaks, savgol_filter + assert template.ndim == 1 - trough_idx = np.argmin(template) - peak_idx = trough_idx + np.argmax(template[trough_idx:]) - return trough_idx, peak_idx + # Smooth template to reduce noise while preserving peaks using Savitzky-Golay filter + if smooth: + window_length = int(len(template) * smooth_window_frac) // 2 * 2 + 1 + window_length = max(smooth_polyorder + 2, window_length) # Must be > polyorder + template = savgol_filter(template, window_length=window_length, polyorder=smooth_polyorder) + + # Initialize empty result dictionaries + empty_dict = { + "indices": np.array([], dtype=int), + "values": np.array([]), + "prominences": np.array([]), + "widths": np.array([]), + "main_idx": None, + "main_loc": None, + } -######################################################################################### -# Single-channel metrics -def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: + # Get min prominence to detect peaks and troughs relative to template abs max value + min_prominence = min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + + # --- Find troughs (by inverting waveform and using find_peaks) --- + trough_locs, trough_props = find_peaks(-template, prominence=min_prominence, width=0) + + if len(trough_locs) == 0: + # Fallback: use global minimum + trough_locs = np.array([np.nanargmin(template)]) + trough_props = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + # Determine main trough (most prominent, or first if no valid prominences) + trough_prominences = trough_props.get("prominences", np.array([])) + if len(trough_prominences) > 0 and not np.all(np.isnan(trough_prominences)): + main_trough_idx = np.nanargmax(trough_prominences) + else: + main_trough_idx = 0 + + main_trough_loc = trough_locs[main_trough_idx] + + troughs = { + "indices": trough_locs, + "values": template[trough_locs], + "prominences": trough_props.get("prominences", np.full(len(trough_locs), np.nan)), + "widths": trough_props.get("widths", np.full(len(trough_locs), np.nan)), + "main_idx": main_trough_idx, + "main_loc": main_trough_loc, + } + + # --- Find peaks before the main trough --- + if main_trough_loc > 3: + template_before = template[:main_trough_loc] + + # Try with original prominence + peak_locs_before, peak_props_before = find_peaks(template_before, prominence=min_prominence, width=0) + + # If no peaks found, try with lower prominence (keep only max peak) + if len(peak_locs_before) == 0: + lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + peak_locs_before, peak_props_before = find_peaks(template_before, prominence=lower_prominence, width=0) + # Keep only the most prominent peak when using lower threshold + if len(peak_locs_before) > 1: + prominences = peak_props_before.get("prominences", np.array([])) + if len(prominences) > 0 and not np.all(np.isnan(prominences)): + max_idx = np.nanargmax(prominences) + peak_locs_before = np.array([peak_locs_before[max_idx]]) + peak_props_before = { + "prominences": np.array([prominences[max_idx]]), + "widths": np.array([peak_props_before.get("widths", np.array([np.nan]))[max_idx]]), + } + + # If still no peaks found, fall back to argmax + if len(peak_locs_before) == 0: + peak_locs_before = np.array([np.nanargmax(template_before)]) + peak_props_before = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + peak_prominences_before = peak_props_before.get("prominences", np.array([])) + if len(peak_prominences_before) > 0 and not np.all(np.isnan(peak_prominences_before)): + main_peak_before_idx = np.nanargmax(peak_prominences_before) + else: + main_peak_before_idx = 0 + + peaks_before = { + "indices": peak_locs_before, + "values": template[peak_locs_before], + "prominences": peak_props_before.get("prominences", np.full(len(peak_locs_before), np.nan)), + "widths": peak_props_before.get("widths", np.full(len(peak_locs_before), np.nan)), + "main_idx": main_peak_before_idx, + "main_loc": peak_locs_before[main_peak_before_idx], + } + else: + peaks_before = empty_dict.copy() + + # --- Find peaks after the main trough (repolarization peaks) --- + if main_trough_loc < len(template) - 3: + template_after = template[main_trough_loc:] + + # Try with original prominence + peak_locs_after, peak_props_after = find_peaks(template_after, prominence=min_prominence, width=0) + + # If no peaks found, try with lower prominence (keep only max peak) + if len(peak_locs_after) == 0: + lower_prominence = 0.075 * min_thresh_detect_peaks_troughs * np.nanmax(np.abs(template)) + peak_locs_after, peak_props_after = find_peaks(template_after, prominence=lower_prominence, width=0) + # Keep only the most prominent peak when using lower threshold + if len(peak_locs_after) > 1: + prominences = peak_props_after.get("prominences", np.array([])) + if len(prominences) > 0 and not np.all(np.isnan(prominences)): + max_idx = np.nanargmax(prominences) + peak_locs_after = np.array([peak_locs_after[max_idx]]) + peak_props_after = { + "prominences": np.array([prominences[max_idx]]), + "widths": np.array([peak_props_after.get("widths", np.array([np.nan]))[max_idx]]), + } + + # If still no peaks found, fall back to argmax + if len(peak_locs_after) == 0: + peak_locs_after = np.array([np.nanargmax(template_after)]) + peak_props_after = {"prominences": np.array([np.nan]), "widths": np.array([np.nan])} + + # Convert to original template coordinates + peak_locs_after_abs = peak_locs_after + main_trough_loc + + peak_prominences_after = peak_props_after.get("prominences", np.array([])) + if len(peak_prominences_after) > 0 and not np.all(np.isnan(peak_prominences_after)): + main_peak_after_idx = np.nanargmax(peak_prominences_after) + else: + main_peak_after_idx = 0 + + peaks_after = { + "indices": peak_locs_after_abs, + "values": template[peak_locs_after_abs], + "prominences": peak_props_after.get("prominences", np.full(len(peak_locs_after), np.nan)), + "widths": peak_props_after.get("widths", np.full(len(peak_locs_after), np.nan)), + "main_idx": main_peak_after_idx, + "main_loc": peak_locs_after_abs[main_peak_after_idx], + } + else: + peaks_after = empty_dict.copy() + + return troughs, peaks_before, peaks_after + + +def get_main_to_next_peak_duration(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): """ - Return the peak to valley duration in seconds of input waveforms. + Calculate duration from the main extremum to the next extremum. + + The duration is measured from the largest absolute feature (main trough or main peak) + to the next extremum. For typical negative-first waveforms, this is trough-to-peak. + For positive-first waveforms, this is peak-to-trough. Parameters ---------- - template_single: numpy.ndarray + template : numpy.ndarray The 1D template waveform sampling_frequency : float - The sampling frequency of the template - trough_idx: int, default: None - The index of the trough - peak_idx: int, default: None - The index of the peak + The sampling frequency in Hz + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- - ptv: float - The peak to valley duration in seconds + main_to_next_peak_duration : float + Duration in seconds from main extremum to next extremum """ - if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - ptv = (peak_idx - trough_idx) / sampling_frequency - return ptv + + # Get main locations and values + trough_loc = troughs["main_loc"] + trough_val = template[trough_loc] if trough_loc is not None else None + + peak_before_loc = peaks_before["main_loc"] + peak_before_val = template[peak_before_loc] if peak_before_loc is not None else None + + peak_after_loc = peaks_after["main_loc"] + peak_after_val = template[peak_after_loc] if peak_after_loc is not None else None + + # Find the main extremum (largest absolute value) + candidates = [] + if trough_loc is not None and trough_val is not None: + candidates.append(("trough", trough_loc, abs(trough_val))) + if peak_before_loc is not None and peak_before_val is not None: + candidates.append(("peak_before", peak_before_loc, abs(peak_before_val))) + if peak_after_loc is not None and peak_after_val is not None: + candidates.append(("peak_after", peak_after_loc, abs(peak_after_val))) + + if len(candidates) == 0: + return np.nan + + # Sort by absolute value to find main extremum + candidates.sort(key=lambda x: x[2], reverse=True) + main_type, main_loc, _ = candidates[0] + + # Find the next extremum after the main one + if main_type == "trough": + # Main is trough, next is peak_after + if peak_after_loc is not None: + duration_samples = abs(peak_after_loc - main_loc) + elif peak_before_loc is not None: + duration_samples = abs(main_loc - peak_before_loc) + else: + return np.nan + elif main_type == "peak_before": + # Main is peak before, next is trough + if trough_loc is not None: + duration_samples = abs(trough_loc - main_loc) + else: + return np.nan + else: # peak_after + # Main is peak after, previous is trough + if trough_loc is not None: + duration_samples = abs(main_loc - trough_loc) + else: + return np.nan + + # Convert to seconds + main_to_next_peak_duration = duration_samples / sampling_frequency + + return main_to_next_peak_duration + + +def get_waveform_ratios(template, troughs, peaks_before, peaks_after, **kwargs): + """ + Calculate various waveform amplitude ratios. + + Parameters + ---------- + template : numpy.ndarray + The 1D template waveform + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx + + Returns + ------- + ratios : dict + Dictionary containing: + - "peak_before_to_trough_ratio": ratio of peak before to trough amplitude + - "peak_after_to_trough_ratio": ratio of peak after to trough amplitude + - "peak_before_to_peak_after_ratio": ratio of peak before to peak after amplitude + - "main_peak_to_trough_ratio": ratio of larger peak to trough amplitude + """ + # Get absolute amplitudes + trough_amp = abs(template[troughs["main_loc"]]) if troughs["main_loc"] is not None else np.nan + peak_before_amp = abs(template[peaks_before["main_loc"]]) if peaks_before["main_loc"] is not None else np.nan + peak_after_amp = abs(template[peaks_after["main_loc"]]) if peaks_after["main_loc"] is not None else np.nan + + def safe_ratio(a, b): + if np.isnan(a) or np.isnan(b) or b == 0: + return np.nan + return a / b + + ratios = { + "peak_before_to_trough_ratio": safe_ratio(peak_before_amp, trough_amp), + "peak_after_to_trough_ratio": safe_ratio(peak_after_amp, trough_amp), + "peak_before_to_peak_after_ratio": safe_ratio(peak_before_amp, peak_after_amp), + "main_peak_to_trough_ratio": safe_ratio( + ( + max(peak_before_amp, peak_after_amp) + if not (np.isnan(peak_before_amp) and np.isnan(peak_after_amp)) + else np.nan + ), + trough_amp, + ), + } + + return ratios -def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=None, peak_idx=None, **kwargs) -> float: +def get_waveform_baseline_flatness(template, sampling_frequency, **kwargs): """ - Return the peak to trough ratio of input waveforms. + Compute the baseline flatness of the waveform. + + This metric measures the ratio of the max absolute amplitude in the baseline + window to the max absolute amplitude of the whole waveform. A lower value + indicates a flat baseline (expected for good units). + + Parameters + ---------- + template : numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency in Hz + **kwargs : Required kwargs: + - baseline_window_ms : tuple of (start_ms, end_ms) defining the baseline window + relative to waveform start. Default is (0, 0.5) for first 0.5ms. + + Returns + ------- + baseline_flatness : float + Ratio of max(abs(baseline)) / max(abs(waveform)). Lower = flatter baseline. + """ + baseline_window_ms = kwargs.get("baseline_window_ms", (0.0, 0.5)) + + if baseline_window_ms is None: + return np.nan + + start_ms, end_ms = baseline_window_ms + start_idx = int(start_ms / 1000 * sampling_frequency) + end_idx = int(end_ms / 1000 * sampling_frequency) + + # Clamp to valid range + start_idx = max(0, start_idx) + end_idx = min(len(template), end_idx) + + if end_idx <= start_idx: + return np.nan + + baseline_segment = template[start_idx:end_idx] + + if len(baseline_segment) == 0: + return np.nan + + max_baseline = np.nanmax(np.abs(baseline_segment)) + max_waveform = np.nanmax(np.abs(template)) + + if max_waveform == 0 or np.isnan(max_waveform): + return np.nan + + baseline_flatness = max_baseline / max_waveform + + return baseline_flatness + + +def get_waveform_widths(template, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): + """ + Get the widths of the main trough and peaks in seconds. + + Parameters + ---------- + template : numpy.ndarray + The 1D template waveform + sampling_frequency : float + The sampling frequency in Hz + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx + + Returns + ------- + widths : dict + Dictionary containing: + - "trough_width": width of main trough in seconds + - "peak_before_width": width of main peak before trough in seconds + - "peak_after_width": width of main peak after trough in seconds + """ + + def get_main_width(feature_dict): + if feature_dict["main_idx"] is None: + return np.nan + widths = feature_dict.get("widths", np.array([])) + if len(widths) == 0: + return np.nan + main_idx = feature_dict["main_idx"] + if main_idx < len(widths): + return widths[main_idx] + return np.nan + + # Convert from samples to seconds + samples_to_seconds = 1.0 / sampling_frequency + + trough_width = get_main_width(troughs) + peak_before_width = get_main_width(peaks_before) + peak_after_width = get_main_width(peaks_after) + + widths = { + "trough_width": trough_width * samples_to_seconds if not np.isnan(trough_width) else np.nan, + "peak_before_width": peak_before_width * samples_to_seconds if not np.isnan(peak_before_width) else np.nan, + "peak_after_width": peak_after_width * samples_to_seconds if not np.isnan(peak_after_width) else np.nan, + } + + return widths + + +######################################################################################### +# Single-channel metrics +def get_peak_to_valley(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: + """ + Return the peak to valley duration in seconds of input waveforms. Parameters ---------- @@ -75,13 +454,17 @@ def get_peak_trough_ratio(template_single, sampling_frequency=None, trough_idx=N Returns ------- - ptratio: float - The peak to trough ratio + ptv: float + The peak to valley duration in seconds """ if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) - ptratio = template_single[peak_idx] / template_single[trough_idx] - return ptratio + troughs, _, peaks_after = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] + peak_idx = peaks_after["main_loc"] + if trough_idx is None or peak_idx is None: + return np.nan + ptv = (peak_idx - trough_idx) / sampling_frequency + return ptv def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_idx=None, **kwargs) -> float: @@ -105,9 +488,11 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id The half width in seconds """ if trough_idx is None or peak_idx is None: - trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + troughs, _, peaks_after = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] + peak_idx = peaks_after["main_loc"] - if peak_idx == 0: + if peak_idx is None or peak_idx == 0: return np.nan trough_val = template_single[trough_idx] @@ -156,11 +541,12 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non The repolarization slope """ if trough_idx is None: - trough_idx, _ = get_trough_and_peak_idx(template_single) + troughs, _, _ = get_trough_and_peak_idx(template_single) + trough_idx = troughs["main_loc"] times = np.arange(template_single.shape[0]) / sampling_frequency - if trough_idx == 0: + if trough_idx is None or trough_idx == 0: return np.nan (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) @@ -209,11 +595,12 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa assert "recovery_window_ms" in kwargs, "recovery_window_ms must be given as kwarg" recovery_window_ms = kwargs["recovery_window_ms"] if peak_idx is None: - _, peak_idx = get_trough_and_peak_idx(template_single) + _, _, peaks_after = get_trough_and_peak_idx(template_single) + peak_idx = peaks_after["main_loc"] times = np.arange(template_single.shape[0]) / sampling_frequency - if peak_idx == 0: + if peak_idx is None or peak_idx == 0: return np.nan max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) max_idx = np.min([max_idx, template_single.shape[0]]) @@ -222,9 +609,12 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa return res.slope -def get_number_of_peaks(template_single, sampling_frequency, **kwargs): +def get_number_of_peaks(template_single, sampling_frequency, troughs, peaks_before, peaks_after, **kwargs): """ - Count the total number of peaks (positive + negative) in the template. + Count the total number of peaks (positive) and troughs (negative) in the template. + + Uses the pre-computed peak/trough detection from get_trough_and_peak_idx which + applies smoothing for more robust detection. Parameters ---------- @@ -232,28 +622,28 @@ def get_number_of_peaks(template_single, sampling_frequency, **kwargs): The 1D template waveform sampling_frequency : float The sampling frequency of the template - **kwargs: Required kwargs: - - peak_relative_threshold: the relative threshold to detect positive and negative peaks - - peak_width_ms: the width in samples to detect peaks + troughs : dict + Trough detection results from get_trough_and_peak_idx + peaks_before : dict + Peak before trough results from get_trough_and_peak_idx + peaks_after : dict + Peak after trough results from get_trough_and_peak_idx Returns ------- - number_of_peaks: int - the total number of peaks (positive + negative) - """ - from scipy.signal import find_peaks - - assert "peak_relative_threshold" in kwargs, "peak_relative_threshold must be given as kwarg" - assert "peak_width_ms" in kwargs, "peak_width_ms must be given as kwarg" - peak_relative_threshold = kwargs["peak_relative_threshold"] - peak_width_ms = kwargs["peak_width_ms"] - max_value = np.max(np.abs(template_single)) - peak_width_samples = int(peak_width_ms / 1000 * sampling_frequency) - - pos_peaks = find_peaks(template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) - neg_peaks = find_peaks(-template_single, height=peak_relative_threshold * max_value, width=peak_width_samples) - num_positive = len(pos_peaks[0]) - num_negative = len(neg_peaks[0]) + num_positive_peaks : int + The number of positive peaks (peaks_before + peaks_after) + num_negative_peaks : int + The number of negative peaks (troughs) + """ + # Count peaks (positive) from peaks_before and peaks_after + num_peaks_before = len(peaks_before["indices"]) + num_peaks_after = len(peaks_after["indices"]) + num_positive = num_peaks_before + num_peaks_after + + # Count troughs (negative) + num_negative = len(troughs["indices"]) + return num_positive, num_negative @@ -293,8 +683,10 @@ def fit_velocity(peak_times, channel_dist): from sklearn.linear_model import TheilSenRegressor - theil = TheilSenRegressor() - theil.fit(peak_times.reshape(-1, 1), channel_dist) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*Maximum number of iterations.*") + theil = TheilSenRegressor(max_iter=1000) + theil.fit(peak_times.reshape(-1, 1), channel_dist) slope = theil.coef_[0] intercept = theil.intercept_ score = theil.score(peak_times.reshape(-1, 1), channel_dist) @@ -376,7 +768,11 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs): """ - Compute the exponential decay of the template amplitude over distance in units um/s. + Compute the spatial decay of the template amplitude over distance. + + Can fit either an exponential decay (with offset) or a linear decay model. Channels are first + filtered by x-distance tolerance from the max channel, then the closest channels + in y-distance are used for fitting. Parameters ---------- @@ -387,13 +783,18 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs sampling_frequency : float The sampling frequency of the template **kwargs: Required kwargs: - - peak_function: the function to use to compute the peak amplitude for the exp decay ("ptp" or "min") - - min_r2: the minimum r2 to accept the exp decay fit + - peak_function: the function to use to compute the peak amplitude ("ptp" or "min") + - min_r2: the minimum r2 to accept the fit + - linear_fit: bool, if True use linear fit, otherwise exponential fit + - channel_tolerance: max x-distance (um) from max channel to include channels + - min_channels_for_fit: minimum number of valid channels required for fitting + - num_channels_for_fit: number of closest channels to use for fitting + - normalize_decay: bool, if True normalize amplitudes to max before fitting Returns ------- exp_decay_value : float - The exponential decay of the template amplitude + The spatial decay slope (decay constant for exp fit, negative slope for linear fit) """ from scipy.optimize import curve_fit from sklearn.metrics import r2_score @@ -401,41 +802,117 @@ def get_exp_decay(template, channel_locations, sampling_frequency=None, **kwargs def exp_decay(x, decay, amp0, offset): return amp0 * np.exp(-decay * x) + offset + def linear_fit_func(x, a, b): + return a * x + b + + # Extract parameters assert "peak_function" in kwargs, "peak_function must be given as kwarg" peak_function = kwargs["peak_function"] assert "min_r2" in kwargs, "min_r2 must be given as kwarg" min_r2 = kwargs["min_r2"] - # exp decay fit + + use_linear_fit = kwargs.get("linear_fit", False) + channel_tolerance = kwargs.get("channel_tolerance", None) + normalize_decay = kwargs.get("normalize_decay", False) + + # Set defaults based on fit type if not specified + min_channels_for_fit = kwargs.get("min_channels_for_fit") + if min_channels_for_fit is None: + min_channels_for_fit = 5 if use_linear_fit else 8 + + num_channels_for_fit = kwargs.get("num_channels_for_fit") + if num_channels_for_fit is None: + num_channels_for_fit = 6 if use_linear_fit else 10 + + # Compute peak amplitudes per channel if peak_function == "ptp": fun = np.ptp elif peak_function == "min": fun = np.min + else: + fun = np.ptp + peak_amplitudes = np.abs(fun(template, axis=0)) - max_channel_location = channel_locations[np.argmax(peak_amplitudes)] - channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) - distances_sort_indices = np.argsort(channel_distances) + max_channel_idx = np.argmax(peak_amplitudes) + max_channel_location = channel_locations[max_channel_idx] - # longdouble is float128 when the platform supports it, otherwise it is float64 - channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) - peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) + # Channel selection based on tolerance (new bombcell-style) or use all channels (old style) + if channel_tolerance is not None: + # Calculate x-distances from max channel + x_dist = np.abs(channel_locations[:, 0] - max_channel_location[0]) - try: - amp0 = peak_amplitudes_sorted[0] - offset0 = np.min(peak_amplitudes_sorted) - - popt, _ = curve_fit( - exp_decay, - channel_distances_sorted, - peak_amplitudes_sorted, - bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), - p0=[1e-3, peak_amplitudes_sorted[0], offset0], + # Find channels within x-distance tolerance + valid_x_channels = np.argwhere(x_dist <= channel_tolerance).flatten() + + if len(valid_x_channels) < min_channels_for_fit: + return np.nan + + # Calculate y-distances for channel selection + y_dist = np.abs(channel_locations[:, 1] - max_channel_location[1]) + + # Set y distances to max for channels outside x tolerance (so they won't be selected) + y_dist_masked = y_dist.copy() + y_dist_masked[~np.isin(np.arange(len(y_dist)), valid_x_channels)] = y_dist.max() + 1 + + # Select the closest channels in y-distance + use_these_channels = np.argsort(y_dist_masked)[:num_channels_for_fit] + + # Calculate distances from max channel for selected channels + channel_distances = np.sqrt( + np.sum(np.square(channel_locations[use_these_channels] - max_channel_location), axis=1) ) - r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) - exp_decay_value = popt[0] + + # Get amplitudes for selected channels + spatial_decay_points = np.max(np.abs(template[:, use_these_channels]), axis=0) + + # Sort by distance + sort_idx = np.argsort(channel_distances) + channel_distances_sorted = channel_distances[sort_idx] + peak_amplitudes_sorted = spatial_decay_points[sort_idx] + + # Normalize if requested + if normalize_decay: + peak_amplitudes_sorted = peak_amplitudes_sorted / np.max(peak_amplitudes_sorted) + + # Ensure float64 for numerical stability + channel_distances_sorted = np.float64(channel_distances_sorted) + peak_amplitudes_sorted = np.float64(peak_amplitudes_sorted) + + else: + # Old style: use all channels sorted by distance + channel_distances = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations]) + distances_sort_indices = np.argsort(channel_distances) + + # longdouble is float128 when the platform supports it, otherwise it is float64 + channel_distances_sorted = channel_distances[distances_sort_indices].astype(np.longdouble) + peak_amplitudes_sorted = peak_amplitudes[distances_sort_indices].astype(np.longdouble) + + try: + if use_linear_fit: + # Linear fit: y = a*x + b + popt, _ = curve_fit(linear_fit_func, channel_distances_sorted, peak_amplitudes_sorted) + predicted = linear_fit_func(channel_distances_sorted, *popt) + r2 = r2_score(peak_amplitudes_sorted, predicted) + exp_decay_value = -popt[0] # Negative of slope + else: + # Exponential fit with offset: y = amp0 * exp(-decay * x) + offset + amp0 = peak_amplitudes_sorted[0] + offset0 = np.min(peak_amplitudes_sorted) + + popt, _ = curve_fit( + exp_decay, + channel_distances_sorted, + peak_amplitudes_sorted, + bounds=([1e-5, amp0 - 0.5 * amp0, 0], [2, amp0 + 0.5 * amp0, 2 * offset0]), + p0=[1e-3, peak_amplitudes_sorted[0], offset0], + ) + r2 = r2_score(peak_amplitudes_sorted, exp_decay(channel_distances_sorted, *popt)) + exp_decay_value = popt[0] if r2 < min_r2: exp_decay_value = np.nan - except: + + except Exception: exp_decay_value = np.nan return exp_decay_value @@ -512,17 +989,17 @@ def single_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, * return result -class PeakToValley(BaseMetric): - metric_name = "peak_to_valley" +class PeakToTroughDuration(BaseMetric): + metric_name = "peak_to_trough_duration" metric_params = {} - metric_columns = {"peak_to_valley": float} + metric_columns = {"peak_to_trough_duration": float} metric_descriptions = { - "peak_to_valley": "Duration in s between the trough (minimum) and the peak (maximum) of the spike waveform." + "peak_to_trough_duration": "Duration in seconds between the trough (minimum) and the peak (maximum) of the spike waveform." } needs_tmp_data = True @staticmethod - def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + def _peak_to_trough_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): return single_channel_metric( unit_function=get_peak_to_valley, sorting_analyzer=sorting_analyzer, @@ -531,29 +1008,7 @@ def _peak_to_valley_metric_function(sorting_analyzer, unit_ids, tmp_data, **metr **metric_params, ) - metric_function = _peak_to_valley_metric_function - - -class PeakToTroughRatio(BaseMetric): - metric_name = "peak_trough_ratio" - metric_params = {} - metric_columns = {"peak_trough_ratio": float} - metric_descriptions = { - "peak_trough_ratio": "Ratio of the amplitude of the peak (maximum) to the trough (minimum) of the spike waveform." - } - needs_tmp_data = True - - @staticmethod - def _peak_to_trough_ratio_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): - return single_channel_metric( - unit_function=get_peak_trough_ratio, - sorting_analyzer=sorting_analyzer, - unit_ids=unit_ids, - tmp_data=tmp_data, - **metric_params, - ) - - metric_function = _peak_to_trough_ratio_metric_function + metric_function = _peak_to_trough_duration_metric_function class HalfWidth(BaseMetric): @@ -623,14 +1078,26 @@ def _recovery_slope_metric_function(sorting_analyzer, unit_ids, tmp_data, **metr def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + from collections import namedtuple + num_peaks_result = namedtuple("NumberOfPeaksResult", ["num_positive_peaks", "num_negative_peaks"]) num_positive_peaks_dict = {} num_negative_peaks_dict = {} - sampling_frequency = sorting_analyzer.sampling_frequency + sampling_frequency = tmp_data["sampling_frequency"] templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] for unit_index, unit_id in enumerate(unit_ids): template_single = templates_single[unit_index] - num_positive, num_negative = get_number_of_peaks(template_single, sampling_frequency, **metric_params) + num_positive, num_negative = get_number_of_peaks( + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) num_positive_peaks_dict[unit_id] = num_positive num_negative_peaks_dict[unit_id] = num_negative return num_peaks_result(num_positive_peaks=num_positive_peaks_dict, num_negative_peaks=num_negative_peaks_dict) @@ -639,26 +1106,192 @@ def _number_of_peaks_metric_function(sorting_analyzer, unit_ids, tmp_data, **met class NumberOfPeaks(BaseMetric): metric_name = "number_of_peaks" metric_function = _number_of_peaks_metric_function - metric_params = {"peak_relative_threshold": 0.2, "peak_width_ms": 0.1} + metric_params = {} metric_columns = {"num_positive_peaks": int, "num_negative_peaks": int} metric_descriptions = { "num_positive_peaks": "Number of positive peaks in the template", - "num_negative_peaks": "Number of negative peaks in the template", + "num_negative_peaks": "Number of negative peaks (troughs) in the template", + } + needs_tmp_data = True + + +def _main_to_next_peak_duration_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + value = get_main_to_next_peak_duration( + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) + result[unit_id] = value + return result + + +class MainToNextPeakDuration(BaseMetric): + metric_name = "main_to_next_peak_duration" + metric_function = _main_to_next_peak_duration_metric_function + metric_params = {} + metric_columns = {"main_to_next_peak_duration": float} + metric_descriptions = {"main_to_next_peak_duration": "Duration in seconds from main extremum to next extremum."} + needs_tmp_data = True + + +def _waveform_ratios_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + from collections import namedtuple + + waveform_ratios_result = namedtuple( + "WaveformRatiosResult", + [ + "peak_before_to_trough_ratio", + "peak_after_to_trough_ratio", + "peak_before_to_peak_after_ratio", + "main_peak_to_trough_ratio", + ], + ) + peak_before_to_trough = {} + peak_after_to_trough = {} + peak_before_to_peak_after = {} + main_peak_to_trough = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + ratios = get_waveform_ratios( + template_single, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) + peak_before_to_trough[unit_id] = ratios["peak_before_to_trough_ratio"] + peak_after_to_trough[unit_id] = ratios["peak_after_to_trough_ratio"] + peak_before_to_peak_after[unit_id] = ratios["peak_before_to_peak_after_ratio"] + main_peak_to_trough[unit_id] = ratios["main_peak_to_trough_ratio"] + return waveform_ratios_result( + peak_before_to_trough_ratio=peak_before_to_trough, + peak_after_to_trough_ratio=peak_after_to_trough, + peak_before_to_peak_after_ratio=peak_before_to_peak_after, + main_peak_to_trough_ratio=main_peak_to_trough, + ) + + +class WaveformRatios(BaseMetric): + metric_name = "waveform_ratios" + metric_function = _waveform_ratios_metric_function + metric_params = {} + metric_columns = { + "peak_before_to_trough_ratio": float, + "peak_after_to_trough_ratio": float, + "peak_before_to_peak_after_ratio": float, + "main_peak_to_trough_ratio": float, + } + metric_descriptions = { + "peak_before_to_trough_ratio": "Ratio of peak before amplitude to trough amplitude", + "peak_after_to_trough_ratio": "Ratio of peak after amplitude to trough amplitude", + "peak_before_to_peak_after_ratio": "Ratio of peak before amplitude to peak after amplitude", + "main_peak_to_trough_ratio": "Ratio of main peak amplitude to trough amplitude", + } + needs_tmp_data = True + + +def _waveform_widths_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + from collections import namedtuple + + waveform_widths_result = namedtuple( + "WaveformWidthsResult", ["trough_width", "peak_before_width", "peak_after_width"] + ) + trough_width_dict = {} + peak_before_width_dict = {} + peak_after_width_dict = {} + templates_single = tmp_data["templates_single"] + troughs_info = tmp_data["troughs_info"] + peaks_before_info = tmp_data["peaks_before_info"] + peaks_after_info = tmp_data["peaks_after_info"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + widths = get_waveform_widths( + template_single, + sampling_frequency, + troughs_info[unit_id], + peaks_before_info[unit_id], + peaks_after_info[unit_id], + **metric_params, + ) + trough_width_dict[unit_id] = widths["trough_width"] + peak_before_width_dict[unit_id] = widths["peak_before_width"] + peak_after_width_dict[unit_id] = widths["peak_after_width"] + return waveform_widths_result( + trough_width=trough_width_dict, peak_before_width=peak_before_width_dict, peak_after_width=peak_after_width_dict + ) + + +class WaveformWidths(BaseMetric): + metric_name = "waveform_widths" + metric_function = _waveform_widths_metric_function + metric_params = {} + metric_columns = { + "trough_width": float, + "peak_before_width": float, + "peak_after_width": float, + } + metric_descriptions = { + "trough_width": "Width of the main trough in seconds", + "peak_before_width": "Width of the main peak before trough in seconds", + "peak_after_width": "Width of the main peak after trough in seconds", + } + needs_tmp_data = True + + +def _waveform_baseline_flatness_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + result = {} + templates_single = tmp_data["templates_single"] + sampling_frequency = tmp_data["sampling_frequency"] + for unit_index, unit_id in enumerate(unit_ids): + template_single = templates_single[unit_index] + value = get_waveform_baseline_flatness(template_single, sampling_frequency, **metric_params) + result[unit_id] = value + return result + + +class WaveformBaselineFlatness(BaseMetric): + metric_name = "waveform_baseline_flatness" + metric_function = _waveform_baseline_flatness_metric_function + metric_params = {"baseline_window_ms": (0.0, 0.5)} + metric_columns = {"waveform_baseline_flatness": float} + metric_descriptions = { + "waveform_baseline_flatness": "Ratio of max baseline amplitude to max waveform amplitude. Lower = flatter baseline." } needs_tmp_data = True single_channel_metrics = [ - PeakToValley, - PeakToTroughRatio, + PeakToTroughDuration, HalfWidth, RepolarizationSlope, RecoverySlope, NumberOfPeaks, + MainToNextPeakDuration, + WaveformRatios, + WaveformWidths, + WaveformBaselineFlatness, ] def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_params): + from collections import namedtuple + velocity_above_result = namedtuple("Velocities", ["velocity_above", "velocity_below"]) velocity_above_dict = {} velocity_below_dict = {} @@ -707,10 +1340,21 @@ def multi_channel_metric(unit_function, sorting_analyzer, unit_ids, tmp_data, ** class ExpDecay(BaseMetric): metric_name = "exp_decay" - metric_params = {"peak_function": "ptp", "min_r2": 0.2} + metric_params = { + "peak_function": "ptp", + "min_r2": 0.2, + "linear_fit": False, + "channel_tolerance": None, # None uses old style (all channels), set to e.g. 33 for bombcell-style + "min_channels_for_fit": None, # None means use default based on linear_fit (5 for linear, 8 for exp) + "num_channels_for_fit": None, # None means use default based on linear_fit (6 for linear, 10 for exp) + "normalize_decay": False, + } metric_columns = {"exp_decay": float} metric_descriptions = { - "exp_decay": ("Exponential decay of the template amplitude over distance from the extremum channel (1/um).") + "exp_decay": ( + "Spatial decay of the template amplitude over distance from the extremum channel (1/um). " + "Uses exponential or linear fit based on linear_fit parameter." + ) } needs_tmp_data = True @@ -729,7 +1373,7 @@ def _exp_decay_metric_function(sorting_analyzer, unit_ids, tmp_data, **metric_pa class Spread(BaseMetric): metric_name = "spread" - metric_params = {"spread_threshold": 0.5, "spread_smooth_um": 20, "column_range": None} + metric_params = {"spread_threshold": 0.2, "spread_smooth_um": 20, "column_range": None} metric_columns = {"spread": float} metric_descriptions = { "spread": ( diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 85ef9e22cb..212369ca12 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -6,9 +6,8 @@ from __future__ import annotations -import numpy as np import warnings -from copy import deepcopy +import numpy as np from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.analyzer_extension_core import BaseMetricExtension @@ -33,6 +32,8 @@ def get_template_metric_list(): def get_template_metric_names(): + import warnings + warnings.warn( "get_template_metric_names is deprecated and will be removed in a version 0.105.0. " "Please use get_template_metric_list instead.", @@ -45,8 +46,8 @@ def get_template_metric_names(): class ComputeTemplateMetrics(BaseMetricExtension): """ Compute template metrics including: - * peak_to_valley - * peak_trough_ratio + * peak_to_trough_duration + * peak_to_trough_ratio * halfwidth * repolarization_slope * recovery_slope @@ -95,6 +96,8 @@ class ComputeTemplateMetrics(BaseMetricExtension): metric_list = single_channel_metrics + multi_channel_metrics def _handle_backward_compatibility_on_load(self): + from copy import deepcopy + # For backwards compatibility - this reformats metrics_kwargs as metric_params if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: @@ -106,24 +109,52 @@ def _handle_backward_compatibility_on_load(self): del self.params["metrics_kwargs"] # handle metric names change: - # num_positive_peaks/num_negative_peaks merged into number_of_peaks if "num_positive_peaks" in self.params["metric_names"]: self.params["metric_names"].remove("num_positive_peaks") if "number_of_peaks" not in self.params["metric_names"]: self.params["metric_names"].append("number_of_peaks") + if "num_positive_peaks" in self.params["metric_params"]: + del self.params["metric_params"]["num_positive_peaks"] if "num_negative_peaks" in self.params["metric_names"]: self.params["metric_names"].remove("num_negative_peaks") if "number_of_peaks" not in self.params["metric_names"]: self.params["metric_names"].append("number_of_peaks") + if "num_negative_peaks" in self.params["metric_params"]: + del self.params["metric_params"]["num_negative_peaks"] # velocity_above/velocity_below merged into velocity_fits if "velocity_above" in self.params["metric_names"]: self.params["metric_names"].remove("velocity_above") if "velocity_fits" not in self.params["metric_names"]: self.params["metric_names"].append("velocity_fits") + self.params["metric_params"]["velocity_fits"] = self.params["metric_params"]["velocity_above"] + self.params["metric_params"]["velocity_fits"]["min_channels"] = self.params["metric_params"][ + "velocity_above" + ]["min_channels_for_velocity"] + self.params["metric_params"]["velocity_fits"]["min_r2"] = self.params["metric_params"]["velocity_above"][ + "min_r2_velocity" + ] + del self.params["metric_params"]["velocity_above"] if "velocity_below" in self.params["metric_names"]: self.params["metric_names"].remove("velocity_below") if "velocity_fits" not in self.params["metric_names"]: self.params["metric_names"].append("velocity_fits") + # parameters are already updated from velocity_above + if "velocity_below" in self.params["metric_params"]: + del self.params["metric_params"]["velocity_below"] + # peak to valley -> peak_to_trough_duration + if "peak_to_valley" in self.params["metric_names"]: + self.params["metric_names"].remove("peak_to_valley") + if "peak_to_trough_duration" not in self.params["metric_names"]: + self.params["metric_names"].append("peak_to_trough_duration") + # peak_trough ratio -> main peak to trough ratio + # note that the new implementation correctly uses the absolute peak values, + # which is different from the old implementation. + # we make a flag to invert the polarity of old values if needed + if "peak_trough_ratio" in self.params["metric_names"]: + self.params["metric_names"].remove("peak_trough_ratio") + if "waveform_ratios" not in self.params["metric_names"]: + self.params["metric_names"].append("waveform_ratios") + self.params["metric_params"]["invert_peak_to_trough"] = True def _set_params( self, @@ -137,6 +168,10 @@ def _set_params( upsampling_factor=10, include_multi_channel_metrics=False, depth_direction="y", + min_thresh_detect_peaks_troughs=0.4, + smooth=True, + smooth_window_frac=0.1, + smooth_polyorder=3, ): # Auto-detect if multi-channel metrics should be included based on number of channels num_channels = self.sorting_analyzer.get_num_channels() @@ -166,9 +201,15 @@ def _set_params( upsampling_factor=upsampling_factor, include_multi_channel_metrics=include_multi_channel_metrics, depth_direction=depth_direction, + min_thresh_detect_peaks_troughs=min_thresh_detect_peaks_troughs, + smooth=smooth, + smooth_window_frac=smooth_window_frac, + smooth_polyorder=smooth_polyorder, ) def _prepare_data(self, sorting_analyzer, unit_ids): + import warnings + from scipy.signal import resample_poly # compute templates_single and templates_multi (if include_multi_channel_metrics is True) @@ -197,6 +238,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids): templates_single = [] troughs = {} peaks = {} + troughs_info = {} + peaks_before_info = {} + peaks_after_info = {} templates_multi = [] channel_locations_multi = [] for unit_id in unit_ids: @@ -210,11 +254,22 @@ def _prepare_data(self, sorting_analyzer, unit_ids): else: template_upsampled = template_single sampling_frequency_up = sampling_frequency - trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + troughs_dict, peaks_before_dict, peaks_after_dict = get_trough_and_peak_idx( + template_upsampled, + min_thresh_detect_peaks_troughs=self.params["min_thresh_detect_peaks_troughs"], + smooth=self.params["smooth"], + smooth_window_frac=self.params["smooth_window_frac"], + smooth_polyorder=self.params["smooth_polyorder"], + ) templates_single.append(template_upsampled) - troughs[unit_id] = trough_idx - peaks[unit_id] = peak_idx + # Store main locations for backward compatibility + troughs[unit_id] = troughs_dict["main_loc"] + peaks[unit_id] = peaks_after_dict["main_loc"] + # Store full dicts for new metrics + troughs_info[unit_id] = troughs_dict + peaks_before_info[unit_id] = peaks_before_dict + peaks_after_info[unit_id] = peaks_after_dict if include_multi_channel_metrics: if sorting_analyzer.is_sparse(): @@ -239,6 +294,9 @@ def _prepare_data(self, sorting_analyzer, unit_ids): tmp_data["troughs"] = troughs tmp_data["peaks"] = peaks + tmp_data["troughs_info"] = troughs_info + tmp_data["peaks_before_info"] = peaks_before_info + tmp_data["peaks_after_info"] = peaks_after_info tmp_data["templates_single"] = np.array(templates_single) if include_multi_channel_metrics: @@ -249,6 +307,18 @@ def _prepare_data(self, sorting_analyzer, unit_ids): return tmp_data + def get_data(self, *args, **kwargs): + """Override to handle deprecated polarity of 'peak_trough_ratio' metric.""" + metrics = super().get_data(*args, **kwargs) + if self.params["metric_params"].get("invert_peak_to_trough", False): + if "peak_trough_ratio" in metrics.columns: + warnings.warn( + "The 'peak_trough_ratio' metric has been deprecated and replaced by 'main_peak_to_trough_ratio'. " + "The values have been inverted to maintain consistency with previous versions." + ) + metrics["peak_trough_ratio"] = -metrics["peak_trough_ratio"] + return metrics + register_result_extension(ComputeTemplateMetrics) compute_template_metrics = ComputeTemplateMetrics.function_factory() @@ -273,6 +343,8 @@ def get_default_tm_params(metric_names=None): metric_params : dict Dictionary with default parameters for template metrics. """ + import warnings + warnings.warn( "get_default_tm_params is deprecated and will be removed in a version 0.105.0. " "Please use get_default_template_metrics_params instead.", diff --git a/src/spikeinterface/metrics/template/tests/test_template_metrics.py b/src/spikeinterface/metrics/template/tests/test_template_metrics.py index 8f1bf05b85..66437b156e 100644 --- a/src/spikeinterface/metrics/template/tests/test_template_metrics.py +++ b/src/spikeinterface/metrics/template/tests/test_template_metrics.py @@ -83,7 +83,7 @@ def test_metric_names_in_same_order(small_sorting_analyzer): """ Computes sepecified template metrics and checks order is propagated. """ - specified_metric_names = ["peak_trough_ratio", "half_width", "peak_to_valley"] + specified_metric_names = ["main_peak_to_trough_ratio", "half_width", "peak_to_valley"] small_sorting_analyzer.compute( "template_metrics", metric_names=specified_metric_names, delete_existing_metrics=True ) diff --git a/src/spikeinterface/widgets/bombcell_curation.py b/src/spikeinterface/widgets/bombcell_curation.py new file mode 100644 index 0000000000..030cb67327 --- /dev/null +++ b/src/spikeinterface/widgets/bombcell_curation.py @@ -0,0 +1,389 @@ +"""Widgets for visualizing unit labeling results.""" + +from __future__ import annotations + +import numpy as np +from typing import Optional + +from .base import BaseWidget, to_attr + +from .unit_labels import WaveformOverlayByLabelWidget + + +def _is_threshold_disabled(value): + """Check if a threshold value is disabled (None or np.nan).""" + if value is None: + return True + if isinstance(value, float) and np.isnan(value): + return True + return False + + +class LabelingHistogramsWidget(BaseWidget): + """Plot histograms of quality metrics with threshold lines.""" + + def __init__( + self, + sorting_analyzer, + thresholds: Optional[dict] = None, + metrics_to_plot: Optional[list] = None, + backend=None, + **backend_kwargs, + ): + from spikeinterface.curation import bombcell_get_default_thresholds + + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + combined_metrics = sorting_analyzer.get_metrics_extension_data() + if combined_metrics.empty: + raise ValueError( + "SortingAnalyzer has no metrics extensions computed. " + "Compute quality_metrics and/or template_metrics first." + ) + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + if metrics_to_plot is None: + metrics_to_plot = [m for m in thresholds.keys() if m in combined_metrics.columns] + + plot_data = dict( + quality_metrics=combined_metrics, + thresholds=thresholds, + metrics_to_plot=metrics_to_plot, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + from .utils_matplotlib import make_mpl_figure + import matplotlib.pyplot as plt + + dp = to_attr(data_plot) + quality_metrics = dp.quality_metrics + thresholds = dp.thresholds + metrics_to_plot = dp.metrics_to_plot + + n_metrics = len(metrics_to_plot) + if n_metrics == 0: + print("No metrics to plot") + return + + n_cols = min(4, n_metrics) + n_rows = int(np.ceil(n_metrics / n_cols)) + backend_kwargs["ncols"] = n_cols + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (4 * n_cols, 3 * n_rows) + self.figure, self.ax, self.axes = make_mpl_figure(n_rows, n_cols, **backend_kwargs) + + colors = plt.cm.tab10(np.linspace(0, 1, 10)) + absolute_value_metrics = ["amplitude_median"] + + axes = self.axes + for idx, metric_name in enumerate(metrics_to_plot): + row, col = idx // n_cols, idx % n_cols + ax = axes[row, col] + + values = quality_metrics[metric_name].values + if metric_name in absolute_value_metrics: + values = np.abs(values) + values = values[~np.isnan(values) & ~np.isinf(values)] + + if len(values) == 0: + ax.set_title(f"{metric_name}\n(no valid data)") + continue + + ax.hist(values, bins=30, color=colors[idx % 10], alpha=0.7, edgecolor="black", density=True) + + thresh = thresholds.get(metric_name, {}) + has_thresh = False + if not _is_threshold_disabled(thresh.get("min", None)): + ax.axvline(thresh["min"], color="red", ls="--", lw=2, label=f"min={thresh['min']:.2g}") + has_thresh = True + if not _is_threshold_disabled(thresh.get("max", None)): + ax.axvline(thresh["max"], color="blue", ls="--", lw=2, label=f"max={thresh['max']:.2g}") + has_thresh = True + + ax.set_xlabel(metric_name) + ax.set_ylabel("Density") + if has_thresh: + ax.legend(fontsize=8, loc="upper right") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + for idx in range(len(metrics_to_plot), n_rows * n_cols): + axes[idx // n_cols, idx % n_cols].set_visible(False) + + +class UpsetPlotWidget(BaseWidget): + """ + Plot UpSet plots showing which metrics fail together for each unit type. + + Requires `upsetplot` package. Each unit type shows relevant metrics: + NOISE -> waveform metrics, MUA -> spike quality metrics, NON_SOMA -> non-somatic metrics. + """ + + def __init__( + self, + sorting_analyzer, + unit_type: np.ndarray, + unit_type_string: np.ndarray, + thresholds: Optional[dict] = None, + unit_types_to_plot: Optional[list] = None, + split_non_somatic: bool = False, + min_subset_size: int = 1, + backend=None, + **backend_kwargs, + ): + from spikeinterface.curation import bombcell_get_default_thresholds + + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + combined_metrics = sorting_analyzer.get_metrics_extension_data() + if combined_metrics.empty: + raise ValueError( + "SortingAnalyzer has no metrics extensions computed. " + "Compute quality_metrics and/or template_metrics first." + ) + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + if unit_types_to_plot is None: + if split_non_somatic: + unit_types_to_plot = ["noise", "mua", "non_soma_good", "non_soma_mua"] + else: + unit_types_to_plot = ["noise", "mua", "non_soma"] + + plot_data = dict( + quality_metrics=combined_metrics, + unit_type=unit_type, + unit_type_string=unit_type_string, + thresholds=thresholds, + unit_types_to_plot=unit_types_to_plot, + min_subset_size=min_subset_size, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def _get_metrics_for_unit_type(self, unit_type_label): + from spikeinterface.curation.bombcell_curation import ( + NOISE_METRICS, + SPIKE_QUALITY_METRICS, + NON_SOMATIC_METRICS, + ) + + if unit_type_label == "noise": + return NOISE_METRICS + elif unit_type_label == "mua": + return SPIKE_QUALITY_METRICS + elif unit_type_label in ("non_soma", "non_soma_good", "non_soma_mua"): + return NON_SOMATIC_METRICS + return None + + def plot_matplotlib(self, data_plot, **backend_kwargs): + from .utils_matplotlib import make_mpl_figure + import warnings + import matplotlib.pyplot as plt + import pandas as pd + + dp = to_attr(data_plot) + quality_metrics = dp.quality_metrics + unit_type_string = dp.unit_type_string + thresholds = dp.thresholds + unit_types_to_plot = dp.unit_types_to_plot + min_subset_size = dp.min_subset_size + + try: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning, module="upsetplot") + from upsetplot import UpSet, from_memberships + except ImportError: + fig, ax = plt.subplots(1, 1, figsize=(10, 6)) + ax.text( + 0.5, + 0.5, + "UpSet plots require 'upsetplot' package.\n\npip install upsetplot", + ha="center", + va="center", + fontsize=14, + family="monospace", + bbox=dict(boxstyle="round", facecolor="lightyellow", edgecolor="orange"), + ) + ax.axis("off") + ax.set_title("UpSet Plot - Package Not Installed", fontsize=16) + self.figure = fig + self.axes = ax + self.figures = [fig] + return + + failure_table = self._build_failure_table(quality_metrics, thresholds) + figures = [] + axes_list = [] + + for unit_type_label in unit_types_to_plot: + mask = unit_type_string == unit_type_label + n_units = np.sum(mask) + if n_units == 0: + continue + + relevant_metrics = self._get_metrics_for_unit_type(unit_type_label) + if relevant_metrics is not None: + available_metrics = [m for m in relevant_metrics if m in failure_table.columns] + if len(available_metrics) == 0: + continue + unit_failure_table = failure_table[available_metrics] + else: + unit_failure_table = failure_table + + unit_failures = unit_failure_table.loc[mask] + memberships = [] + for idx in unit_failures.index: + failed = unit_failures.columns[unit_failures.loc[idx]].tolist() + if failed: + memberships.append(failed) + + if not memberships: + continue + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning, module="upsetplot") + upset_data = from_memberships(memberships) + upset_data = upset_data[upset_data >= min_subset_size] + if len(upset_data) == 0: + continue + + fig = plt.figure(figsize=(12, 6)) + UpSet( + upset_data, + subset_size="count", + show_counts=True, + sort_by="cardinality", + sort_categories_by="cardinality", + ).plot(fig=fig) + fig.suptitle(f"{unit_type_label} (n={n_units})", fontsize=14, y=1.02) + figures.append(fig) + axes_list.append(fig.axes) + + if not figures: + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.text(0.5, 0.5, "No units found or no metric failures detected.", ha="center", va="center", fontsize=12) + ax.axis("off") + figures = [fig] + axes_list = [ax] + + self.figures = figures + self.figure = figures[0] if figures else None + self.axes = axes_list + + def _build_failure_table(self, quality_metrics, thresholds): + import pandas as pd + + absolute_value_metrics = ["amplitude_median"] + failure_data = {} + + for metric_name, thresh in thresholds.items(): + if metric_name not in quality_metrics.columns: + continue + values = quality_metrics[metric_name].values.copy() + if metric_name in absolute_value_metrics: + values = np.abs(values) + + failed = np.isnan(values) + if not _is_threshold_disabled(thresh.get("min", None)): + failed |= values < thresh["min"] + if not _is_threshold_disabled(thresh.get("max", None)): + failed |= values > thresh["max"] + failure_data[metric_name] = failed + + return pd.DataFrame(failure_data, index=quality_metrics.index) + + +def plot_unit_labeling_all( + sorting_analyzer, + unit_type: np.ndarray, + unit_type_string: np.ndarray, + thresholds: Optional[dict] = None, + split_non_somatic: bool = False, + include_upset: bool = True, + save_folder=None, + backend=None, + **kwargs, +): + """ + Generate all unit labeling plots and optionally save to folder. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object with computed metrics extensions. + unit_type : np.ndarray + Array of unit type codes (0=noise, 1=good, 2=mua, 3=non_soma, etc.). + unit_type_string : np.ndarray + Array of unit type labels as strings. + thresholds : dict, optional + Threshold dictionary. If None, uses default thresholds. + split_non_somatic : bool, default: False + Whether to split "non_soma" into "non_soma_good" and "non_soma_mua". + include_upset : bool, default: True + Whether to include UpSet plots (requires upsetplot package). + save_folder : str or Path, optional + If provided, saves all plots and CSV results to this folder. + backend : str, optional + Plotting backend. + **kwargs + Additional arguments passed to plot functions. + + Returns + ------- + dict + Dictionary with keys 'histograms', 'waveforms', 'upset' containing widget objects. + """ + from pathlib import Path + from spikeinterface.curation import bombcell_get_default_thresholds, save_bombcell_results + + if thresholds is None: + thresholds = bombcell_get_default_thresholds() + + combined_metrics = sorting_analyzer.get_metrics_extension_data() + has_metrics = not combined_metrics.empty + + results = {} + + # Histograms + if has_metrics: + results["histograms"] = LabelingHistogramsWidget( + sorting_analyzer, + thresholds=thresholds, + backend=backend, + **kwargs, + ) + + # Waveform overlay + results["waveforms"] = WaveformOverlayByLabelWidget(sorting_analyzer, unit_type_string, backend=backend, **kwargs) + + # UpSet plots + if include_upset and has_metrics: + results["upset"] = UpsetPlotWidget( + sorting_analyzer, + unit_type, + unit_type_string, + thresholds=thresholds, + split_non_somatic=split_non_somatic, + backend=backend, + **kwargs, + ) + + # Save to folder if requested + if save_folder is not None: + save_folder = Path(save_folder) + save_folder.mkdir(parents=True, exist_ok=True) + + # Save plots + if "histograms" in results and results["histograms"].figure is not None: + results["histograms"].figure.savefig(save_folder / "labeling_histograms.png", dpi=150, bbox_inches="tight") + if "waveforms" in results and results["waveforms"].figure is not None: + results["waveforms"].figure.savefig(save_folder / "waveform_overlay.png", dpi=150, bbox_inches="tight") + if "upset" in results and hasattr(results["upset"], "figures"): + for i, fig in enumerate(results["upset"].figures): + fig.savefig(save_folder / f"upset_plot_{i}.png", dpi=150, bbox_inches="tight") + + # Save CSV results + if has_metrics: + save_bombcell_results(combined_metrics, unit_type, unit_type_string, thresholds, save_folder) + + return results diff --git a/src/spikeinterface/widgets/unit_labels.py b/src/spikeinterface/widgets/unit_labels.py new file mode 100644 index 0000000000..03ee0b2391 --- /dev/null +++ b/src/spikeinterface/widgets/unit_labels.py @@ -0,0 +1,114 @@ +"""Widgets for visualizing unit labeling results.""" + +from __future__ import annotations + +import numpy as np + +from .base import BaseWidget, to_attr + + +class WaveformOverlayByLabelWidget(BaseWidget): + """Plot overlaid waveforms grouped by unit label type. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object with 'templates' extension computed. + unit_labels : np.ndarray + Array of unit type labels corresponding to each unit in the sorting. + labels_order : list, optional + List specifying the order of labels to display. If None, unique labels in unit_labels are + used in the order they appear. + max_columns : int, default: 3 + Maximum number of columns in the plot grid. + """ + + def __init__( + self, + sorting_analyzer, + unit_labels: np.ndarray, + labels_order=None, + max_columns: int = 3, + backend=None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "templates") + if labels_order is not None: + assert len(labels_order) == len(np.unique(unit_labels)), "labels_order length must match unique unit types" + assert all( + [label in np.unique(unit_labels) for label in labels_order] + ), "All labels in labels_order must be present in unit_labels" + else: + labels_order = np.unique(unit_labels) + plot_data = dict( + sorting_analyzer=sorting_analyzer, + labels_order=labels_order, + unit_labels=unit_labels, + max_columns=max_columns, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting_analyzer = dp.sorting_analyzer + unit_labels = dp.unit_labels + labels_order = dp.labels_order + + if not sorting_analyzer.has_extension("templates"): + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.text( + 0.5, + 0.5, + "Templates extension not computed.\nRun: analyzer.compute('templates')", + ha="center", + va="center", + fontsize=12, + ) + ax.axis("off") + self.figure = fig + self.axes = ax + return + + templates_ext = sorting_analyzer.get_extension("templates") + templates = templates_ext.get_templates(operator="average") + + backend_kwargs["num_axes"] = len(labels_order) + if len(labels_order) <= dp.max_columns: + ncols = len(labels_order) + else: + ncols = int(np.ceil(len(labels_order) / 2)) + nrows = int(np.ceil(len(labels_order) / ncols)) + backend_kwargs["ncols"] = ncols + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (5 * ncols, 4 * nrows) + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + axes_flat = self.axes.flatten() + for index, label in enumerate(labels_order): + ax = axes_flat[index] + mask = unit_labels == label + n_units = np.sum(mask) + + if n_units > 0: + unit_indices = np.where(mask)[0] + alpha = max(0.05, min(0.3, 10 / n_units)) + for unit_idx in unit_indices: + template = templates[unit_idx] + best_chan = np.argmax(np.max(np.abs(template), axis=0)) + ax.plot(template[:, best_chan], color="black", alpha=alpha, linewidth=0.5) + ax.set_title(f"{label} (n={n_units})") + else: + ax.set_title(f"{label} (n=0)") + ax.text(0.5, 0.5, "No units", ha="center", va="center", transform=ax.transAxes) + + for spine in ax.spines.values(): + spine.set_visible(False) + ax.set_xticks([]) + ax.set_yticks([]) + + for idx in range(len(labels_order), len(axes_flat)): + axes_flat[idx].set_visible(False) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6edba67c96..fe191d4450 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -37,12 +37,19 @@ from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget +from .unit_labels import WaveformOverlayByLabelWidget +from .bombcell_curation import ( + LabelingHistogramsWidget, + UpsetPlotWidget, + plot_unit_labeling_all, +) widget_list = [ AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + LabelingHistogramsWidget, ConfusionMatrixWidget, ComparisonCollisionBySimilarityWidget, CrossCorrelogramsWidget, @@ -75,6 +82,8 @@ UnitTemplatesWidget, UnitWaveformDensityMapWidget, UnitWaveformsWidget, + UpsetPlotWidget, + WaveformOverlayByLabelWidget, StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, @@ -148,6 +157,9 @@ plot_template_similarity = TemplateSimilarityWidget plot_traces = TracesWidget plot_unit_depths = UnitDepthsWidget +plot_unit_labels = WaveformOverlayByLabelWidget +plot_unit_labeling_upset = UpsetPlotWidget +plot_unit_labeling_histograms = LabelingHistogramsWidget plot_unit_locations = UnitLocationsWidget plot_unit_presence = UnitPresenceWidget plot_unit_probe_map = UnitProbeMapWidget