From d138d3b97f0ce86a9ae85a6e0702aa47071e0566 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 15:49:53 -0600 Subject: [PATCH 1/6] add shift times to sorting --- src/spikeinterface/core/basesorting.py | 71 +++++++++-- .../core/tests/test_time_handling.py | 111 ++++++++++++++++++ 2 files changed, 171 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index a5cbdff18e..f90990f133 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -280,7 +280,14 @@ def get_unit_spike_train_in_seconds( # Some instances might implement a method themselves to access spike times directly without having to convert # (e.g. NWB extractors) if hasattr(segment, "get_unit_spike_train_in_seconds"): - return segment.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time) + spike_times = segment.get_unit_spike_train_in_seconds( + unit_id=unit_id, start_time=start_time, end_time=end_time + ) + # Apply the sorting's shift on top of the native times + t_start = segment._t_start if segment._t_start is not None else 0 + if t_start != 0: + spike_times = spike_times + t_start + return spike_times # If no recording attached and all back to frame-based conversion # Get spike train in frames and convert to times using traditional method @@ -374,11 +381,38 @@ def get_start_time(self, segment_index: int | None = None) -> float: segment = self.segments[segment_index] return segment._t_start if segment._t_start is not None else 0.0 + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + This modifies the sorting's own time offset without touching the registered + recording. When a recording is registered, the shift is applied on top of + the recording's time basis when resolving timestamps. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + If negative, times will be decreased. + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for segment_index in segments_to_shift: + segment = self.segments[segment_index] + segment._t_start = (segment._t_start if segment._t_start is not None else 0) + shift + def get_end_time(self, segment_index: int | None = None) -> float: """Get the end time of the sorting segment. - If a recording is registered, returns the recording's end time. - Otherwise returns the time of the last spike in the segment. + If a recording is registered, returns the recording's end time (plus any + shift applied via `shift_times`). Otherwise returns the time of the last + spike in the segment. Parameters ---------- @@ -392,7 +426,10 @@ def get_end_time(self, segment_index: int | None = None) -> float: """ segment_index = self._check_segment_index(segment_index) if self.has_recording(): - return self._recording.get_end_time(segment_index=segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + return self._recording.get_end_time(segment_index=segment_index) + shift else: last_spike_frame = self.get_last_spike_frame(segment_index=segment_index) return self.sample_index_to_time(last_spike_frame, segment_index=segment_index) @@ -425,11 +462,19 @@ def get_times(self, segment_index=None): * if the segment has a time_vector, then it is returned * if not, a time_vector is constructed on the fly with sampling frequency + Any shift applied via `shift_times` is added to the returned times. + If there is no registered recording it returns None """ segment_index = self._check_segment_index(segment_index) if self.has_recording(): - return self._recording.get_times(segment_index=segment_index) + times = self._recording.get_times(segment_index=segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + if shift != 0: + times = times + shift + return times else: return None @@ -771,11 +816,13 @@ def time_to_sample_index(self, time, segment_index=0): """ Transform time in seconds into sample index """ + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 if self.has_recording(): - sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index) + # Subtract the sorting's shift (relative to the recording's start) before delegating + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + sample_index = self._recording.time_to_sample_index(time - shift, segment_index=segment_index) else: - segment = self.segments[segment_index] - t_start = segment._t_start if segment._t_start is not None else 0 sample_index = round((time - t_start) * self.get_sampling_frequency()) return sample_index @@ -787,11 +834,13 @@ def sample_index_to_time( Transform sample index into time in seconds """ segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 if self.has_recording(): - return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + # Add the sorting's shift (relative to the recording's start) after delegating + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + shift else: - segment = self.segments[segment_index] - t_start = segment._t_start if segment._t_start is not None else 0 return (sample_index / self.get_sampling_frequency()) + t_start def precompute_spike_trains(self): diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index bd74ddfe02..263ed0b36e 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -465,6 +465,52 @@ def test_get_start_time_with_t_start(self): sorting.segments[0]._t_start = 100.0 assert sorting.get_start_time(segment_index=0) == 100.0 + def test_shift_times(self): + sorting = generate_sorting(num_units=5, durations=[10]) + unit_id = sorting.unit_ids[0] + + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + assert sorting.get_start_time(segment_index=0) == 5.0 + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + def test_shift_times_all_segments(self): + sorting = generate_sorting(num_units=5, durations=[10, 15]) + sorting.segments[0]._t_start = 1.0 + sorting.segments[1]._t_start = 2.0 + + sorting.shift_times(shift=3.0) + + assert sorting.get_start_time(segment_index=0) == 4.0 + assert sorting.get_start_time(segment_index=1) == 5.0 + + def test_shift_times_single_segment(self): + sorting = generate_sorting(num_units=5, durations=[10, 15]) + sorting.segments[0]._t_start = 1.0 + sorting.segments[1]._t_start = 2.0 + + sorting.shift_times(shift=3.0, segment_index=1) + + assert sorting.get_start_time(segment_index=0) == 1.0 + assert sorting.get_start_time(segment_index=1) == 5.0 + + def test_shift_times_with_native_spike_times(self): + """Shift must apply even when the segment provides native spike times (e.g. NWB extractors).""" + sorting = generate_sorting(num_units=5, durations=[10]) + unit_id = sorting.unit_ids[0] + segment = sorting.segments[0] + + # Simulate a segment that provides native spike times directly + original_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True).copy() + segment.get_unit_spike_train_in_seconds = lambda unit_id, start_time, end_time: original_times + + sorting.shift_times(shift=5.0) + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times, original_times + 5.0) + class TestSortingTimeWithRecording: """ @@ -503,3 +549,68 @@ def test_with_recording_shifted_start(self): sorting.register_recording(recording) assert sorting.get_start_time(segment_index=0) == 50.0 + + def test_shift_times(self): + recording = generate_recording(num_channels=4, durations=[10]) + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + unit_id = sorting.unit_ids[0] + + rec_start_before = recording.get_start_time(segment_index=0) + rec_end_before = recording.get_end_time(segment_index=0) + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + # The recording should be untouched + assert recording.get_start_time(segment_index=0) == rec_start_before + assert recording.get_end_time(segment_index=0) == rec_end_before + + # The sorting's times should be shifted + assert sorting.get_start_time(segment_index=0) == rec_start_before + 5.0 + assert sorting.get_end_time(segment_index=0) == rec_end_before + 5.0 + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + def test_time_conversion_roundtrip_after_shift(self): + """sample_index_to_time and time_to_sample_index must remain inverses after a shift.""" + recording = generate_recording(num_channels=4, durations=[10]) + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + + sorting.shift_times(shift=5.0) + + # Frame 30000 is 1.0s in the recording. After a 5.0s shift, the sorting should report 6.0s. + time = sorting.sample_index_to_time(30000, segment_index=0) + assert time == recording.sample_index_to_time(30000, segment_index=0) + 5.0 + + # The inverse: 6.0s in the sorting should map back to frame 30000. + frame = sorting.time_to_sample_index(time, segment_index=0) + assert frame == 30000 + + def test_shift_times_with_time_vector(self): + """Shift on sorting composes with a recording that has an explicit time vector, + preserving the irregular spacing.""" + recording = generate_recording(num_channels=4, durations=[1.0]) + num_samples = recording.get_num_samples(segment_index=0) + # Irregular timestamps starting at 100.0 + times = ( + 100.0 + + np.cumsum(np.random.RandomState(0).uniform(0.5, 1.5, num_samples)) / recording.get_sampling_frequency() + ) + recording.set_times(times, segment_index=0, with_warning=False) + + sorting = generate_sorting(num_units=5, durations=[1.0]) + sorting.register_recording(recording) + unit_id = sorting.unit_ids[0] + + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + # Irregular spacing preserved, everything shifted by 5.0 + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + # Recording is untouched + assert np.allclose(recording.get_times(segment_index=0), times) From 0cd029e1a662a30a28698e0425c279f28fdde7fb Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 20:15:50 -0600 Subject: [PATCH 2/6] ad native t_start --- src/spikeinterface/core/basesorting.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index f90990f133..e6bd6158e9 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -278,15 +278,17 @@ def get_unit_spike_train_in_seconds( # Use the native spiking times if available # Some instances might implement a method themselves to access spike times directly without having to convert - # (e.g. NWB extractors) + # (e.g. NWB extractors). The native times already include the extractor's `_native_t_start`, + # so we apply only the shift (`_t_start - _native_t_start`) on top. if hasattr(segment, "get_unit_spike_train_in_seconds"): spike_times = segment.get_unit_spike_train_in_seconds( unit_id=unit_id, start_time=start_time, end_time=end_time ) - # Apply the sorting's shift on top of the native times t_start = segment._t_start if segment._t_start is not None else 0 - if t_start != 0: - spike_times = spike_times + t_start + native_t_start = segment._native_t_start if segment._native_t_start is not None else 0 + shift = t_start - native_t_start + if shift != 0: + spike_times = spike_times + shift return spike_times # If no recording attached and all back to frame-based conversion @@ -337,8 +339,12 @@ def register_recording(self, recording, check_spike_frames: bool = True): # Copy the recording's start times into the sorting segments. This way, # the sorting preserves the start time even if the recording is later # detached (e.g. analyzer saved and reloaded without the recording). + # Also update `_native_t_start` so any subsequent `shift_times` call measures + # its delta from the recording's start time (not the extractor's original value). for segment_index, segment in enumerate(self.segments): - segment._t_start = recording.get_start_time(segment_index=segment_index) + start_time = recording.get_start_time(segment_index=segment_index) + segment._t_start = start_time + segment._native_t_start = start_time @property def sorting_info(self): @@ -1198,6 +1204,11 @@ class BaseSortingSegment(BaseSegment): def __init__(self, t_start=None): self._t_start = t_start + # Immutable reference to the start time as set by the extractor at init. + # Used to compute the user-applied shift as `_t_start - _native_t_start`, + # so `shift_times` can correctly propagate through extractors that return + # native absolute times (e.g. NWB) without double-counting the extractor's offset. + self._native_t_start = t_start BaseSegment.__init__(self) def get_unit_spike_train( From c2d74b1a156b8d2c280b623f962812d00fc92cd4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 21 Apr 2026 20:35:16 -0600 Subject: [PATCH 3/6] nwb --- .../extractors/neoextractors/neobaseextractor.py | 3 +-- src/spikeinterface/extractors/nwbextractors.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index 3d787f9519..fa91f0f4c2 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -618,11 +618,10 @@ def __init__( sampling_frequency, neo_returns_frames, ): - BaseSortingSegment.__init__(self) + BaseSortingSegment.__init__(self, t_start=t_start) self.neo_reader = neo_reader self.segment_index = segment_index self.block_index = block_index - self._t_start = t_start self._sampling_frequency = sampling_frequency self.neo_returns_frames = neo_returns_frames diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b89999d088..b223b97398 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -1313,13 +1313,12 @@ def _fetch_properties(self, columns): class NwbSortingSegment(BaseSortingSegment): def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): - BaseSortingSegment.__init__(self) + BaseSortingSegment.__init__(self, t_start=t_start) self.spike_times_data = spike_times_data self.spike_times_index_data = spike_times_index_data self.spike_times_data = spike_times_data self.spike_times_index_data = spike_times_index_data self._sampling_frequency = sampling_frequency - self._t_start = t_start def get_unit_spike_train( self, From 32e40b81c46870ba769b4f10c1dee8bcb372a473 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 6 May 2026 14:59:01 -0600 Subject: [PATCH 4/6] chris request --- src/spikeinterface/core/generate.py | 14 ++++++++++ src/spikeinterface/core/sortinganalyzer.py | 14 ++++++++++ .../core/tests/test_sortinganalyzer.py | 15 +++++++++++ .../core/tests/test_time_handling.py | 26 ++++++++----------- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 1c9ece728f..1d16a8f9b5 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -95,6 +95,7 @@ def generate_sorting( add_spikes_on_borders=False, num_spikes_per_border=3, border_size_samples=20, + t_starts=None, seed=None, ): """ @@ -122,6 +123,9 @@ def generate_sorting( The number of spikes to add close to the borders of the segments. border_size_samples : int, default: 20 The size of the border in samples to add border spikes. + t_starts : list of float | None, default: None + Per-segment start times in seconds. Must match the length of `durations`. + If None, all segments start at t=0. seed : int, default: None The random seed. @@ -177,6 +181,16 @@ def generate_sorting( sorting = NumpySorting(spikes, sampling_frequency, unit_ids) + if t_starts is not None: + if len(t_starts) != num_segments: + raise ValueError( + f"`t_starts` must have the same length as `durations` ({num_segments}), got {len(t_starts)}." + ) + for segment_index, t_start in enumerate(t_starts): + segment = sorting.segments[segment_index] + segment._t_start = float(t_start) + segment._native_t_start = float(t_start) + return sorting diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..98bd82840c 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -363,6 +363,20 @@ def create( f"recording: {recording.sampling_frequency} - sorting: {sorting.sampling_frequency}. " "Ensure that you are associating the correct Recording and Sorting when creating a SortingAnalyzer." ) + + # Check that sorting and recording start times match per segment. A mismatch typically + # means the user shifted one but not the other (e.g. via `shift_times` on only one side). + for segment_index in range(sorting.get_num_segments()): + sorting_start = sorting.get_start_time(segment_index=segment_index) + recording_start = recording.get_start_time(segment_index=segment_index) + if not math.isclose(sorting_start, recording_start, abs_tol=1e-6, rel_tol=1e-6): + raise ValueError( + f"Sorting and Recording start times do not match for segment {segment_index}: " + f"recording: {recording_start} - sorting: {sorting_start}. " + "Call `sorting.register_recording(recording)` to align them, or apply the " + "matching `shift_times` to the side that is out of sync." + ) + # check that multiple probes are non-overlapping all_probes = recording.get_probegroup().probes check_probe_do_not_overlap(all_probes) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..14aaa02f66 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -304,6 +304,21 @@ def test_SortingAnalyzer_tmp_recording(dataset): sorting_analyzer.set_temporary_recording(recording_sliced) +def test_create_sorting_analyzer_start_time_mismatch(dataset): + """create_sorting_analyzer must raise when the sorting and recording have different start times.""" + recording, sorting = dataset + + # Shift only the sorting, leaving the recording at t=0. + sorting.shift_times(shift=5.0) + + with pytest.raises(ValueError, match="start times do not match"): + create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + + # Aligning them via register_recording resolves the mismatch. + sorting.register_recording(recording) + create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + + def test_SortingAnalyzer_interleaved_probegroup(dataset): from probeinterface import generate_linear_probe, ProbeGroup diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 263ed0b36e..b35fd7fb16 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -286,8 +286,10 @@ def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recordi """ _, times_recording, _ = time_vector_recording + num_segments = times_recording.get_num_segments() sorting = si.generate_sorting( - durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())] + durations=[times_recording.get_duration(s) for s in range(num_segments)], + t_starts=[times_recording.get_start_time(segment_index=s) for s in range(num_segments)], ) sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording) @@ -461,8 +463,7 @@ def test_get_end_time_is_last_spike(self): assert sorting.get_end_time(segment_index=0) == expected_time def test_get_start_time_with_t_start(self): - sorting = generate_sorting(num_units=5, durations=[10]) - sorting.segments[0]._t_start = 100.0 + sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0]) assert sorting.get_start_time(segment_index=0) == 100.0 def test_shift_times(self): @@ -478,9 +479,7 @@ def test_shift_times(self): assert np.allclose(spike_times_after, spike_times_before + 5.0) def test_shift_times_all_segments(self): - sorting = generate_sorting(num_units=5, durations=[10, 15]) - sorting.segments[0]._t_start = 1.0 - sorting.segments[1]._t_start = 2.0 + sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0]) sorting.shift_times(shift=3.0) @@ -488,9 +487,7 @@ def test_shift_times_all_segments(self): assert sorting.get_start_time(segment_index=1) == 5.0 def test_shift_times_single_segment(self): - sorting = generate_sorting(num_units=5, durations=[10, 15]) - sorting.segments[0]._t_start = 1.0 - sorting.segments[1]._t_start = 2.0 + sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0]) sorting.shift_times(shift=3.0, segment_index=1) @@ -527,17 +524,16 @@ def test_get_start_end_time(self): assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0) def test_register_recording_copies_start_times(self): - """Registering a recording copies its start times into the sorting segments.""" - sorting = generate_sorting(num_units=5, durations=[10]) - sorting.segments[0]._t_start = 100.0 + """Registering a recording overrides any pre-existing sorting start time.""" + sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0]) recording = generate_recording(num_channels=4, durations=[10]) recording.shift_times(shift=50.0) sorting.register_recording(recording) - # _t_start now mirrors the recording's start time, preserving it across - # save/load cycles even when the recording is not attached. - assert sorting.segments[0]._t_start == recording.get_start_time(segment_index=0) + # The sorting's start time now mirrors the recording's start time, preserving it + # across save/load cycles even when the recording is later detached. + assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0) assert sorting.get_start_time(segment_index=0) == 50.0 def test_with_recording_shifted_start(self): From 8e6a6498016af4bf26793541852b813459b60ab6 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 6 May 2026 16:42:49 -0600 Subject: [PATCH 5/6] re-do --- src/spikeinterface/core/frameslicesorting.py | 13 ++++++++++--- src/spikeinterface/core/sortinganalyzer.py | 14 -------------- .../core/tests/test_sortinganalyzer.py | 15 --------------- .../core/tests/test_time_handling.py | 7 ++----- 4 files changed, 12 insertions(+), 37 deletions(-) diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index a337e83707..2d3195e8a5 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -76,7 +76,9 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike # link sorting segment parent_segment = parent_sorting.segments[0] - sub_segment = FrameSliceSortingSegment(parent_segment, start_frame, end_frame) + sub_segment = FrameSliceSortingSegment( + parent_segment, start_frame, end_frame, sampling_frequency=parent_sorting.get_sampling_frequency() + ) self.add_sorting_segment(sub_segment) # copy properties and annotations @@ -96,8 +98,13 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike class FrameSliceSortingSegment(BaseSortingSegment): - def __init__(self, parent_sorting_segment, start_frame, end_frame): - BaseSortingSegment.__init__(self) + def __init__(self, parent_sorting_segment, start_frame, end_frame, sampling_frequency): + # Propagate the parent's start time forward by the slice offset, mirroring + # what FrameSliceRecordingSegment does. A parent with `_t_start=None` is + # treated as starting at 0, so the slice gets a concrete `start_frame / fs`. + parent_t_start = parent_sorting_segment._t_start if parent_sorting_segment._t_start is not None else 0.0 + t_start = parent_t_start + start_frame / sampling_frequency + BaseSortingSegment.__init__(self, t_start=t_start) self._parent_sorting_segment = parent_sorting_segment self.start_frame = start_frame self.end_frame = end_frame diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 98bd82840c..8e16757bcc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -363,20 +363,6 @@ def create( f"recording: {recording.sampling_frequency} - sorting: {sorting.sampling_frequency}. " "Ensure that you are associating the correct Recording and Sorting when creating a SortingAnalyzer." ) - - # Check that sorting and recording start times match per segment. A mismatch typically - # means the user shifted one but not the other (e.g. via `shift_times` on only one side). - for segment_index in range(sorting.get_num_segments()): - sorting_start = sorting.get_start_time(segment_index=segment_index) - recording_start = recording.get_start_time(segment_index=segment_index) - if not math.isclose(sorting_start, recording_start, abs_tol=1e-6, rel_tol=1e-6): - raise ValueError( - f"Sorting and Recording start times do not match for segment {segment_index}: " - f"recording: {recording_start} - sorting: {sorting_start}. " - "Call `sorting.register_recording(recording)` to align them, or apply the " - "matching `shift_times` to the side that is out of sync." - ) - # check that multiple probes are non-overlapping all_probes = recording.get_probegroup().probes check_probe_do_not_overlap(all_probes) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 14aaa02f66..a9bd71b5c0 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -304,21 +304,6 @@ def test_SortingAnalyzer_tmp_recording(dataset): sorting_analyzer.set_temporary_recording(recording_sliced) -def test_create_sorting_analyzer_start_time_mismatch(dataset): - """create_sorting_analyzer must raise when the sorting and recording have different start times.""" - recording, sorting = dataset - - # Shift only the sorting, leaving the recording at t=0. - sorting.shift_times(shift=5.0) - - with pytest.raises(ValueError, match="start times do not match"): - create_sorting_analyzer(sorting, recording, format="memory", sparse=False) - - # Aligning them via register_recording resolves the mismatch. - sorting.register_recording(recording) - create_sorting_analyzer(sorting, recording, format="memory", sparse=False) - - def test_SortingAnalyzer_interleaved_probegroup(dataset): from probeinterface import generate_linear_probe, ProbeGroup diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index b35fd7fb16..8df1079734 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -286,11 +286,8 @@ def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recordi """ _, times_recording, _ = time_vector_recording - num_segments = times_recording.get_num_segments() - sorting = si.generate_sorting( - durations=[times_recording.get_duration(s) for s in range(num_segments)], - t_starts=[times_recording.get_start_time(segment_index=s) for s in range(num_segments)], - ) + durations = [times_recording.get_duration(s) for s in range(times_recording.get_num_segments())] + sorting = si.generate_sorting(durations=durations) sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording) assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration()) From f6430d5175b4c24e80f22a08860f400f8e3ce20a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 6 May 2026 16:49:13 -0600 Subject: [PATCH 6/6] add time slice test --- .../core/tests/test_frameslicesorting.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/spikeinterface/core/tests/test_frameslicesorting.py b/src/spikeinterface/core/tests/test_frameslicesorting.py index 7c1ad0ae34..990233229e 100644 --- a/src/spikeinterface/core/tests/test_frameslicesorting.py +++ b/src/spikeinterface/core/tests/test_frameslicesorting.py @@ -91,5 +91,25 @@ def test_FrameSliceSorting(): assert_raises(Exception, sorting_exceeding.frame_slice, None, None) +def test_time_slice_propagates_t_start(): + """`time_slice` goes through `frame_slice`, so the propagated start time should + equal the requested `start_time`. Covers both the parent-with-no-t_start case + and the parent-with-explicit-t_start case (which should stack).""" + sf = 10.0 + spike_times = {"0": np.arange(100, 900)} + + # Parent has no explicit t_start (treated as 0). + sorting = NumpySorting.from_unit_dict([spike_times], sf) + sub = sorting.time_slice(start_time=20.0, end_time=50.0) + assert sub.get_start_time(segment_index=0) == 20.0 + + # Parent has an explicit t_start; the slice offset stacks on top. + sorting_shifted = NumpySorting.from_unit_dict([spike_times], sf) + sorting_shifted.shift_times(shift=100.0) + sub_shifted = sorting_shifted.time_slice(start_time=120.0, end_time=150.0) + assert sub_shifted.get_start_time(segment_index=0) == 120.0 + + if __name__ == "__main__": test_FrameSliceSorting() + test_time_slice_propagates_t_start()