From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 01/13] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 61c317aba92608d9f096a3a374bc3d43e27faaba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Mar 2026 10:09:46 -0800 Subject: [PATCH 02/13] Fix OpenEphys tests --- .../extractors/neoextractors/openephys.py | 20 ++++++++++++------- .../extractors/tests/test_neoextractors.py | 3 +++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 1c39a1b97c..1d16df534b 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -351,13 +351,19 @@ def __init__( # Ensure device channel index corresponds to channel_ids probe_channel_names = probe.contact_annotations.get("channel_name", None) if probe_channel_names is not None and not np.array_equal(probe_channel_names, self.channel_ids): - device_channel_indices = [] - probe_channel_names = list(probe_channel_names) - device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) - for i, ch in enumerate(self.channel_ids): - index_in_probe = probe_channel_names.index(ch) - device_channel_indices[index_in_probe] = i - probe.set_device_channel_indices(device_channel_indices) + if set(probe_channel_names) == set(self.channel_ids): + device_channel_indices = [] + probe_channel_names = list(probe_channel_names) + device_channel_indices = np.zeros(len(self.channel_ids), dtype=int) + for i, ch in enumerate(self.channel_ids): + index_in_probe = probe_channel_names.index(ch) + device_channel_indices[index_in_probe] = i + probe.set_device_channel_indices(device_channel_indices) + else: + warnings.warn( + "Channel names in the probe do not match the channel ids from Neo. " + "Cannot set device channel indices, but this might lead to incorrect probe geometries" + ) if probe.shank_ids is not None: self.set_probe(probe, in_place=True, group_mode="by_shank") diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index f80f62ebf0..f40b4d05ab 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -121,6 +121,9 @@ class OpenEphysBinaryRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "0"}), ("openephysbinary/v0.5.x_two_nodes", {"stream_id": "1"}), ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "0", "block_index": 0}), + # TODO: block_indices 1/2 of v0.6.x_neuropixels_multiexp_multistream have a mismatch in the channel names between + # the settings files (starting with CH0) and structure.oebin (starting at CH1). + # Currently, the extractor will skip remapping to match order in oebin and settings file, raising a warning ("openephysbinary/v0.6.x_neuropixels_multiexp_multistream", {"stream_id": "1", "block_index": 1}), ( "openephysbinary/v0.6.x_neuropixels_multiexp_multistream", From 49c51dadf9f802e772e83f7ee23a5f33be66a2ac Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 20 Apr 2026 17:34:19 +0200 Subject: [PATCH 03/13] feat: implement DetectAndRemoveArtifacts and signed saturation --- .../preprocessing/detect_artifacts.py | 325 +++++++++++++----- .../preprocessing/preprocessing_classes.py | 3 + .../tests/test_detect_artifacts.py | 114 ++++++ 3 files changed, 357 insertions(+), 85 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2f2c740616..4a4dbf17f3 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -4,19 +4,24 @@ from spikeinterface.core import BaseRecording from spikeinterface.core.base import base_period_dtype +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_random_data_chunks +from spikeinterface.core.node_pipeline import PeakDetector, run_node_pipeline, PipelineNode from spikeinterface.preprocessing.rectify import RectifyRecording from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels, get_random_data_chunks -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording artifact_dtype = base_period_dtype -# this will be extend with channel boundaries if needed -# extended_artifact_dtype = artifact_dtype + [ -# # TODO -# ] +signed_artifact_dtype = np.dtype(artifact_dtype + [("sign", "U8")]) + + +def _indent_docstring(docstring: str, indent: int = 4) -> str: + indent_str = " " * indent + indented_lines = [(indent_str + line) if line.strip() else line for line in docstring.splitlines()] + return "\n".join(indented_lines) def _collapse_events(events: np.ndarray) -> np.ndarray: @@ -77,6 +82,7 @@ def __init__( saturation_threshold_uV: float, diff_threshold_uV: float | None, proportion: float, + signed: bool, ) -> None: """ Parameters @@ -91,6 +97,11 @@ def __init__( proportion : float Fraction of channels that must exceed the threshold for a sample to be labelled as saturated (0 < proportion < 1). + signed : bool + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``. """ PipelineNode.__init__(self, recording, return_output=True) @@ -102,6 +113,7 @@ def __init__( # slightly lower than the documented saturation point of the probe self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion + self.signed = signed self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() @@ -122,6 +134,27 @@ def get_dtype(self) -> np.dtype: """Return the NumPy dtype of the output array produced by :meth:`compute`.""" return self._dtype + def detect_in_chunk(self, traces, saturation_threshold, diff_threshold, proportion) -> np.ndarray: + saturation = np.mean(traces > saturation_threshold, axis=1) + detected_by_value = saturation > proportion + + if diff_threshold is not None: + # then compute the derivative of the voltage saturation + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= diff_threshold, axis=1) + + # Note this means the velocity is not checked for the last sample in the + # check because we are taking the forward derivative + n_diff_saturated = np.r_[n_diff_saturated, 0] + + # if either of those reaches more than the proportion of channels labels the sample as saturated + detected_by_diff = n_diff_saturated > proportion + saturation = np.logical_or(detected_by_value, detected_by_diff) + else: + saturation = detected_by_value + + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) + return intervals + def compute( self, traces: np.ndarray, @@ -165,33 +198,63 @@ def compute( # cast to float32 to prevent overflow when applying thresholds in unscaled ADC units traces = traces.astype("float32") - saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - detected_by_value = saturation > self.proportion - - if self.diff_threshold_unscaled is not None: - # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.diff_threshold_unscaled, axis=1) - - # Note this means the velocity is not checked for the last sample in the - # check because we are taking the forward derivative - n_diff_saturated = np.r_[n_diff_saturated, 0] - - # if either of those reaches more than the proportion of channels labels the sample as saturated - detected_by_diff = n_diff_saturated > self.proportion - saturation = np.logical_or(detected_by_value, detected_by_diff) + if not self.signed: + traces = np.abs(traces) + intervals = self.detect_in_chunk( + traces, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=artifact_dtype) + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index else: - saturation = detected_by_value + all_events = [] + for sign in (1, -1): + traces_signed = sign * traces + intervals = self.detect_in_chunk( + traces_signed, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=signed_artifact_dtype) + events["sign"] = "positive" if sign == 1 else "negative" + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index + all_events.append(events) + all_events = np.concatenate(all_events) + # sort by start sample index + order = np.argsort(all_events["start_sample_index"]) + all_events = all_events[order] + return (all_events,) - intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) - n_events = len(intervals) // 2 # Number of saturation periods - events = np.zeros(n_events, dtype=artifact_dtype) + return (events,) - for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): - events[i]["start_sample_index"] = start + start_frame - events[i]["end_sample_index"] = stop + start_frame - events[i]["segment_index"] = segment_index - return (events,) +_detect_saturation_periods_params = """saturation_threshold_uV : float | None, default: None + Voltage saturation threshold in μV. The appropriate value depends on + the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL + recommend **1200 μV**. NP2 probes are harder to saturate than NP1. + If ``None``, the value is read from the ``"saturation_threshold_uV"`` + annotation of ``recording``. +diff_threshold_uV : float | None, default: None + First-derivative threshold in μV/sample. Periods where the + sample-to-sample voltage change exceeds this value in the required + fraction of channels are flagged as saturation. Pass ``None`` to + disable derivative-based detection and rely solely on + ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. +proportion : float, default: 0.2 + Fraction of channels (0 < proportion < 1) that must exceed the + threshold for a sample to be considered saturated. +signed : bool, default: False + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``.""" def detect_saturation_periods( @@ -199,6 +262,7 @@ def detect_saturation_periods( saturation_threshold_uV: float | None = None, diff_threshold_uV: float | None = None, proportion: float = 0.2, + signed: bool = False, job_kwargs: dict | None = None, ) -> np.ndarray: """ @@ -224,21 +288,7 @@ def detect_saturation_periods( ---------- recording : BaseRecording The recording on which to detect saturation events. - saturation_threshold_uV : float | None, default: None - Voltage saturation threshold in μV. The appropriate value depends on - the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL - recommend **1200 μV**. NP2 probes are harder to saturate than NP1. - If ``None``, the value is read from the ``"saturation_threshold_uV"`` - annotation of ``recording``. - diff_threshold_uV : float | None, default: None - First-derivative threshold in μV/sample. Periods where the - sample-to-sample voltage change exceeds this value in the required - fraction of channels are flagged as saturation. Pass ``None`` to - disable derivative-based detection and rely solely on - ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. - proportion : float, default: 0.2 - Fraction of channels (0 < proportion < 1) that must exceed the - threshold for a sample to be considered saturated. + {} job_kwargs : dict | None, default: None Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. ``n_jobs``, ``chunk_duration``). @@ -270,6 +320,7 @@ def detect_saturation_periods( saturation_threshold_uV=saturation_threshold_uV, diff_threshold_uV=diff_threshold_uV, proportion=proportion, + signed=signed, ) saturation_periods = run_node_pipeline( @@ -278,6 +329,9 @@ def detect_saturation_periods( return _collapse_events(saturation_periods) +detect_saturation_periods.__doc__ = detect_saturation_periods.__doc__.format(_detect_saturation_periods_params) + + ## detect_artifact_periods_by_envelope zone class _DetectThresholdCrossing(PeakDetector): """ @@ -382,6 +436,23 @@ def compute( return (threshold_crossings,) +_detect_artifacts_by_envelope_params = """detect_threshold : float, default: 5 + Detection threshold as a multiple of the estimated per-channel noise + level of the envelope. +freq_max : float, default: 20.0 + Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the + rectified signal when building the envelope. +seed : int | None, default: None + Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. + If ``None``, ``get_noise_levels`` uses ``seed=0``. +job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). +random_slices_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the ``random_slices_kwargs`` + argument of :func:`~spikeinterface.core.get_noise_levels`.""" + + def detect_artifact_periods_by_envelope( recording: BaseRecording, detect_threshold: float = 5, @@ -410,21 +481,7 @@ def detect_artifact_periods_by_envelope( ---------- recording : BaseRecording The recording extractor from which to detect artefact periods. - detect_threshold : float, default: 5 - Detection threshold as a multiple of the estimated per-channel noise - level of the envelope. - freq_max : float, default: 20.0 - Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the - rectified signal when building the envelope. - seed : int | None, default: None - Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. - If ``None``, ``get_noise_levels`` uses ``seed=0``. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). - random_slices_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the ``random_slices_kwargs`` - argument of :func:`~spikeinterface.core.get_noise_levels`. + {} return_envelope : bool, default: False If ``True``, also return the intermediate envelope recording so that it can be inspected or plotted. @@ -475,7 +532,6 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] artifacts = _collapse_events(artifacts) if return_envelope: @@ -484,6 +540,11 @@ def detect_artifact_periods_by_envelope( return artifacts +detect_artifact_periods_by_envelope.__doc__ = detect_artifact_periods_by_envelope.__doc__.format( + _detect_artifacts_by_envelope_params +) + + def _transform_internal_dtype_to_artifact_dtype( artifacts: np.ndarray, recording: BaseRecording, @@ -561,35 +622,41 @@ def detect_artifact_periods( job_kwargs: dict | None = None, ) -> np.ndarray: """ - Detect artifact periods using one of several available methods. + Detect artifact periods using one of several available methods. - Available methods: + Available methods: - * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified - channel envelope. - * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. - See the documentation of each sub-function for a full description of their - parameters, which can be forwarded via ``method_kwargs``. + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. - Parameters - ---------- - recording : BaseRecording - The recording on which to detect artifact periods. - method : {"envelope", "saturation"}, default: "envelope" - Detection method to use. - method_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the selected detection - function. Pass ``None`` to use that function's defaults. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). + Method-specific parameters include: - Returns - ------- - np.ndarray - Array with dtype ``artifact_dtype`` describing each detected artifact - period. + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. """ assert ( method in _method_to_function @@ -600,3 +667,91 @@ def detect_artifact_periods( artifact_periods = _method_to_function[method](recording, job_kwargs=job_kwargs, **method_kwargs) return artifact_periods + + +detect_artifact_periods.__doc__ = detect_artifact_periods.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) + + +class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): + """ + Detect and remove artifact periods using one of several available methods. + + Available methods: + + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. + + Method-specific parameters include: + + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + """ + + def __init__( + self, + recording: BaseRecording, + recording_to_detect: BaseRecording | None = None, + method: Literal["envelope", "saturation"] = "envelope", + method_kwargs: dict | None = None, + job_kwargs: dict | None = None, + mode: Literal["zeros", "noise"] = "zeros", + noise_levels_kwargs: dict | None = None, + seed: int | None = None, + artifact_periods=None, + ) -> None: + if artifact_periods is not None: + artifact_periods = artifact_periods + else: + if recording_to_detect is None: + recording_to_detect = recording + artifact_periods = detect_artifact_periods( + recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs + ) + super().__init__( + recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed + ) + + self._kwargs = dict( + recording=recording, + recording_to_detect=recording_to_detect, + method=method, + method_kwargs=method_kwargs, + job_kwargs=job_kwargs, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + artifact_periods=artifact_periods, + ) + + +# function for API +detect_and_remove_artifacts = define_function_handling_dict_from_class( + source_class=DetectAndRemoveArtifactsRecording, name="detect_and_remove_artifacts" +) + +detect_and_remove_artifacts.__doc__ = detect_and_remove_artifacts.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index 47e3c0906b..6ada71de44 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -48,6 +48,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed +from .detect_artifacts import DetectAndRemoveArtifactsRecording, detect_and_remove_artifacts # from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts @@ -72,6 +73,8 @@ # bad channel detection/interpolation DetectAndRemoveBadChannelsRecording: detect_and_remove_bad_channels, DetectAndInterpolateBadChannelsRecording: detect_and_interpolate_bad_channels, + # artifact/saturation handling + DetectAndRemoveArtifactsRecording: detect_and_remove_artifacts, # misc RectifyRecording: rectify, ClipRecording: clip, diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 1d99206bd2..76347f56b8 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -5,6 +5,7 @@ detect_artifact_periods, detect_saturation_periods, detect_artifact_periods_by_envelope, + detect_and_remove_artifacts, ) @@ -238,6 +239,119 @@ def test_detect_saturation_periods(debug_plots): assert np.array_equal(periods, periods_entry_with_annotation) +def test_detect_saturation_signed(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + # Inject positive saturation in first third, negative in second third + pos_start, pos_stop = 15000, 15500 + neg_start, neg_stop = 45000, 45500 + data[pos_start:pos_stop, :] = sat_value + data[neg_start:neg_stop, :] = -sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + periods = detect_saturation_periods( + recording, + saturation_threshold_uV=sat_value * 0.98, + signed=True, + job_kwargs=job_kwargs, + ) + + # Output dtype must include the "sign" field + assert "sign" in periods.dtype.names + + pos_periods = periods[periods["sign"] == "positive"] + neg_periods = periods[periods["sign"] == "negative"] + assert len(pos_periods) > 0, "No positive saturation periods detected" + assert len(neg_periods) > 0, "No negative saturation periods detected" + + # Positive period should be near the injected positive saturation + tolerance = 1 + assert np.any(np.abs(pos_periods["start_sample_index"] - pos_start) <= tolerance) + assert np.any(np.abs(pos_periods["end_sample_index"] - pos_stop) <= tolerance) + + # Negative period should be near the injected negative saturation + assert np.any(np.abs(neg_periods["start_sample_index"] - neg_start) <= tolerance) + assert np.any(np.abs(neg_periods["end_sample_index"] - neg_stop) <= tolerance) + + # Positive periods must not contain any sample indices from the negative injection + for p in pos_periods: + assert not (p["start_sample_index"] < neg_stop and p["end_sample_index"] > neg_start) + + # Negative periods must not contain any sample indices from the positive injection + for p in neg_periods: + assert not (p["start_sample_index"] < pos_stop and p["end_sample_index"] > pos_start) + + +def test_detect_and_remove_artifacts(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + sat_start, sat_stop = 15000, 15500 + data[sat_start:sat_stop, :] = sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + # Basic usage: detect and zero out saturation in one step + cleaned = detect_and_remove_artifacts( + recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + traces = cleaned.get_traces(segment_index=0) + assert traces[sat_start + 100, 0] == 0, "Saturated samples should be zeroed" + assert traces[0, 0] != 0, "Non-saturated samples should not be zeroed" + + # recording_to_detect: detect on raw recording, silence a separate (processed) recording + # We use the same recording here just to exercise the code path + cleaned_with_detect = detect_and_remove_artifacts( + recording, + recording_to_detect=recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + assert np.array_equal(cleaned.get_traces(), cleaned_with_detect.get_traces()) + + if __name__ == "__main__": # test_detect_artifact_by_envelope(True) test_detect_saturation_periods(False) From 99afff64336c3b246e0da49ee11e99f4ff5ba100 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 20 Apr 2026 18:10:27 +0200 Subject: [PATCH 04/13] Apply suggestion from @alejoe91 --- src/spikeinterface/preprocessing/detect_artifacts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 4a4dbf17f3..a03082747c 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -229,8 +229,7 @@ def compute( all_events = np.concatenate(all_events) # sort by start sample index order = np.argsort(all_events["start_sample_index"]) - all_events = all_events[order] - return (all_events,) + events = all_events[order] return (events,) From 9c25668a211317100e65ab48ec9f61071d39738e Mon Sep 17 00:00:00 2001 From: Olivier Winter Date: Wed, 22 Apr 2026 21:28:46 +0100 Subject: [PATCH 05/13] saturation application with apodization --- .../preprocessing/detect_artifacts.py | 4 +++- .../preprocessing/silence_periods.py | 21 +++++++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index a03082747c..a68824ca72 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -717,6 +717,7 @@ def __init__( job_kwargs: dict | None = None, mode: Literal["zeros", "noise"] = "zeros", noise_levels_kwargs: dict | None = None, + apodization: int = 7, seed: int | None = None, artifact_periods=None, ) -> None: @@ -729,7 +730,7 @@ def __init__( recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs ) super().__init__( - recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed + recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed, apodization=apodization ) self._kwargs = dict( @@ -742,6 +743,7 @@ def __init__( noise_levels_kwargs=noise_levels_kwargs, seed=seed, artifact_periods=artifact_periods, + apodization=apodization, ) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 393c712919..c77e8cd43a 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -1,4 +1,5 @@ import numpy as np +import scipy.signal from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -48,14 +49,15 @@ def __init__( self, recording, periods=None, - # this is keep for backward compatibility + # this is kept for backward compatibility list_periods=None, mode="zeros", noise_levels=None, + apodization=7, seed=None, **noise_levels_kwargs, ): - available_modes = ("zeros", "noise") + available_modes = ("zeros", "noise", "apodization") num_seg = recording.get_num_segments() # handle backward compatibility with previous version @@ -108,11 +110,11 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index + parent_segment, periods_in_seg, mode, noise_generator, seg_index, apodization=apodization, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels) + self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels, apodization=apodization) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -154,12 +156,13 @@ def _check_periods(periods, num_seg): class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index): + def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization=7): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.periods = periods self.mode = mode self.seg_index = seg_index self.noise_generator = noise_generator + self.apodization = apodization def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) @@ -185,7 +188,13 @@ def get_traces(self, start_frame, end_frame, channel_indices): :, channel_indices ] traces[onset:offset, :] = noise[onset:offset] - + elif self.mode == "apodization": + # apply a cosine taper to the saturation to create a mute function + mute = np.zeros(traces.shape[0], dtype=np.float32) + mute[onset:offset] = 1 + win = scipy.signal.windows.cosine(self.apodization) + mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same")) + traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) return traces From 16368173cb3ed65bb39055ffc8ca294fbc849e7b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Apr 2026 13:26:43 +0000 Subject: [PATCH 06/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../preprocessing/detect_artifacts.py | 7 ++++++- .../preprocessing/silence_periods.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index a68824ca72..bb183ccbf4 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -730,7 +730,12 @@ def __init__( recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs ) super().__init__( - recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed, apodization=apodization + recording, + periods=artifact_periods, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + apodization=apodization, ) self._kwargs = dict( diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c77e8cd43a..db984b6572 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -110,11 +110,23 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index, apodization=apodization, + parent_segment, + periods_in_seg, + mode, + noise_generator, + seg_index, + apodization=apodization, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels, apodization=apodization) + self._kwargs = dict( + recording=recording, + periods=periods, + mode=mode, + seed=seed, + noise_levels=noise_levels, + apodization=apodization, + ) def _all_period_list_to_periods_vec(list_periods, num_seg): From 230507fb0dae638ccad123cd11c3ef15547c4b23 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 23 Apr 2026 16:20:20 +0200 Subject: [PATCH 07/13] Add docstrings for apodization and tests --- .../preprocessing/detect_artifacts.py | 29 +++++++++++- .../preprocessing/silence_periods.py | 45 +++++++++++++------ .../tests/test_silence_periods.py | 21 ++++++++- 3 files changed, 78 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index a68824ca72..18cbd2f67d 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -706,6 +706,26 @@ class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): job_kwargs : dict | None, default: None Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. ``n_jobs``, ``chunk_duration``). + noise_levels_kwargs : dict | None, default: None + Keyword arguments for `spikeinterface.core.get_noise_levels()` function. + + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. + mode : "zeros" | "noise" | "apodization", default: "zeros" + Determines what periods are replaced by. Can be one of the following: + + - "zeros": Artifacts are replaced by zeros. + + - "noise": The periods are filled with a gaussion noise that has the + same variance that the one in the recordings, on a per channel + basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) + apodization_factor : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + artifact_periods : np.ndarray | None, default: None + Optionally, pre-computed artifact periods can be passed directly to the constructor to skip the + detection step. If ``None``, artifact periods are detected on the fly using the specified method """ def __init__( @@ -715,7 +735,7 @@ def __init__( method: Literal["envelope", "saturation"] = "envelope", method_kwargs: dict | None = None, job_kwargs: dict | None = None, - mode: Literal["zeros", "noise"] = "zeros", + mode: Literal["zeros", "noise", "apodization"] = "zeros", noise_levels_kwargs: dict | None = None, apodization: int = 7, seed: int | None = None, @@ -730,7 +750,12 @@ def __init__( recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs ) super().__init__( - recording, periods=artifact_periods, mode=mode, noise_levels_kwargs=noise_levels_kwargs, seed=seed, apodization=apodization + recording, + periods=artifact_periods, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + apodization=apodization, ) self._kwargs = dict( diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index c77e8cd43a..5fadc1969c 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -22,14 +22,11 @@ class SilencedPeriodsRecording(BasePreprocessor): ---------- recording : RecordingExtractor The recording extractor to silance periods - list_periods : list of lists/arrays - One list per segment of tuples (start_frame, end_frame) to silence - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise, default: "zeros" + periods : np.array + A numpy array with dtype `base_period_dtype` and fields + "segment_index", "start_sample_index", "end_sample_index". + Each row corresponds to a period to silence. + mode : "zeros" | "noise" | "apodization", default: "zeros" Determines what periods are replaced by. Can be one of the following: - "zeros": Artifacts are replaced by zeros. @@ -37,6 +34,14 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) + apodization_factor : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + noise_levels : array + Noise levels if already computed + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function Returns @@ -52,8 +57,8 @@ def __init__( # this is kept for backward compatibility list_periods=None, mode="zeros", + apodization_factor=7, noise_levels=None, - apodization=7, seed=None, **noise_levels_kwargs, ): @@ -110,11 +115,23 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index, apodization=apodization, + parent_segment, + periods_in_seg, + mode, + noise_generator, + seg_index, + apodization_factor=apodization_factor, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels, apodization=apodization) + self._kwargs = dict( + recording=recording, + periods=periods, + mode=mode, + seed=seed, + noise_levels=noise_levels, + apodization_factor=apodization_factor, + ) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -156,13 +173,13 @@ def _check_periods(periods, num_seg): class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization=7): + def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization_factor=7): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.periods = periods self.mode = mode self.seg_index = seg_index self.noise_generator = noise_generator - self.apodization = apodization + self.apodization_factor = apodization_factor def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) @@ -192,7 +209,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # apply a cosine taper to the saturation to create a mute function mute = np.zeros(traces.shape[0], dtype=np.float32) mute[onset:offset] = 1 - win = scipy.signal.windows.cosine(self.apodization) + win = scipy.signal.windows.cosine(self.apodization_factor) mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same")) traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) return traces diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index 44bd205f1b..4eee8646dc 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -39,7 +39,6 @@ def test_silence(create_cache_folder): data1 = rec.get_traces(0, 400, 600) data2 = rec.get_traces(0, 500, 700) assert np.all(data1[100:] == data2[:100]) - traces_mix = rec0.get_traces(segment_index=0, start_frame=900, end_frame=5100) traces_original = rec.get_traces(segment_index=0, start_frame=900, end_frame=5100) assert np.all(traces_original[100:-100] == traces_mix[100:-100]) @@ -48,6 +47,26 @@ def test_silence(create_cache_folder): assert not np.all(traces_mix[:200] == 0) assert not np.all(traces_mix[:-200] == 0) + # test that the apodization creates a taper + apodization_factor = 10 + rec2 = silence_periods(rec, periods=periods, mode="apodization", apodization_factor=apodization_factor) + rec2 = rec2.save(format="memory", verbose=False, overwrite=True) + traces_in0 = rec2.get_traces(segment_index=0, start_frame=0, end_frame=1000) + traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000, end_frame=6000) + # all apodized traces + assert np.all(traces_in0 == 0) + assert np.all(traces_in1 == 0) + + # at margins, traces should not be all zero, but should be apodized + apodized_traces_in0 = rec2.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) + apodized_traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) + traces_raw_in0 = rec.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) + traces_raw_in1 = rec.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) + # the apodized traces should be less than the raw traces in absolute value, + # since they are multiplied by a cosine taper between 0 and 1 + assert np.all(np.abs(apodized_traces_in0) <= np.abs(traces_raw_in0)) + assert np.all(np.abs(apodized_traces_in1) <= np.abs(traces_raw_in1)) + if __name__ == "__main__": cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" From cbacc4c5ba567a760a84ea49e509a88326019ad0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 24 Apr 2026 16:11:16 +0200 Subject: [PATCH 08/13] fix: apodization_factor -> apodization_samples and move scipy import --- .../preprocessing/detect_artifacts.py | 4 ++-- .../preprocessing/silence_periods.py | 19 ++++++++++--------- .../tests/test_silence_periods.py | 12 ++++++------ 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 18cbd2f67d..72acfd22cc 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -718,8 +718,8 @@ class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis - - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) - apodization_factor : int, default: 7 + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_samples`) + apodization_samples : int, default: 7 The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. seed : int | None, default: None Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 5fadc1969c..ab8734e096 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -1,5 +1,4 @@ import numpy as np -import scipy.signal from spikeinterface.core.core_tools import define_function_handling_dict_from_class from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment @@ -34,8 +33,8 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis - - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_factor`) - apodization_factor : int, default: 7 + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_samples`) + apodization_samples : int, default: 7 The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. noise_levels : array Noise levels if already computed @@ -57,7 +56,7 @@ def __init__( # this is kept for backward compatibility list_periods=None, mode="zeros", - apodization_factor=7, + apodization_samples=7, noise_levels=None, seed=None, **noise_levels_kwargs, @@ -120,7 +119,7 @@ def __init__( mode, noise_generator, seg_index, - apodization_factor=apodization_factor, + apodization_samples=apodization_samples, ) self.add_recording_segment(rec_segment) @@ -130,7 +129,7 @@ def __init__( mode=mode, seed=seed, noise_levels=noise_levels, - apodization_factor=apodization_factor, + apodization_samples=apodization_samples, ) @@ -173,13 +172,13 @@ def _check_periods(periods, num_seg): class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization_factor=7): + def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization_samples=7): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.periods = periods self.mode = mode self.seg_index = seg_index self.noise_generator = noise_generator - self.apodization_factor = apodization_factor + self.apodization_samples = apodization_samples def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) @@ -206,10 +205,12 @@ def get_traces(self, start_frame, end_frame, channel_indices): ] traces[onset:offset, :] = noise[onset:offset] elif self.mode == "apodization": + import scipy.signal + # apply a cosine taper to the saturation to create a mute function mute = np.zeros(traces.shape[0], dtype=np.float32) mute[onset:offset] = 1 - win = scipy.signal.windows.cosine(self.apodization_factor) + win = scipy.signal.windows.cosine(self.apodization_samples) mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same")) traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) return traces diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index 4eee8646dc..e1338a8887 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -48,8 +48,8 @@ def test_silence(create_cache_folder): assert not np.all(traces_mix[:-200] == 0) # test that the apodization creates a taper - apodization_factor = 10 - rec2 = silence_periods(rec, periods=periods, mode="apodization", apodization_factor=apodization_factor) + apodization_samples = 10 + rec2 = silence_periods(rec, periods=periods, mode="apodization", apodization_samples=apodization_samples) rec2 = rec2.save(format="memory", verbose=False, overwrite=True) traces_in0 = rec2.get_traces(segment_index=0, start_frame=0, end_frame=1000) traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000, end_frame=6000) @@ -58,10 +58,10 @@ def test_silence(create_cache_folder): assert np.all(traces_in1 == 0) # at margins, traces should not be all zero, but should be apodized - apodized_traces_in0 = rec2.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) - apodized_traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) - traces_raw_in0 = rec.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_factor) - traces_raw_in1 = rec.get_traces(segment_index=0, start_frame=5000 - apodization_factor, end_frame=5000) + apodized_traces_in0 = rec2.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_samples) + apodized_traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000 - apodization_samples, end_frame=5000) + traces_raw_in0 = rec.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_samples) + traces_raw_in1 = rec.get_traces(segment_index=0, start_frame=5000 - apodization_samples, end_frame=5000) # the apodized traces should be less than the raw traces in absolute value, # since they are multiplied by a cosine taper between 0 and 1 assert np.all(np.abs(apodized_traces_in0) <= np.abs(traces_raw_in0)) From 9db9cad82012abaa9acbd765f78aae1e687474f9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 4 May 2026 11:27:27 +0200 Subject: [PATCH 09/13] feat: add pipeline substitution logic --- doc/modules/preprocessing.rst | 31 ++++++++++++ .../preprocessing/detect_artifacts.py | 1 - src/spikeinterface/preprocessing/pipeline.py | 25 ++++++++-- .../preprocessing/tests/test_pipeline.py | 47 +++++++++++++++++++ 4 files changed, 100 insertions(+), 4 deletions(-) diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 1769ab7e1c..e587ac45a5 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -149,6 +149,37 @@ can also be obtained from the pipeline object directly: dict_used_to_make_pipeline = preprocessing_pipeline.preprocessor_dict +Some preprocessing steps, such as :code:`detect_and_remove_artifacts`, allow you to specify an input recording +and optionally another recording to perform some computation (e.g., detect artifacts on the output of a previous +preprocessor, but remove them on the the parent preprocessor). In this case, the string "pipeline[preprocessor_name]" +can be used in the dictionary to specify that the recording argument for this step should be the output of a previous +preprocessor in the same pipeline. For example, if we want to use the output of the "bandpass_filter" step as the +recording to detect artifacts, we can specify it as follows: + +.. code-block:: python + + preprocessing_dict = { + 'bandpass_filter': {'freq_min': 250}, + 'common_reference': {'operator': 'median', 'reference': 'global'}, + 'detect_and_remove_artifacts': {'recording_to_detect': 'pipeline[bandpass_filter]'}, + } + +This will detect artifacts on the output of the "bandpass_filter" step, but the artifacts will be removed on the output +of the "common_reference" step (since the parent recording for "detect_and_remove_artifacts" is by default the output of +the previous step in the pipeline, which is "common_reference" in this case). +To specify the "raw" recording, i.e., the input to the pipeline, we can use "pipeline[raw]". +For example, if we want to detect artifacts on the raw recording, we can specify it as follows: + + +.. code-block:: python + + preprocessing_dict = { + 'bandpass_filter': {'freq_min': 250}, + 'common_reference': {'operator': 'median', 'reference': 'global'}, + 'detect_and_remove_artifacts': {'recording_to_detect': 'pipeline[raw]'}, + } + + Impact on recording dtype ------------------------- diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 72acfd22cc..3c34cba68b 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -582,7 +582,6 @@ def _transform_internal_dtype_to_artifact_dtype( for seg_index in range(num_seg): mask = artifacts["segment_index"] == seg_index sub_thr = artifacts[mask] - print(sub_thr) if len(sub_thr) > 0: if not sub_thr["front"][0]: local_thr = np.zeros(1, dtype=np.dtype(base_period_dtype + [("front", "bool")])) diff --git a/src/spikeinterface/preprocessing/pipeline.py b/src/spikeinterface/preprocessing/pipeline.py index 0f94a2584d..65441bb800 100644 --- a/src/spikeinterface/preprocessing/pipeline.py +++ b/src/spikeinterface/preprocessing/pipeline.py @@ -99,11 +99,29 @@ def _apply(self, recording, apply_precomputed_kwargs=False): Preprocessed recording """ - - for preprocessor_name, kwargs in self.preprocessor_dict.items(): - + instantiated_recordings = {"raw": recording} + for preprocessor_name, kwargs_ in self.preprocessor_dict.items(): + kwargs = kwargs_.copy() dont_apply_kwargs = ["recording", "parent_recording"] + for k, v in kwargs.items(): + if isinstance(v, str) and "pipeline[" in v: + if "recording" not in k: + raise ValueError( + f"Cannot substitute recording for argument '{k}' of preprocessor '{preprocessor_name}' " + f"because this argument is not meant to be a recording object." + ) + if k in dont_apply_kwargs: + raise ValueError( + f"Cannot substitute recording for argument '{k}' of preprocessor '{preprocessor_name}' " + f"because this argument is reserved for the recording to be preprocessed." + ) + rec_name = v.split("pipeline[")[-1].split("]")[0] + substituted_recording = instantiated_recordings.get(rec_name) + if substituted_recording is None: + raise ValueError(f"Cannot find recording '{rec_name}' from previous steps in the pipeline.") + kwargs[k] = substituted_recording + if not apply_precomputed_kwargs: preprocessor_class = pp_names_to_classes[preprocessor_name] precomputable_kwarg_names = preprocessor_class._precomputable_kwarg_names @@ -112,6 +130,7 @@ def _apply(self, recording, apply_precomputed_kwargs=False): non_rec_kwargs = {key: value for key, value in kwargs.items() if key not in dont_apply_kwargs} pp_output = pp_names_to_functions[preprocessor_name](recording, **non_rec_kwargs) recording = pp_output + instantiated_recordings[preprocessor_name] = recording return recording diff --git a/src/spikeinterface/preprocessing/tests/test_pipeline.py b/src/spikeinterface/preprocessing/tests/test_pipeline.py index 27432bb156..be7ccd9048 100644 --- a/src/spikeinterface/preprocessing/tests/test_pipeline.py +++ b/src/spikeinterface/preprocessing/tests/test_pipeline.py @@ -1,3 +1,5 @@ +import pytest + from spikeinterface.generation import generate_recording, generate_ground_truth_recording from spikeinterface.preprocessing import ( apply_preprocessing_pipeline, @@ -212,6 +214,51 @@ def test_loading_from_analyzer(create_cache_folder): check_recordings_equal(pp_recording, pp_recording_from_zarr) +def test_pipeline_recording_arg_substitution(): + """ + Tests that if a preprocessing step in the pipeline has an argument that is a string of the form "pipeline[preprocessor_name]", + then this string is replaced by the recording output by the preprocessor with name "preprocessor_name". This allows users to + use outputs of previous preprocessors as arguments for later preprocessors in the same pipeline. + """ + from spikeinterface.preprocessing.filter import BandpassFilterRecording + from spikeinterface.preprocessing.common_reference import CommonReferenceRecording + from spikeinterface.preprocessing.detect_artifacts import DetectAndRemoveArtifactsRecording + + rec = generate_recording(durations=[1]) + + # "recording" argument is protected, as it is the default argument for the recording to preprocess + pipeline_dict_wrong = { + "common_reference": {}, + "bandpass_filter": {"recording": "pipeline[raw]"}, + } + with pytest.raises(ValueError): + pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pipeline_dict_wrong) + + # The argument using the pipeline substitution must be a string with "recording" as substring + pipeline_dict_wrong2 = { + "common_reference": {}, + "bandpass_filter": {"freq_min": "pipeline[raw]"}, + } + with pytest.raises(ValueError): + pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pipeline_dict_wrong2) + + # Correct usage: the "recording_to_detect" argument for the "detect_and_remove_artifacts" step is set to be the + # output of the "bandpass_filter" step, which is correctly substituted when applying the pipeline. + # The "recording" argument for the "detect_and_remove_artifacts" step should be set to the output of the + # "common_reference" step, as this is the last preprocessor in the pipeline before it. + pipeline_dict_correct = { + "bandpass_filter": {}, + "common_reference": {}, + "detect_and_remove_artifacts": {"recording_to_detect": "pipeline[bandpass_filter]"}, + } + pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pipeline_dict_correct) + # Check that the recording argument for detect step is common ref, + # and that the recording_to_detect argument for detect_and_remove_artifacts is also the output of bandpass_filter + assert isinstance(pp_rec_from_pipeline._kwargs["recording_to_detect"], BandpassFilterRecording) + assert isinstance(pp_rec_from_pipeline._kwargs["recording"], CommonReferenceRecording) + assert isinstance(pp_rec_from_pipeline, DetectAndRemoveArtifactsRecording) + + if __name__ == "__main__": import tempfile from pathlib import Path From 9b9e87caf6c3a4f0cd70e1a240cd4401f4b52811 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 May 2026 11:35:02 +0200 Subject: [PATCH 10/13] Add margin and tests for silence with aopdization --- .../preprocessing/silence_periods.py | 51 +++++++++++++------ .../tests/test_silence_periods.py | 25 ++++++++- 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index ab8734e096..b8bc182216 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -1,12 +1,13 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_noise_levels +from spikeinterface.core.base import base_period_dtype +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.recording_tools import get_noise_levels, get_chunk_with_margin from spikeinterface.core.generate import NoiseGeneratorRecording from spikeinterface.core.job_tools import split_job_kwargs -from spikeinterface.core.base import base_period_dtype + +from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment class SilencedPeriodsRecording(BasePreprocessor): @@ -181,10 +182,19 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg self.apodization_samples = apodization_samples def get_traces(self, start_frame, end_frame, channel_indices): - traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) + if self.mode in ("zeros", "noise"): + margin = 0 + elif self.mode == "apodization": + margin = self.apodization_samples + else: + raise ValueError(f"Unknown method {self.mode}") + + traces, left_margin, right_margin = get_chunk_with_margin( + self.parent_recording_segment, start_frame, end_frame, channel_indices, margin=margin + ) if self.periods.size > 0: - new_interval = np.array([start_frame, end_frame]) + new_interval = np.array([start_frame - margin, end_frame + margin]) lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) @@ -193,9 +203,14 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces.copy() periods_in_interval = self.periods[lower_index:upper_index] + + # For apodization, we pre-allocate the mute function and cosine window + if self.mode == "apodization": + mute_mask = np.zeros(traces.shape[0], dtype=np.float32) + for period in periods_in_interval: - onset = max(0, period["start_sample_index"] - start_frame) - offset = min(period["end_sample_index"] - start_frame, end_frame) + onset = max(0, period["start_sample_index"] - start_frame - margin) + offset = min(period["end_sample_index"] - start_frame + margin, end_frame + margin) if self.mode == "zeros": traces[onset:offset, :] = 0 @@ -205,15 +220,19 @@ def get_traces(self, start_frame, end_frame, channel_indices): ] traces[onset:offset, :] = noise[onset:offset] elif self.mode == "apodization": - import scipy.signal - # apply a cosine taper to the saturation to create a mute function - mute = np.zeros(traces.shape[0], dtype=np.float32) - mute[onset:offset] = 1 - win = scipy.signal.windows.cosine(self.apodization_samples) - mute = np.maximum(0, 1 - scipy.signal.convolve(mute, win, mode="same")) - traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) - return traces + mute_mask[onset:offset] = 1 + + # For apodization, we apply the mute function including all periods to the whole trace, + # so that the edges of the silenced periods are smoothly tapered + if self.mode == "apodization": + import scipy.signal + + win = scipy.signal.windows.cosine(self.apodization_samples) + mute = np.maximum(0, 1 - scipy.signal.convolve(mute_mask, win, mode="same")) + traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) + # discard margin + return traces[left_margin : traces.shape[0] - right_margin, :] # function for API diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index e1338a8887..d3531f1a0f 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -47,10 +47,17 @@ def test_silence(create_cache_folder): assert not np.all(traces_mix[:200] == 0) assert not np.all(traces_mix[:-200] == 0) + +def test_silence_with_apodization(create_cache_folder): + + cache_folder = create_cache_folder + + rec = generate_recording() + + periods = np.array([(0, 0, 1000), (0, 5000, 6000)], dtype=base_period_dtype) # test that the apodization creates a taper apodization_samples = 10 rec2 = silence_periods(rec, periods=periods, mode="apodization", apodization_samples=apodization_samples) - rec2 = rec2.save(format="memory", verbose=False, overwrite=True) traces_in0 = rec2.get_traces(segment_index=0, start_frame=0, end_frame=1000) traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000, end_frame=6000) # all apodized traces @@ -67,6 +74,22 @@ def test_silence(create_cache_folder): assert np.all(np.abs(apodized_traces_in0) <= np.abs(traces_raw_in0)) assert np.all(np.abs(apodized_traces_in1) <= np.abs(traces_raw_in1)) + # check that margins are handled correctly with apodization + extra_samples = 50 + traces_at_offset = rec2.get_traces(segment_index=0, start_frame=998, end_frame=1002) + traces_at_offset_extended = rec2.get_traces( + segment_index=0, start_frame=998 - extra_samples, end_frame=1002 + extra_samples + ) + # the traces at offset should be apodized, and the extended traces should have the same apodization in the overlapping region + assert np.array_equal(traces_at_offset, traces_at_offset_extended[extra_samples:-extra_samples]) + + traces_at_onset = rec2.get_traces(segment_index=0, start_frame=4997, end_frame=5003) + traces_at_onset_extended = rec2.get_traces( + segment_index=0, start_frame=4997 - extra_samples, end_frame=5003 + extra_samples + ) + # the traces at onset should be apodized, and the extended traces should have the same apodization in the overlapping region + assert np.array_equal(traces_at_onset, traces_at_onset_extended[extra_samples:-extra_samples]) + if __name__ == "__main__": cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" From 361cca67892c984eadf32c089bbe441b6708de6e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 May 2026 14:57:24 +0200 Subject: [PATCH 11/13] add test on dump/load and make silence_periods recording not JSON serializable --- .../preprocessing/detect_artifacts.py | 4 ++++ src/spikeinterface/preprocessing/pipeline.py | 6 ++++++ src/spikeinterface/preprocessing/silence_periods.py | 13 +++++++++++-- .../preprocessing/tests/test_pipeline.py | 12 +++++++++++- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 3c34cba68b..49af692996 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -688,6 +688,10 @@ class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): ---------- recording : BaseRecording The recording on which to detect artifact periods. + recording_to_detect : BaseRecording | None, default: None + The recording on which to perform artifact detection. + If ``None``, the same recording passed in ``recording`` is used. This allows users to perform detection on + the another recording (e.g. the raw recording) while applying the silencing to the `recording`. method : "envelope" | "saturation", default: "envelope" Detection method to use. method_kwargs : dict | None, default: None diff --git a/src/spikeinterface/preprocessing/pipeline.py b/src/spikeinterface/preprocessing/pipeline.py index 65441bb800..4d90007964 100644 --- a/src/spikeinterface/preprocessing/pipeline.py +++ b/src/spikeinterface/preprocessing/pipeline.py @@ -324,6 +324,12 @@ def _load_pp_from_dict(prov_dict, kwargs_dict): for name, value in prov_dict["kwargs"].items(): if is_dict_extractor(value): this_level_kwargs[name] = _load_pp_from_dict(value, kwargs_dict) + elif isinstance(value, BaseRecording): + extractor_as_dict = value.to_dict() + if name in ["recording", "parent_recording"]: + this_level_kwargs[name] = _load_pp_from_dict(extractor_as_dict, kwargs_dict) + else: # this branch takes care of other arguments being a recording, e.g., `recording_to_detect` + this_level_kwargs[name] = value elif isinstance(value, dict): this_level_kwargs[name] = {k: prov_dict_to_kwargs_dict(v) for k, v in value.items()} elif isinstance(value, list): diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index b8bc182216..595ef46123 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -67,10 +67,16 @@ def __init__( # handle backward compatibility with previous version if list_periods is not None: - assert periods is None + assert periods is None, ( + "You cannot specify both list_periods and periods. " + f"Please specify only periods, which should be a np.array with dtype {base_period_dtype}" + ) periods = _all_period_list_to_periods_vec(list_periods, num_seg) else: - assert list_periods is None + assert list_periods is None, ( + "list_periods is deprecated. Please specify periods, which should be a np.array with " + f"dtype {base_period_dtype}" + ) if not isinstance(periods, np.ndarray): raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") @@ -124,6 +130,9 @@ def __init__( ) self.add_recording_segment(rec_segment) + # the base_period_dtype is a structured dtype, which is not json serializable + self._serializability["json"] = False + self._kwargs = dict( recording=recording, periods=periods, diff --git a/src/spikeinterface/preprocessing/tests/test_pipeline.py b/src/spikeinterface/preprocessing/tests/test_pipeline.py index be7ccd9048..124efdc0c4 100644 --- a/src/spikeinterface/preprocessing/tests/test_pipeline.py +++ b/src/spikeinterface/preprocessing/tests/test_pipeline.py @@ -214,7 +214,7 @@ def test_loading_from_analyzer(create_cache_folder): check_recordings_equal(pp_recording, pp_recording_from_zarr) -def test_pipeline_recording_arg_substitution(): +def test_pipeline_recording_arg_substitution(create_cache_folder): """ Tests that if a preprocessing step in the pipeline has an argument that is a string of the form "pipeline[preprocessor_name]", then this string is replaced by the recording output by the preprocessor with name "preprocessor_name". This allows users to @@ -258,6 +258,16 @@ def test_pipeline_recording_arg_substitution(): assert isinstance(pp_rec_from_pipeline._kwargs["recording"], CommonReferenceRecording) assert isinstance(pp_rec_from_pipeline, DetectAndRemoveArtifactsRecording) + # Test dumping the pipeline to pickle and loading it back with the correct substitution still works + pp_rec_from_pipeline.dump_to_pickle(create_cache_folder / "pipeline_substitution_test.pkl") + pp_rec_from_pkl = get_preprocessing_dict_from_file(create_cache_folder / "pipeline_substitution_test.pkl") + pp_rec_from_pipeline_substitution = apply_preprocessing_pipeline( + rec, pp_rec_from_pkl, apply_precomputed_kwargs=True + ) + assert isinstance(pp_rec_from_pipeline_substitution._kwargs["recording_to_detect"], BandpassFilterRecording) + assert isinstance(pp_rec_from_pipeline_substitution._kwargs["recording"], CommonReferenceRecording) + assert isinstance(pp_rec_from_pipeline_substitution, DetectAndRemoveArtifactsRecording) + if __name__ == "__main__": import tempfile From 676d09f7957c410905473e1480d6ba82d0eaf443 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 5 May 2026 15:17:46 +0200 Subject: [PATCH 12/13] allow channel_filters to be passed as a list (JSON-serializable) --- src/spikeinterface/preprocessing/detect_bad_channels.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 3b403ca73f..503c802de7 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -308,6 +308,9 @@ def detect_bad_channels( if channel_filters is None: channel_filters = allowed_filters + if isinstance(channel_filters, list): + channel_filters = set(channel_filters) + if not isinstance(channel_filters, set): raise ValueError(f"channel_filters must be None or a set of the following values : {allowed_filters} ") From d0c4e040b33cb041aa461f22bed5e546e769f93e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 6 May 2026 12:20:34 +0200 Subject: [PATCH 13/13] add test and autmatixally load saturation levels from probeinterface --- .../extractors/neoextractors/openephys.py | 10 ++++++++++ .../extractors/neoextractors/spikeglx.py | 9 +++++++++ .../preprocessing/tests/test_pipeline.py | 5 +++-- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index f421d285f5..d092be75b1 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -338,6 +338,16 @@ def __init__( if sample_shifts is not None: self.set_property("inter_sample_shift", sample_shifts) + # add saturation levels if available + saturation_threshold_uV = None + # LFP should be in the oe_stream_name for LFP streams, anything else is considered AP + if "LFP" in oe_stream_name: + saturation_threshold_uV = probe.annotations.get("lf_saturation_uV", None) + else: + saturation_threshold_uV = probe.annotations.get("ap_saturation_uV", None) + if saturation_threshold_uV is not None: + self.annotate(saturation_threshold_uV=saturation_threshold_uV) + # folder_path can point to different levels of the OE folder structure # (root, record node, experiment, or recording). We need to find the root folder # in order to load the sync timestamps and set them as times to the recording. diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index d8c58c773e..b99fd32f4a 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -91,6 +91,15 @@ def __init__( sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) if sample_shifts is not None: self.set_property("inter_sample_shift", sample_shifts) + + # add saturation levels if available + saturation_threshold_uV = None + if "ap" in self.stream_id: + saturation_threshold_uV = probe.annotations.get("ap_saturation_uV", None) + elif "lf" in self.stream_id: + saturation_threshold_uV = probe.annotations.get("lf_saturation_uV", None) + if saturation_threshold_uV is not None: + self.annotate(saturation_threshold_uV=saturation_threshold_uV) else: warning_message = ( "Unable to find a corresponding metadata file for the recording. " diff --git a/src/spikeinterface/preprocessing/tests/test_pipeline.py b/src/spikeinterface/preprocessing/tests/test_pipeline.py index 124efdc0c4..6a96d4a66c 100644 --- a/src/spikeinterface/preprocessing/tests/test_pipeline.py +++ b/src/spikeinterface/preprocessing/tests/test_pipeline.py @@ -1,5 +1,7 @@ import pytest +from spikeinterface.core.testing import check_recordings_equal +from spikeinterface.core import create_sorting_analyzer from spikeinterface.generation import generate_recording, generate_ground_truth_recording from spikeinterface.preprocessing import ( apply_preprocessing_pipeline, @@ -16,8 +18,6 @@ get_preprocessing_dict_from_file, get_preprocessing_dict_from_analyzer, ) -from spikeinterface.core.testing import check_recordings_equal -from spikeinterface.core import create_sorting_analyzer def test_pipeline_equiv_to_step(): @@ -267,6 +267,7 @@ def test_pipeline_recording_arg_substitution(create_cache_folder): assert isinstance(pp_rec_from_pipeline_substitution._kwargs["recording_to_detect"], BandpassFilterRecording) assert isinstance(pp_rec_from_pipeline_substitution._kwargs["recording"], CommonReferenceRecording) assert isinstance(pp_rec_from_pipeline_substitution, DetectAndRemoveArtifactsRecording) + check_recordings_equal(pp_rec_from_pipeline, pp_rec_from_pipeline_substitution) if __name__ == "__main__":