diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fe0cd03641..fc8373cdfb 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -278,9 +278,18 @@ def get_unit_spike_train_in_seconds( # Use the native spiking times if available # Some instances might implement a method themselves to access spike times directly without having to convert - # (e.g. NWB extractors) + # (e.g. NWB extractors). The native times already include the extractor's `_native_t_start`, + # so we apply only the shift (`_t_start - _native_t_start`) on top. if hasattr(segment, "get_unit_spike_train_in_seconds"): - return segment.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time) + spike_times = segment.get_unit_spike_train_in_seconds( + unit_id=unit_id, start_time=start_time, end_time=end_time + ) + t_start = segment._t_start if segment._t_start is not None else 0 + native_t_start = segment._native_t_start if segment._native_t_start is not None else 0 + shift = t_start - native_t_start + if shift != 0: + spike_times = spike_times + shift + return spike_times # If no recording attached and all back to frame-based conversion # Get spike train in frames and convert to times using traditional method @@ -330,8 +339,12 @@ def register_recording(self, recording, check_spike_frames: bool = True): # Copy the recording's start times into the sorting segments. This way, # the sorting preserves the start time even if the recording is later # detached (e.g. analyzer saved and reloaded without the recording). + # Also update `_native_t_start` so any subsequent `shift_times` call measures + # its delta from the recording's start time (not the extractor's original value). for segment_index, segment in enumerate(self.segments): - segment._t_start = recording.get_start_time(segment_index=segment_index) + start_time = recording.get_start_time(segment_index=segment_index) + segment._t_start = start_time + segment._native_t_start = start_time @property def sorting_info(self): @@ -374,11 +387,38 @@ def get_start_time(self, segment_index: int | None = None) -> float: segment = self.segments[segment_index] return segment._t_start if segment._t_start is not None else 0.0 + def shift_times(self, shift: int | float, segment_index: int | None = None) -> None: + """ + Shift all times by a scalar value. + + This modifies the sorting's own time offset without touching the registered + recording. When a recording is registered, the shift is applied on top of + the recording's time basis when resolving timestamps. + + Parameters + ---------- + shift : int | float + The shift to apply. If positive, times will be increased by `shift`. + If negative, times will be decreased. + segment_index : int | None + The segment on which to shift the times. + If `None`, all segments will be shifted. + """ + if segment_index is None: + segments_to_shift = range(self.get_num_segments()) + else: + segments_to_shift = (segment_index,) + + for segment_index in segments_to_shift: + segment = self.segments[segment_index] + segment._t_start = (segment._t_start if segment._t_start is not None else 0) + shift + def get_end_time(self, segment_index: int | None = None) -> float: """Get the end time of the sorting segment. - If a recording is registered, returns the recording's end time. - Otherwise returns the time of the last spike in the segment. + If a recording is registered, returns the recording's end time (plus any + shift applied via `shift_times`). Otherwise returns the time of the last + spike in the segment. Parameters ---------- @@ -392,7 +432,10 @@ def get_end_time(self, segment_index: int | None = None) -> float: """ segment_index = self._check_segment_index(segment_index) if self.has_recording(): - return self._recording.get_end_time(segment_index=segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + return self._recording.get_end_time(segment_index=segment_index) + shift else: last_spike_frame = self.get_last_spike_frame(segment_index=segment_index) return self.sample_index_to_time(last_spike_frame, segment_index=segment_index) @@ -430,11 +473,19 @@ def get_times( * if the segment has a time_vector, then it is returned * if not, a time_vector is constructed on the fly with sampling frequency + Any shift applied via `shift_times` is added to the returned times. + If there is no registered recording it returns None """ segment_index = self._check_segment_index(segment_index) if self.has_recording(): - return self._recording.get_times(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) + times = self._recording.get_times(segment_index=segment_index, start_frame=start_frame, end_frame=end_frame) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + if shift != 0: + times = times + shift + return times else: return None @@ -776,11 +827,13 @@ def time_to_sample_index(self, time, segment_index=0): """ Transform time in seconds into sample index """ + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 if self.has_recording(): - sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index) + # Subtract the sorting's shift (relative to the recording's start) before delegating + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + sample_index = self._recording.time_to_sample_index(time - shift, segment_index=segment_index) else: - segment = self.segments[segment_index] - t_start = segment._t_start if segment._t_start is not None else 0 sample_index = round((time - t_start) * self.get_sampling_frequency()) return sample_index @@ -792,11 +845,13 @@ def sample_index_to_time( Transform sample index into time in seconds """ segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + t_start = segment._t_start if segment._t_start is not None else 0 if self.has_recording(): - return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + # Add the sorting's shift (relative to the recording's start) after delegating + shift = t_start - self._recording.get_start_time(segment_index=segment_index) + return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + shift else: - segment = self.segments[segment_index] - t_start = segment._t_start if segment._t_start is not None else 0 return (sample_index / self.get_sampling_frequency()) + t_start def precompute_spike_trains(self): @@ -1154,6 +1209,11 @@ class BaseSortingSegment(BaseSegment): def __init__(self, t_start=None): self._t_start = t_start + # Immutable reference to the start time as set by the extractor at init. + # Used to compute the user-applied shift as `_t_start - _native_t_start`, + # so `shift_times` can correctly propagate through extractors that return + # native absolute times (e.g. NWB) without double-counting the extractor's offset. + self._native_t_start = t_start BaseSegment.__init__(self) def get_unit_spike_train( diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index a337e83707..2d3195e8a5 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -76,7 +76,9 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike # link sorting segment parent_segment = parent_sorting.segments[0] - sub_segment = FrameSliceSortingSegment(parent_segment, start_frame, end_frame) + sub_segment = FrameSliceSortingSegment( + parent_segment, start_frame, end_frame, sampling_frequency=parent_sorting.get_sampling_frequency() + ) self.add_sorting_segment(sub_segment) # copy properties and annotations @@ -96,8 +98,13 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike class FrameSliceSortingSegment(BaseSortingSegment): - def __init__(self, parent_sorting_segment, start_frame, end_frame): - BaseSortingSegment.__init__(self) + def __init__(self, parent_sorting_segment, start_frame, end_frame, sampling_frequency): + # Propagate the parent's start time forward by the slice offset, mirroring + # what FrameSliceRecordingSegment does. A parent with `_t_start=None` is + # treated as starting at 0, so the slice gets a concrete `start_frame / fs`. + parent_t_start = parent_sorting_segment._t_start if parent_sorting_segment._t_start is not None else 0.0 + t_start = parent_t_start + start_frame / sampling_frequency + BaseSortingSegment.__init__(self, t_start=t_start) self._parent_sorting_segment = parent_sorting_segment self.start_frame = start_frame self.end_frame = end_frame diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 64593c736e..8b3edd048a 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -193,7 +193,9 @@ def generate_sorting( if t_starts is not None: assert len(t_starts) == len(durations), "t_starts must have the same length as durations" for segment_index, t_start in enumerate(t_starts): - sorting.segments[segment_index]._t_start = t_start + segment = sorting.segments[segment_index] + segment._t_start = float(t_start) + segment._native_t_start = float(t_start) return sorting diff --git a/src/spikeinterface/core/tests/test_frameslicesorting.py b/src/spikeinterface/core/tests/test_frameslicesorting.py index 7c1ad0ae34..990233229e 100644 --- a/src/spikeinterface/core/tests/test_frameslicesorting.py +++ b/src/spikeinterface/core/tests/test_frameslicesorting.py @@ -91,5 +91,25 @@ def test_FrameSliceSorting(): assert_raises(Exception, sorting_exceeding.frame_slice, None, None) +def test_time_slice_propagates_t_start(): + """`time_slice` goes through `frame_slice`, so the propagated start time should + equal the requested `start_time`. Covers both the parent-with-no-t_start case + and the parent-with-explicit-t_start case (which should stack).""" + sf = 10.0 + spike_times = {"0": np.arange(100, 900)} + + # Parent has no explicit t_start (treated as 0). + sorting = NumpySorting.from_unit_dict([spike_times], sf) + sub = sorting.time_slice(start_time=20.0, end_time=50.0) + assert sub.get_start_time(segment_index=0) == 20.0 + + # Parent has an explicit t_start; the slice offset stacks on top. + sorting_shifted = NumpySorting.from_unit_dict([spike_times], sf) + sorting_shifted.shift_times(shift=100.0) + sub_shifted = sorting_shifted.time_slice(start_time=120.0, end_time=150.0) + assert sub_shifted.get_start_time(segment_index=0) == 120.0 + + if __name__ == "__main__": test_FrameSliceSorting() + test_time_slice_propagates_t_start() diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index 76576e54d9..a1e5aa47bf 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -286,9 +286,8 @@ def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recordi """ _, times_recording, _ = time_vector_recording - sorting = si.generate_sorting( - durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())] - ) + durations = [times_recording.get_duration(s) for s in range(times_recording.get_num_segments())] + sorting = si.generate_sorting(durations=durations) sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording) assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration()) @@ -484,10 +483,51 @@ def test_get_end_time_is_last_spike(self): assert sorting.get_end_time(segment_index=0) == expected_time def test_get_start_time_with_t_start(self): - sorting = generate_sorting(num_units=5, durations=[10]) - sorting.segments[0]._t_start = 100.0 + sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0]) assert sorting.get_start_time(segment_index=0) == 100.0 + def test_shift_times(self): + sorting = generate_sorting(num_units=5, durations=[10]) + unit_id = sorting.unit_ids[0] + + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + assert sorting.get_start_time(segment_index=0) == 5.0 + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + def test_shift_times_all_segments(self): + sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0]) + + sorting.shift_times(shift=3.0) + + assert sorting.get_start_time(segment_index=0) == 4.0 + assert sorting.get_start_time(segment_index=1) == 5.0 + + def test_shift_times_single_segment(self): + sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0]) + + sorting.shift_times(shift=3.0, segment_index=1) + + assert sorting.get_start_time(segment_index=0) == 1.0 + assert sorting.get_start_time(segment_index=1) == 5.0 + + def test_shift_times_with_native_spike_times(self): + """Shift must apply even when the segment provides native spike times (e.g. NWB extractors).""" + sorting = generate_sorting(num_units=5, durations=[10]) + unit_id = sorting.unit_ids[0] + segment = sorting.segments[0] + + # Simulate a segment that provides native spike times directly + original_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True).copy() + segment.get_unit_spike_train_in_seconds = lambda unit_id, start_time, end_time: original_times + + sorting.shift_times(shift=5.0) + spike_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times, original_times + 5.0) + class TestSortingTimeWithRecording: """ @@ -504,17 +544,16 @@ def test_get_start_end_time(self): assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0) def test_register_recording_copies_start_times(self): - """Registering a recording copies its start times into the sorting segments.""" - sorting = generate_sorting(num_units=5, durations=[10]) - sorting.segments[0]._t_start = 100.0 + """Registering a recording overrides any pre-existing sorting start time.""" + sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0]) recording = generate_recording(num_channels=4, durations=[10]) recording.shift_times(shift=50.0) sorting.register_recording(recording) - # _t_start now mirrors the recording's start time, preserving it across - # save/load cycles even when the recording is not attached. - assert sorting.segments[0]._t_start == recording.get_start_time(segment_index=0) + # The sorting's start time now mirrors the recording's start time, preserving it + # across save/load cycles even when the recording is later detached. + assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0) assert sorting.get_start_time(segment_index=0) == 50.0 def test_with_recording_shifted_start(self): @@ -526,3 +565,68 @@ def test_with_recording_shifted_start(self): sorting.register_recording(recording) assert sorting.get_start_time(segment_index=0) == 50.0 + + def test_shift_times(self): + recording = generate_recording(num_channels=4, durations=[10]) + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + unit_id = sorting.unit_ids[0] + + rec_start_before = recording.get_start_time(segment_index=0) + rec_end_before = recording.get_end_time(segment_index=0) + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + # The recording should be untouched + assert recording.get_start_time(segment_index=0) == rec_start_before + assert recording.get_end_time(segment_index=0) == rec_end_before + + # The sorting's times should be shifted + assert sorting.get_start_time(segment_index=0) == rec_start_before + 5.0 + assert sorting.get_end_time(segment_index=0) == rec_end_before + 5.0 + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + def test_time_conversion_roundtrip_after_shift(self): + """sample_index_to_time and time_to_sample_index must remain inverses after a shift.""" + recording = generate_recording(num_channels=4, durations=[10]) + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + + sorting.shift_times(shift=5.0) + + # Frame 30000 is 1.0s in the recording. After a 5.0s shift, the sorting should report 6.0s. + time = sorting.sample_index_to_time(30000, segment_index=0) + assert time == recording.sample_index_to_time(30000, segment_index=0) + 5.0 + + # The inverse: 6.0s in the sorting should map back to frame 30000. + frame = sorting.time_to_sample_index(time, segment_index=0) + assert frame == 30000 + + def test_shift_times_with_time_vector(self): + """Shift on sorting composes with a recording that has an explicit time vector, + preserving the irregular spacing.""" + recording = generate_recording(num_channels=4, durations=[1.0]) + num_samples = recording.get_num_samples(segment_index=0) + # Irregular timestamps starting at 100.0 + times = ( + 100.0 + + np.cumsum(np.random.RandomState(0).uniform(0.5, 1.5, num_samples)) / recording.get_sampling_frequency() + ) + recording.set_times(times, segment_index=0, with_warning=False) + + sorting = generate_sorting(num_units=5, durations=[1.0]) + sorting.register_recording(recording) + unit_id = sorting.unit_ids[0] + + spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + + sorting.shift_times(shift=5.0) + + spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True) + # Irregular spacing preserved, everything shifted by 5.0 + assert np.allclose(spike_times_after, spike_times_before + 5.0) + + # Recording is untouched + assert np.allclose(recording.get_times(segment_index=0), times) diff --git a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py index 3d787f9519..fa91f0f4c2 100644 --- a/src/spikeinterface/extractors/neoextractors/neobaseextractor.py +++ b/src/spikeinterface/extractors/neoextractors/neobaseextractor.py @@ -618,11 +618,10 @@ def __init__( sampling_frequency, neo_returns_frames, ): - BaseSortingSegment.__init__(self) + BaseSortingSegment.__init__(self, t_start=t_start) self.neo_reader = neo_reader self.segment_index = segment_index self.block_index = block_index - self._t_start = t_start self._sampling_frequency = sampling_frequency self.neo_returns_frames = neo_returns_frames diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b89999d088..b223b97398 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -1313,13 +1313,12 @@ def _fetch_properties(self, columns): class NwbSortingSegment(BaseSortingSegment): def __init__(self, spike_times_data, spike_times_index_data, sampling_frequency: float, t_start: float): - BaseSortingSegment.__init__(self) + BaseSortingSegment.__init__(self, t_start=t_start) self.spike_times_data = spike_times_data self.spike_times_index_data = spike_times_index_data self.spike_times_data = spike_times_data self.spike_times_index_data = spike_times_index_data self._sampling_frequency = sampling_frequency - self._t_start = t_start def get_unit_spike_train( self,