diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index 1769ab7e1c..e587ac45a5 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -149,6 +149,37 @@ can also be obtained from the pipeline object directly: dict_used_to_make_pipeline = preprocessing_pipeline.preprocessor_dict +Some preprocessing steps, such as :code:`detect_and_remove_artifacts`, allow you to specify an input recording +and optionally another recording to perform some computation (e.g., detect artifacts on the output of a previous +preprocessor, but remove them on the the parent preprocessor). In this case, the string "pipeline[preprocessor_name]" +can be used in the dictionary to specify that the recording argument for this step should be the output of a previous +preprocessor in the same pipeline. For example, if we want to use the output of the "bandpass_filter" step as the +recording to detect artifacts, we can specify it as follows: + +.. code-block:: python + + preprocessing_dict = { + 'bandpass_filter': {'freq_min': 250}, + 'common_reference': {'operator': 'median', 'reference': 'global'}, + 'detect_and_remove_artifacts': {'recording_to_detect': 'pipeline[bandpass_filter]'}, + } + +This will detect artifacts on the output of the "bandpass_filter" step, but the artifacts will be removed on the output +of the "common_reference" step (since the parent recording for "detect_and_remove_artifacts" is by default the output of +the previous step in the pipeline, which is "common_reference" in this case). +To specify the "raw" recording, i.e., the input to the pipeline, we can use "pipeline[raw]". +For example, if we want to detect artifacts on the raw recording, we can specify it as follows: + + +.. code-block:: python + + preprocessing_dict = { + 'bandpass_filter': {'freq_min': 250}, + 'common_reference': {'operator': 'median', 'reference': 'global'}, + 'detect_and_remove_artifacts': {'recording_to_detect': 'pipeline[raw]'}, + } + + Impact on recording dtype ------------------------- diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index f421d285f5..d092be75b1 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -338,6 +338,16 @@ def __init__( if sample_shifts is not None: self.set_property("inter_sample_shift", sample_shifts) + # add saturation levels if available + saturation_threshold_uV = None + # LFP should be in the oe_stream_name for LFP streams, anything else is considered AP + if "LFP" in oe_stream_name: + saturation_threshold_uV = probe.annotations.get("lf_saturation_uV", None) + else: + saturation_threshold_uV = probe.annotations.get("ap_saturation_uV", None) + if saturation_threshold_uV is not None: + self.annotate(saturation_threshold_uV=saturation_threshold_uV) + # folder_path can point to different levels of the OE folder structure # (root, record node, experiment, or recording). We need to find the root folder # in order to load the sync timestamps and set them as times to the recording. diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index d8c58c773e..b99fd32f4a 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -91,6 +91,15 @@ def __init__( sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) if sample_shifts is not None: self.set_property("inter_sample_shift", sample_shifts) + + # add saturation levels if available + saturation_threshold_uV = None + if "ap" in self.stream_id: + saturation_threshold_uV = probe.annotations.get("ap_saturation_uV", None) + elif "lf" in self.stream_id: + saturation_threshold_uV = probe.annotations.get("lf_saturation_uV", None) + if saturation_threshold_uV is not None: + self.annotate(saturation_threshold_uV=saturation_threshold_uV) else: warning_message = ( "Unable to find a corresponding metadata file for the recording. " diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py index 2f2c740616..49af692996 100644 --- a/src/spikeinterface/preprocessing/detect_artifacts.py +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -4,19 +4,24 @@ from spikeinterface.core import BaseRecording from spikeinterface.core.base import base_period_dtype +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_random_data_chunks +from spikeinterface.core.node_pipeline import PeakDetector, run_node_pipeline, PipelineNode from spikeinterface.preprocessing.rectify import RectifyRecording from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels, get_random_data_chunks -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode +from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording artifact_dtype = base_period_dtype -# this will be extend with channel boundaries if needed -# extended_artifact_dtype = artifact_dtype + [ -# # TODO -# ] +signed_artifact_dtype = np.dtype(artifact_dtype + [("sign", "U8")]) + + +def _indent_docstring(docstring: str, indent: int = 4) -> str: + indent_str = " " * indent + indented_lines = [(indent_str + line) if line.strip() else line for line in docstring.splitlines()] + return "\n".join(indented_lines) def _collapse_events(events: np.ndarray) -> np.ndarray: @@ -77,6 +82,7 @@ def __init__( saturation_threshold_uV: float, diff_threshold_uV: float | None, proportion: float, + signed: bool, ) -> None: """ Parameters @@ -91,6 +97,11 @@ def __init__( proportion : float Fraction of channels that must exceed the threshold for a sample to be labelled as saturated (0 < proportion < 1). + signed : bool + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``. """ PipelineNode.__init__(self, recording, return_output=True) @@ -102,6 +113,7 @@ def __init__( # slightly lower than the documented saturation point of the probe self.sampling_frequency = recording.get_sampling_frequency() self.proportion = proportion + self.signed = signed self._dtype = np.dtype(artifact_dtype) self.gain = recording.get_channel_gains() self.offset = recording.get_channel_offsets() @@ -122,6 +134,27 @@ def get_dtype(self) -> np.dtype: """Return the NumPy dtype of the output array produced by :meth:`compute`.""" return self._dtype + def detect_in_chunk(self, traces, saturation_threshold, diff_threshold, proportion) -> np.ndarray: + saturation = np.mean(traces > saturation_threshold, axis=1) + detected_by_value = saturation > proportion + + if diff_threshold is not None: + # then compute the derivative of the voltage saturation + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= diff_threshold, axis=1) + + # Note this means the velocity is not checked for the last sample in the + # check because we are taking the forward derivative + n_diff_saturated = np.r_[n_diff_saturated, 0] + + # if either of those reaches more than the proportion of channels labels the sample as saturated + detected_by_diff = n_diff_saturated > proportion + saturation = np.logical_or(detected_by_value, detected_by_diff) + else: + saturation = detected_by_value + + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) + return intervals + def compute( self, traces: np.ndarray, @@ -165,33 +198,62 @@ def compute( # cast to float32 to prevent overflow when applying thresholds in unscaled ADC units traces = traces.astype("float32") - saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) - detected_by_value = saturation > self.proportion - - if self.diff_threshold_unscaled is not None: - # then compute the derivative of the voltage saturation - n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.diff_threshold_unscaled, axis=1) - - # Note this means the velocity is not checked for the last sample in the - # check because we are taking the forward derivative - n_diff_saturated = np.r_[n_diff_saturated, 0] - - # if either of those reaches more than the proportion of channels labels the sample as saturated - detected_by_diff = n_diff_saturated > self.proportion - saturation = np.logical_or(detected_by_value, detected_by_diff) + if not self.signed: + traces = np.abs(traces) + intervals = self.detect_in_chunk( + traces, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=artifact_dtype) + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index else: - saturation = detected_by_value + all_events = [] + for sign in (1, -1): + traces_signed = sign * traces + intervals = self.detect_in_chunk( + traces_signed, self.saturation_threshold_unscaled, self.diff_threshold_unscaled, self.proportion + ) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=signed_artifact_dtype) + events["sign"] = "positive" if sign == 1 else "negative" + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index + all_events.append(events) + all_events = np.concatenate(all_events) + # sort by start sample index + order = np.argsort(all_events["start_sample_index"]) + events = all_events[order] - intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) - n_events = len(intervals) // 2 # Number of saturation periods - events = np.zeros(n_events, dtype=artifact_dtype) + return (events,) - for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): - events[i]["start_sample_index"] = start + start_frame - events[i]["end_sample_index"] = stop + start_frame - events[i]["segment_index"] = segment_index - return (events,) +_detect_saturation_periods_params = """saturation_threshold_uV : float | None, default: None + Voltage saturation threshold in μV. The appropriate value depends on + the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL + recommend **1200 μV**. NP2 probes are harder to saturate than NP1. + If ``None``, the value is read from the ``"saturation_threshold_uV"`` + annotation of ``recording``. +diff_threshold_uV : float | None, default: None + First-derivative threshold in μV/sample. Periods where the + sample-to-sample voltage change exceeds this value in the required + fraction of channels are flagged as saturation. Pass ``None`` to + disable derivative-based detection and rely solely on + ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. +proportion : float, default: 0.2 + Fraction of channels (0 < proportion < 1) that must exceed the + threshold for a sample to be considered saturated. +signed : bool, default: False + If ``True``, the sign of the saturation is returned as an additional field in the output array, with values + ``"positive"`` for positive saturation and ``"negative"`` for negative saturation. If ``False``, + the output array has the standard ``artifact_dtype`` with fields ``"start_sample_index"``, ``"end_sample_index"``, + and ``"segment_index"``.""" def detect_saturation_periods( @@ -199,6 +261,7 @@ def detect_saturation_periods( saturation_threshold_uV: float | None = None, diff_threshold_uV: float | None = None, proportion: float = 0.2, + signed: bool = False, job_kwargs: dict | None = None, ) -> np.ndarray: """ @@ -224,21 +287,7 @@ def detect_saturation_periods( ---------- recording : BaseRecording The recording on which to detect saturation events. - saturation_threshold_uV : float | None, default: None - Voltage saturation threshold in μV. The appropriate value depends on - the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL - recommend **1200 μV**. NP2 probes are harder to saturate than NP1. - If ``None``, the value is read from the ``"saturation_threshold_uV"`` - annotation of ``recording``. - diff_threshold_uV : float | None, default: None - First-derivative threshold in μV/sample. Periods where the - sample-to-sample voltage change exceeds this value in the required - fraction of channels are flagged as saturation. Pass ``None`` to - disable derivative-based detection and rely solely on - ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. - proportion : float, default: 0.2 - Fraction of channels (0 < proportion < 1) that must exceed the - threshold for a sample to be considered saturated. + {} job_kwargs : dict | None, default: None Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. ``n_jobs``, ``chunk_duration``). @@ -270,6 +319,7 @@ def detect_saturation_periods( saturation_threshold_uV=saturation_threshold_uV, diff_threshold_uV=diff_threshold_uV, proportion=proportion, + signed=signed, ) saturation_periods = run_node_pipeline( @@ -278,6 +328,9 @@ def detect_saturation_periods( return _collapse_events(saturation_periods) +detect_saturation_periods.__doc__ = detect_saturation_periods.__doc__.format(_detect_saturation_periods_params) + + ## detect_artifact_periods_by_envelope zone class _DetectThresholdCrossing(PeakDetector): """ @@ -382,6 +435,23 @@ def compute( return (threshold_crossings,) +_detect_artifacts_by_envelope_params = """detect_threshold : float, default: 5 + Detection threshold as a multiple of the estimated per-channel noise + level of the envelope. +freq_max : float, default: 20.0 + Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the + rectified signal when building the envelope. +seed : int | None, default: None + Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. + If ``None``, ``get_noise_levels`` uses ``seed=0``. +job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). +random_slices_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the ``random_slices_kwargs`` + argument of :func:`~spikeinterface.core.get_noise_levels`.""" + + def detect_artifact_periods_by_envelope( recording: BaseRecording, detect_threshold: float = 5, @@ -410,21 +480,7 @@ def detect_artifact_periods_by_envelope( ---------- recording : BaseRecording The recording extractor from which to detect artefact periods. - detect_threshold : float, default: 5 - Detection threshold as a multiple of the estimated per-channel noise - level of the envelope. - freq_max : float, default: 20.0 - Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the - rectified signal when building the envelope. - seed : int | None, default: None - Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. - If ``None``, ``get_noise_levels`` uses ``seed=0``. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). - random_slices_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the ``random_slices_kwargs`` - argument of :func:`~spikeinterface.core.get_noise_levels`. + {} return_envelope : bool, default: False If ``True``, also return the intermediate envelope recording so that it can be inspected or plotted. @@ -475,7 +531,6 @@ def detect_artifact_periods_by_envelope( artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] artifacts = _collapse_events(artifacts) if return_envelope: @@ -484,6 +539,11 @@ def detect_artifact_periods_by_envelope( return artifacts +detect_artifact_periods_by_envelope.__doc__ = detect_artifact_periods_by_envelope.__doc__.format( + _detect_artifacts_by_envelope_params +) + + def _transform_internal_dtype_to_artifact_dtype( artifacts: np.ndarray, recording: BaseRecording, @@ -522,7 +582,6 @@ def _transform_internal_dtype_to_artifact_dtype( for seg_index in range(num_seg): mask = artifacts["segment_index"] == seg_index sub_thr = artifacts[mask] - print(sub_thr) if len(sub_thr) > 0: if not sub_thr["front"][0]: local_thr = np.zeros(1, dtype=np.dtype(base_period_dtype + [("front", "bool")])) @@ -561,35 +620,41 @@ def detect_artifact_periods( job_kwargs: dict | None = None, ) -> np.ndarray: """ - Detect artifact periods using one of several available methods. + Detect artifact periods using one of several available methods. - Available methods: + Available methods: - * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified - channel envelope. - * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. - See the documentation of each sub-function for a full description of their - parameters, which can be forwarded via ``method_kwargs``. + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. - Parameters - ---------- - recording : BaseRecording - The recording on which to detect artifact periods. - method : {"envelope", "saturation"}, default: "envelope" - Detection method to use. - method_kwargs : dict | None, default: None - Additional keyword arguments forwarded to the selected detection - function. Pass ``None`` to use that function's defaults. - job_kwargs : dict | None, default: None - Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. - ``n_jobs``, ``chunk_duration``). + Method-specific parameters include: - Returns - ------- - np.ndarray - Array with dtype ``artifact_dtype`` describing each detected artifact - period. + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. """ assert ( method in _method_to_function @@ -600,3 +665,122 @@ def detect_artifact_periods( artifact_periods = _method_to_function[method](recording, job_kwargs=job_kwargs, **method_kwargs) return artifact_periods + + +detect_artifact_periods.__doc__ = detect_artifact_periods.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) + + +class DetectAndRemoveArtifactsRecording(SilencedPeriodsRecording): + """ + Detect and remove artifact periods using one of several available methods. + + Available methods: + + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + recording_to_detect : BaseRecording | None, default: None + The recording on which to perform artifact detection. + If ``None``, the same recording passed in ``recording`` is used. This allows users to perform detection on + the another recording (e.g. the raw recording) while applying the silencing to the `recording`. + method : "envelope" | "saturation", default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. + + Method-specific parameters include: + + - ``"envelope"`` + + {artifacts_by_envelope_params} + - ``"saturation"``, see :func:`detect_saturation_periods` + + {saturation_periods_params} + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + noise_levels_kwargs : dict | None, default: None + Keyword arguments for `spikeinterface.core.get_noise_levels()` function. + + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. + mode : "zeros" | "noise" | "apodization", default: "zeros" + Determines what periods are replaced by. Can be one of the following: + + - "zeros": Artifacts are replaced by zeros. + + - "noise": The periods are filled with a gaussion noise that has the + same variance that the one in the recordings, on a per channel + basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_samples`) + apodization_samples : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + artifact_periods : np.ndarray | None, default: None + Optionally, pre-computed artifact periods can be passed directly to the constructor to skip the + detection step. If ``None``, artifact periods are detected on the fly using the specified method + """ + + def __init__( + self, + recording: BaseRecording, + recording_to_detect: BaseRecording | None = None, + method: Literal["envelope", "saturation"] = "envelope", + method_kwargs: dict | None = None, + job_kwargs: dict | None = None, + mode: Literal["zeros", "noise", "apodization"] = "zeros", + noise_levels_kwargs: dict | None = None, + apodization: int = 7, + seed: int | None = None, + artifact_periods=None, + ) -> None: + if artifact_periods is not None: + artifact_periods = artifact_periods + else: + if recording_to_detect is None: + recording_to_detect = recording + artifact_periods = detect_artifact_periods( + recording_to_detect, method=method, method_kwargs=method_kwargs, job_kwargs=job_kwargs + ) + super().__init__( + recording, + periods=artifact_periods, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + apodization=apodization, + ) + + self._kwargs = dict( + recording=recording, + recording_to_detect=recording_to_detect, + method=method, + method_kwargs=method_kwargs, + job_kwargs=job_kwargs, + mode=mode, + noise_levels_kwargs=noise_levels_kwargs, + seed=seed, + artifact_periods=artifact_periods, + apodization=apodization, + ) + + +# function for API +detect_and_remove_artifacts = define_function_handling_dict_from_class( + source_class=DetectAndRemoveArtifactsRecording, name="detect_and_remove_artifacts" +) + +detect_and_remove_artifacts.__doc__ = detect_and_remove_artifacts.__doc__.format( + artifacts_by_envelope_params=_indent_docstring(_detect_artifacts_by_envelope_params, 12), + saturation_periods_params=_indent_docstring(_detect_saturation_periods_params, 12), +) diff --git a/src/spikeinterface/preprocessing/detect_bad_channels.py b/src/spikeinterface/preprocessing/detect_bad_channels.py index 3b403ca73f..503c802de7 100644 --- a/src/spikeinterface/preprocessing/detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/detect_bad_channels.py @@ -308,6 +308,9 @@ def detect_bad_channels( if channel_filters is None: channel_filters = allowed_filters + if isinstance(channel_filters, list): + channel_filters = set(channel_filters) + if not isinstance(channel_filters, set): raise ValueError(f"channel_filters must be None or a set of the following values : {allowed_filters} ") diff --git a/src/spikeinterface/preprocessing/pipeline.py b/src/spikeinterface/preprocessing/pipeline.py index 0f94a2584d..4d90007964 100644 --- a/src/spikeinterface/preprocessing/pipeline.py +++ b/src/spikeinterface/preprocessing/pipeline.py @@ -99,11 +99,29 @@ def _apply(self, recording, apply_precomputed_kwargs=False): Preprocessed recording """ - - for preprocessor_name, kwargs in self.preprocessor_dict.items(): - + instantiated_recordings = {"raw": recording} + for preprocessor_name, kwargs_ in self.preprocessor_dict.items(): + kwargs = kwargs_.copy() dont_apply_kwargs = ["recording", "parent_recording"] + for k, v in kwargs.items(): + if isinstance(v, str) and "pipeline[" in v: + if "recording" not in k: + raise ValueError( + f"Cannot substitute recording for argument '{k}' of preprocessor '{preprocessor_name}' " + f"because this argument is not meant to be a recording object." + ) + if k in dont_apply_kwargs: + raise ValueError( + f"Cannot substitute recording for argument '{k}' of preprocessor '{preprocessor_name}' " + f"because this argument is reserved for the recording to be preprocessed." + ) + rec_name = v.split("pipeline[")[-1].split("]")[0] + substituted_recording = instantiated_recordings.get(rec_name) + if substituted_recording is None: + raise ValueError(f"Cannot find recording '{rec_name}' from previous steps in the pipeline.") + kwargs[k] = substituted_recording + if not apply_precomputed_kwargs: preprocessor_class = pp_names_to_classes[preprocessor_name] precomputable_kwarg_names = preprocessor_class._precomputable_kwarg_names @@ -112,6 +130,7 @@ def _apply(self, recording, apply_precomputed_kwargs=False): non_rec_kwargs = {key: value for key, value in kwargs.items() if key not in dont_apply_kwargs} pp_output = pp_names_to_functions[preprocessor_name](recording, **non_rec_kwargs) recording = pp_output + instantiated_recordings[preprocessor_name] = recording return recording @@ -305,6 +324,12 @@ def _load_pp_from_dict(prov_dict, kwargs_dict): for name, value in prov_dict["kwargs"].items(): if is_dict_extractor(value): this_level_kwargs[name] = _load_pp_from_dict(value, kwargs_dict) + elif isinstance(value, BaseRecording): + extractor_as_dict = value.to_dict() + if name in ["recording", "parent_recording"]: + this_level_kwargs[name] = _load_pp_from_dict(extractor_as_dict, kwargs_dict) + else: # this branch takes care of other arguments being a recording, e.g., `recording_to_detect` + this_level_kwargs[name] = value elif isinstance(value, dict): this_level_kwargs[name] = {k: prov_dict_to_kwargs_dict(v) for k, v in value.items()} elif isinstance(value, list): diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index 47e3c0906b..6ada71de44 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -48,6 +48,7 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed +from .detect_artifacts import DetectAndRemoveArtifactsRecording, detect_and_remove_artifacts # from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts @@ -72,6 +73,8 @@ # bad channel detection/interpolation DetectAndRemoveBadChannelsRecording: detect_and_remove_bad_channels, DetectAndInterpolateBadChannelsRecording: detect_and_interpolate_bad_channels, + # artifact/saturation handling + DetectAndRemoveArtifactsRecording: detect_and_remove_artifacts, # misc RectifyRecording: rectify, ClipRecording: clip, diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 393c712919..595ef46123 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -1,12 +1,13 @@ import numpy as np -from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from spikeinterface.core import get_noise_levels +from spikeinterface.core.base import base_period_dtype +from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.recording_tools import get_noise_levels, get_chunk_with_margin from spikeinterface.core.generate import NoiseGeneratorRecording from spikeinterface.core.job_tools import split_job_kwargs -from spikeinterface.core.base import base_period_dtype + +from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment class SilencedPeriodsRecording(BasePreprocessor): @@ -21,14 +22,11 @@ class SilencedPeriodsRecording(BasePreprocessor): ---------- recording : RecordingExtractor The recording extractor to silance periods - list_periods : list of lists/arrays - One list per segment of tuples (start_frame, end_frame) to silence - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise, default: "zeros" + periods : np.array + A numpy array with dtype `base_period_dtype` and fields + "segment_index", "start_sample_index", "end_sample_index". + Each row corresponds to a period to silence. + mode : "zeros" | "noise" | "apodization", default: "zeros" Determines what periods are replaced by. Can be one of the following: - "zeros": Artifacts are replaced by zeros. @@ -36,6 +34,14 @@ class SilencedPeriodsRecording(BasePreprocessor): - "noise": The periods are filled with a gaussion noise that has the same variance that the one in the recordings, on a per channel basis + - "apodization": The periods zeroed, but are apodized with a cosine taper (using `apodization_samples`) + apodization_samples : int, default: 7 + The factor used for the cosine taper when mode is "apodization". Higher values create a wider taper. + noise_levels : array + Noise levels if already computed + seed : int | None, default: None + Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. + If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function Returns @@ -48,22 +54,29 @@ def __init__( self, recording, periods=None, - # this is keep for backward compatibility + # this is kept for backward compatibility list_periods=None, mode="zeros", + apodization_samples=7, noise_levels=None, seed=None, **noise_levels_kwargs, ): - available_modes = ("zeros", "noise") + available_modes = ("zeros", "noise", "apodization") num_seg = recording.get_num_segments() # handle backward compatibility with previous version if list_periods is not None: - assert periods is None + assert periods is None, ( + "You cannot specify both list_periods and periods. " + f"Please specify only periods, which should be a np.array with dtype {base_period_dtype}" + ) periods = _all_period_list_to_periods_vec(list_periods, num_seg) else: - assert list_periods is None + assert list_periods is None, ( + "list_periods is deprecated. Please specify periods, which should be a np.array with " + f"dtype {base_period_dtype}" + ) if not isinstance(periods, np.ndarray): raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") @@ -108,11 +121,26 @@ def __init__( i1 = seg_limits[seg_index + 1] periods_in_seg = periods[i0:i1] rec_segment = SilencedPeriodsRecordingSegment( - parent_segment, periods_in_seg, mode, noise_generator, seg_index + parent_segment, + periods_in_seg, + mode, + noise_generator, + seg_index, + apodization_samples=apodization_samples, ) self.add_recording_segment(rec_segment) - self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels) + # the base_period_dtype is a structured dtype, which is not json serializable + self._serializability["json"] = False + + self._kwargs = dict( + recording=recording, + periods=periods, + mode=mode, + seed=seed, + noise_levels=noise_levels, + apodization_samples=apodization_samples, + ) def _all_period_list_to_periods_vec(list_periods, num_seg): @@ -154,18 +182,28 @@ def _check_periods(periods, num_seg): class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): - def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index): + def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg_index, apodization_samples=7): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.periods = periods self.mode = mode self.seg_index = seg_index self.noise_generator = noise_generator + self.apodization_samples = apodization_samples def get_traces(self, start_frame, end_frame, channel_indices): - traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) + if self.mode in ("zeros", "noise"): + margin = 0 + elif self.mode == "apodization": + margin = self.apodization_samples + else: + raise ValueError(f"Unknown method {self.mode}") + + traces, left_margin, right_margin = get_chunk_with_margin( + self.parent_recording_segment, start_frame, end_frame, channel_indices, margin=margin + ) if self.periods.size > 0: - new_interval = np.array([start_frame, end_frame]) + new_interval = np.array([start_frame - margin, end_frame + margin]) lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) @@ -174,9 +212,14 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces.copy() periods_in_interval = self.periods[lower_index:upper_index] + + # For apodization, we pre-allocate the mute function and cosine window + if self.mode == "apodization": + mute_mask = np.zeros(traces.shape[0], dtype=np.float32) + for period in periods_in_interval: - onset = max(0, period["start_sample_index"] - start_frame) - offset = min(period["end_sample_index"] - start_frame, end_frame) + onset = max(0, period["start_sample_index"] - start_frame - margin) + offset = min(period["end_sample_index"] - start_frame + margin, end_frame + margin) if self.mode == "zeros": traces[onset:offset, :] = 0 @@ -185,8 +228,20 @@ def get_traces(self, start_frame, end_frame, channel_indices): :, channel_indices ] traces[onset:offset, :] = noise[onset:offset] - - return traces + elif self.mode == "apodization": + # apply a cosine taper to the saturation to create a mute function + mute_mask[onset:offset] = 1 + + # For apodization, we apply the mute function including all periods to the whole trace, + # so that the edges of the silenced periods are smoothly tapered + if self.mode == "apodization": + import scipy.signal + + win = scipy.signal.windows.cosine(self.apodization_samples) + mute = np.maximum(0, 1 - scipy.signal.convolve(mute_mask, win, mode="same")) + traces = (traces.astype(np.float32) * mute[:, np.newaxis]).astype(traces.dtype) + # discard margin + return traces[left_margin : traces.shape[0] - right_margin, :] # function for API diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py index 1d99206bd2..76347f56b8 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -5,6 +5,7 @@ detect_artifact_periods, detect_saturation_periods, detect_artifact_periods_by_envelope, + detect_and_remove_artifacts, ) @@ -238,6 +239,119 @@ def test_detect_saturation_periods(debug_plots): assert np.array_equal(periods, periods_entry_with_annotation) +def test_detect_saturation_signed(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + # Inject positive saturation in first third, negative in second third + pos_start, pos_stop = 15000, 15500 + neg_start, neg_stop = 45000, 45500 + data[pos_start:pos_stop, :] = sat_value + data[neg_start:neg_stop, :] = -sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + periods = detect_saturation_periods( + recording, + saturation_threshold_uV=sat_value * 0.98, + signed=True, + job_kwargs=job_kwargs, + ) + + # Output dtype must include the "sign" field + assert "sign" in periods.dtype.names + + pos_periods = periods[periods["sign"] == "positive"] + neg_periods = periods[periods["sign"] == "negative"] + assert len(pos_periods) > 0, "No positive saturation periods detected" + assert len(neg_periods) > 0, "No negative saturation periods detected" + + # Positive period should be near the injected positive saturation + tolerance = 1 + assert np.any(np.abs(pos_periods["start_sample_index"] - pos_start) <= tolerance) + assert np.any(np.abs(pos_periods["end_sample_index"] - pos_stop) <= tolerance) + + # Negative period should be near the injected negative saturation + assert np.any(np.abs(neg_periods["start_sample_index"] - neg_start) <= tolerance) + assert np.any(np.abs(neg_periods["end_sample_index"] - neg_stop) <= tolerance) + + # Positive periods must not contain any sample indices from the negative injection + for p in pos_periods: + assert not (p["start_sample_index"] < neg_stop and p["end_sample_index"] > neg_start) + + # Negative periods must not contain any sample indices from the positive injection + for p in neg_periods: + assert not (p["start_sample_index"] < pos_stop and p["end_sample_index"] > pos_start) + + +def test_detect_and_remove_artifacts(): + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 + job_kwargs = {"chunk_size": chunk_size} + + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(0) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(90000, num_chans)) * 10 + + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + data = scipy.signal.sosfiltfilt(sos, data, axis=0) + + sat_start, sat_stop = 15000, 15500 + data[sat_start:sat_stop, :] = sat_value + + gain = 2.34 + offset = 0 + data_int16 = np.clip(np.rint((data - offset) / gain), -32768, 32767).astype(np.int16) + + recording = NumpyRecording(data_int16, sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + # Basic usage: detect and zero out saturation in one step + cleaned = detect_and_remove_artifacts( + recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + traces = cleaned.get_traces(segment_index=0) + assert traces[sat_start + 100, 0] == 0, "Saturated samples should be zeroed" + assert traces[0, 0] != 0, "Non-saturated samples should not be zeroed" + + # recording_to_detect: detect on raw recording, silence a separate (processed) recording + # We use the same recording here just to exercise the code path + cleaned_with_detect = detect_and_remove_artifacts( + recording, + recording_to_detect=recording, + method="saturation", + method_kwargs=dict(saturation_threshold_uV=sat_value * 0.98), + job_kwargs=job_kwargs, + ) + assert np.array_equal(cleaned.get_traces(), cleaned_with_detect.get_traces()) + + if __name__ == "__main__": # test_detect_artifact_by_envelope(True) test_detect_saturation_periods(False) diff --git a/src/spikeinterface/preprocessing/tests/test_pipeline.py b/src/spikeinterface/preprocessing/tests/test_pipeline.py index 27432bb156..6a96d4a66c 100644 --- a/src/spikeinterface/preprocessing/tests/test_pipeline.py +++ b/src/spikeinterface/preprocessing/tests/test_pipeline.py @@ -1,3 +1,7 @@ +import pytest + +from spikeinterface.core.testing import check_recordings_equal +from spikeinterface.core import create_sorting_analyzer from spikeinterface.generation import generate_recording, generate_ground_truth_recording from spikeinterface.preprocessing import ( apply_preprocessing_pipeline, @@ -14,8 +18,6 @@ get_preprocessing_dict_from_file, get_preprocessing_dict_from_analyzer, ) -from spikeinterface.core.testing import check_recordings_equal -from spikeinterface.core import create_sorting_analyzer def test_pipeline_equiv_to_step(): @@ -212,6 +214,62 @@ def test_loading_from_analyzer(create_cache_folder): check_recordings_equal(pp_recording, pp_recording_from_zarr) +def test_pipeline_recording_arg_substitution(create_cache_folder): + """ + Tests that if a preprocessing step in the pipeline has an argument that is a string of the form "pipeline[preprocessor_name]", + then this string is replaced by the recording output by the preprocessor with name "preprocessor_name". This allows users to + use outputs of previous preprocessors as arguments for later preprocessors in the same pipeline. + """ + from spikeinterface.preprocessing.filter import BandpassFilterRecording + from spikeinterface.preprocessing.common_reference import CommonReferenceRecording + from spikeinterface.preprocessing.detect_artifacts import DetectAndRemoveArtifactsRecording + + rec = generate_recording(durations=[1]) + + # "recording" argument is protected, as it is the default argument for the recording to preprocess + pipeline_dict_wrong = { + "common_reference": {}, + "bandpass_filter": {"recording": "pipeline[raw]"}, + } + with pytest.raises(ValueError): + pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pipeline_dict_wrong) + + # The argument using the pipeline substitution must be a string with "recording" as substring + pipeline_dict_wrong2 = { + "common_reference": {}, + "bandpass_filter": {"freq_min": "pipeline[raw]"}, + } + with pytest.raises(ValueError): + pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pipeline_dict_wrong2) + + # Correct usage: the "recording_to_detect" argument for the "detect_and_remove_artifacts" step is set to be the + # output of the "bandpass_filter" step, which is correctly substituted when applying the pipeline. + # The "recording" argument for the "detect_and_remove_artifacts" step should be set to the output of the + # "common_reference" step, as this is the last preprocessor in the pipeline before it. + pipeline_dict_correct = { + "bandpass_filter": {}, + "common_reference": {}, + "detect_and_remove_artifacts": {"recording_to_detect": "pipeline[bandpass_filter]"}, + } + pp_rec_from_pipeline = apply_preprocessing_pipeline(rec, pipeline_dict_correct) + # Check that the recording argument for detect step is common ref, + # and that the recording_to_detect argument for detect_and_remove_artifacts is also the output of bandpass_filter + assert isinstance(pp_rec_from_pipeline._kwargs["recording_to_detect"], BandpassFilterRecording) + assert isinstance(pp_rec_from_pipeline._kwargs["recording"], CommonReferenceRecording) + assert isinstance(pp_rec_from_pipeline, DetectAndRemoveArtifactsRecording) + + # Test dumping the pipeline to pickle and loading it back with the correct substitution still works + pp_rec_from_pipeline.dump_to_pickle(create_cache_folder / "pipeline_substitution_test.pkl") + pp_rec_from_pkl = get_preprocessing_dict_from_file(create_cache_folder / "pipeline_substitution_test.pkl") + pp_rec_from_pipeline_substitution = apply_preprocessing_pipeline( + rec, pp_rec_from_pkl, apply_precomputed_kwargs=True + ) + assert isinstance(pp_rec_from_pipeline_substitution._kwargs["recording_to_detect"], BandpassFilterRecording) + assert isinstance(pp_rec_from_pipeline_substitution._kwargs["recording"], CommonReferenceRecording) + assert isinstance(pp_rec_from_pipeline_substitution, DetectAndRemoveArtifactsRecording) + check_recordings_equal(pp_rec_from_pipeline, pp_rec_from_pipeline_substitution) + + if __name__ == "__main__": import tempfile from pathlib import Path diff --git a/src/spikeinterface/preprocessing/tests/test_silence_periods.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py index 44bd205f1b..d3531f1a0f 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence_periods.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -39,7 +39,6 @@ def test_silence(create_cache_folder): data1 = rec.get_traces(0, 400, 600) data2 = rec.get_traces(0, 500, 700) assert np.all(data1[100:] == data2[:100]) - traces_mix = rec0.get_traces(segment_index=0, start_frame=900, end_frame=5100) traces_original = rec.get_traces(segment_index=0, start_frame=900, end_frame=5100) assert np.all(traces_original[100:-100] == traces_mix[100:-100]) @@ -49,6 +48,49 @@ def test_silence(create_cache_folder): assert not np.all(traces_mix[:-200] == 0) +def test_silence_with_apodization(create_cache_folder): + + cache_folder = create_cache_folder + + rec = generate_recording() + + periods = np.array([(0, 0, 1000), (0, 5000, 6000)], dtype=base_period_dtype) + # test that the apodization creates a taper + apodization_samples = 10 + rec2 = silence_periods(rec, periods=periods, mode="apodization", apodization_samples=apodization_samples) + traces_in0 = rec2.get_traces(segment_index=0, start_frame=0, end_frame=1000) + traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000, end_frame=6000) + # all apodized traces + assert np.all(traces_in0 == 0) + assert np.all(traces_in1 == 0) + + # at margins, traces should not be all zero, but should be apodized + apodized_traces_in0 = rec2.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_samples) + apodized_traces_in1 = rec2.get_traces(segment_index=0, start_frame=5000 - apodization_samples, end_frame=5000) + traces_raw_in0 = rec.get_traces(segment_index=0, start_frame=1000, end_frame=1000 + apodization_samples) + traces_raw_in1 = rec.get_traces(segment_index=0, start_frame=5000 - apodization_samples, end_frame=5000) + # the apodized traces should be less than the raw traces in absolute value, + # since they are multiplied by a cosine taper between 0 and 1 + assert np.all(np.abs(apodized_traces_in0) <= np.abs(traces_raw_in0)) + assert np.all(np.abs(apodized_traces_in1) <= np.abs(traces_raw_in1)) + + # check that margins are handled correctly with apodization + extra_samples = 50 + traces_at_offset = rec2.get_traces(segment_index=0, start_frame=998, end_frame=1002) + traces_at_offset_extended = rec2.get_traces( + segment_index=0, start_frame=998 - extra_samples, end_frame=1002 + extra_samples + ) + # the traces at offset should be apodized, and the extended traces should have the same apodization in the overlapping region + assert np.array_equal(traces_at_offset, traces_at_offset_extended[extra_samples:-extra_samples]) + + traces_at_onset = rec2.get_traces(segment_index=0, start_frame=4997, end_frame=5003) + traces_at_onset_extended = rec2.get_traces( + segment_index=0, start_frame=4997 - extra_samples, end_frame=5003 + extra_samples + ) + # the traces at onset should be apodized, and the extended traces should have the same apodization in the overlapping region + assert np.array_equal(traces_at_onset, traces_at_onset_extended[extra_samples:-extra_samples]) + + if __name__ == "__main__": cache_folder = Path(__file__).resolve().parents[4] / "cache_folder" test_silence(cache_folder)