Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,18 @@ def get_unit_spike_train_in_seconds(

# Use the native spiking times if available
# Some instances might implement a method themselves to access spike times directly without having to convert
# (e.g. NWB extractors)
# (e.g. NWB extractors). The native times already include the extractor's `_native_t_start`,
# so we apply only the shift (`_t_start - _native_t_start`) on top.
if hasattr(segment, "get_unit_spike_train_in_seconds"):
return segment.get_unit_spike_train_in_seconds(unit_id=unit_id, start_time=start_time, end_time=end_time)
spike_times = segment.get_unit_spike_train_in_seconds(
unit_id=unit_id, start_time=start_time, end_time=end_time
)
t_start = segment._t_start if segment._t_start is not None else 0
native_t_start = segment._native_t_start if segment._native_t_start is not None else 0
shift = t_start - native_t_start
if shift != 0:
spike_times = spike_times + shift
return spike_times

# If no recording attached and all back to frame-based conversion
# Get spike train in frames and convert to times using traditional method
Expand Down Expand Up @@ -330,8 +339,12 @@ def register_recording(self, recording, check_spike_frames: bool = True):
# Copy the recording's start times into the sorting segments. This way,
# the sorting preserves the start time even if the recording is later
# detached (e.g. analyzer saved and reloaded without the recording).
# Also update `_native_t_start` so any subsequent `shift_times` call measures
# its delta from the recording's start time (not the extractor's original value).
for segment_index, segment in enumerate(self.segments):
segment._t_start = recording.get_start_time(segment_index=segment_index)
start_time = recording.get_start_time(segment_index=segment_index)
segment._t_start = start_time
segment._native_t_start = start_time

@property
def sorting_info(self):
Expand Down Expand Up @@ -374,11 +387,38 @@ def get_start_time(self, segment_index: int | None = None) -> float:
segment = self.segments[segment_index]
return segment._t_start if segment._t_start is not None else 0.0

def shift_times(self, shift: int | float, segment_index: int | None = None) -> None:
"""
Shift all times by a scalar value.

This modifies the sorting's own time offset without touching the registered
recording. When a recording is registered, the shift is applied on top of
the recording's time basis when resolving timestamps.

Parameters
----------
shift : int | float
The shift to apply. If positive, times will be increased by `shift`.
If negative, times will be decreased.
segment_index : int | None
The segment on which to shift the times.
If `None`, all segments will be shifted.
"""
if segment_index is None:
segments_to_shift = range(self.get_num_segments())
else:
segments_to_shift = (segment_index,)

for segment_index in segments_to_shift:
segment = self.segments[segment_index]
segment._t_start = (segment._t_start if segment._t_start is not None else 0) + shift

def get_end_time(self, segment_index: int | None = None) -> float:
"""Get the end time of the sorting segment.

If a recording is registered, returns the recording's end time.
Otherwise returns the time of the last spike in the segment.
If a recording is registered, returns the recording's end time (plus any
shift applied via `shift_times`). Otherwise returns the time of the last
spike in the segment.

Parameters
----------
Expand All @@ -392,7 +432,10 @@ def get_end_time(self, segment_index: int | None = None) -> float:
"""
segment_index = self._check_segment_index(segment_index)
if self.has_recording():
return self._recording.get_end_time(segment_index=segment_index)
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
return self._recording.get_end_time(segment_index=segment_index) + shift
else:
last_spike_frame = self.get_last_spike_frame(segment_index=segment_index)
return self.sample_index_to_time(last_spike_frame, segment_index=segment_index)
Expand Down Expand Up @@ -425,11 +468,19 @@ def get_times(self, segment_index=None):
* if the segment has a time_vector, then it is returned
* if not, a time_vector is constructed on the fly with sampling frequency

Any shift applied via `shift_times` is added to the returned times.

If there is no registered recording it returns None
"""
segment_index = self._check_segment_index(segment_index)
if self.has_recording():
return self._recording.get_times(segment_index=segment_index)
times = self._recording.get_times(segment_index=segment_index)
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
if shift != 0:
times = times + shift
return times
else:
return None

Expand Down Expand Up @@ -771,11 +822,13 @@ def time_to_sample_index(self, time, segment_index=0):
"""
Transform time in seconds into sample index
"""
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
if self.has_recording():
sample_index = self._recording.time_to_sample_index(time, segment_index=segment_index)
# Subtract the sorting's shift (relative to the recording's start) before delegating
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
sample_index = self._recording.time_to_sample_index(time - shift, segment_index=segment_index)
else:
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
sample_index = round((time - t_start) * self.get_sampling_frequency())

return sample_index
Expand All @@ -787,11 +840,13 @@ def sample_index_to_time(
Transform sample index into time in seconds
"""
segment_index = self._check_segment_index(segment_index)
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
if self.has_recording():
return self._recording.sample_index_to_time(sample_index, segment_index=segment_index)
# Add the sorting's shift (relative to the recording's start) after delegating
shift = t_start - self._recording.get_start_time(segment_index=segment_index)
return self._recording.sample_index_to_time(sample_index, segment_index=segment_index) + shift
else:
segment = self.segments[segment_index]
t_start = segment._t_start if segment._t_start is not None else 0
return (sample_index / self.get_sampling_frequency()) + t_start

def precompute_spike_trains(self):
Expand Down Expand Up @@ -1149,6 +1204,11 @@ class BaseSortingSegment(BaseSegment):

def __init__(self, t_start=None):
self._t_start = t_start
# Immutable reference to the start time as set by the extractor at init.
# Used to compute the user-applied shift as `_t_start - _native_t_start`,
# so `shift_times` can correctly propagate through extractors that return
# native absolute times (e.g. NWB) without double-counting the extractor's offset.
self._native_t_start = t_start
BaseSegment.__init__(self)

def get_unit_spike_train(
Expand Down
13 changes: 10 additions & 3 deletions src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike

# link sorting segment
parent_segment = parent_sorting.segments[0]
sub_segment = FrameSliceSortingSegment(parent_segment, start_frame, end_frame)
sub_segment = FrameSliceSortingSegment(
parent_segment, start_frame, end_frame, sampling_frequency=parent_sorting.get_sampling_frequency()
)
self.add_sorting_segment(sub_segment)

# copy properties and annotations
Expand All @@ -96,8 +98,13 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike


class FrameSliceSortingSegment(BaseSortingSegment):
def __init__(self, parent_sorting_segment, start_frame, end_frame):
BaseSortingSegment.__init__(self)
def __init__(self, parent_sorting_segment, start_frame, end_frame, sampling_frequency):
# Propagate the parent's start time forward by the slice offset, mirroring
# what FrameSliceRecordingSegment does. A parent with `_t_start=None` is
# treated as starting at 0, so the slice gets a concrete `start_frame / fs`.
parent_t_start = parent_sorting_segment._t_start if parent_sorting_segment._t_start is not None else 0.0
t_start = parent_t_start + start_frame / sampling_frequency
BaseSortingSegment.__init__(self, t_start=t_start)
self._parent_sorting_segment = parent_sorting_segment
self.start_frame = start_frame
self.end_frame = end_frame
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def generate_sorting(
if t_starts is not None:
assert len(t_starts) == len(durations), "t_starts must have the same length as durations"
for segment_index, t_start in enumerate(t_starts):
sorting.segments[segment_index]._t_start = t_start
segment = sorting.segments[segment_index]
segment._t_start = float(t_start)
segment._native_t_start = float(t_start)

return sorting

Expand Down
20 changes: 20 additions & 0 deletions src/spikeinterface/core/tests/test_frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,25 @@ def test_FrameSliceSorting():
assert_raises(Exception, sorting_exceeding.frame_slice, None, None)


def test_time_slice_propagates_t_start():
"""`time_slice` goes through `frame_slice`, so the propagated start time should
equal the requested `start_time`. Covers both the parent-with-no-t_start case
and the parent-with-explicit-t_start case (which should stack)."""
sf = 10.0
spike_times = {"0": np.arange(100, 900)}

# Parent has no explicit t_start (treated as 0).
sorting = NumpySorting.from_unit_dict([spike_times], sf)
sub = sorting.time_slice(start_time=20.0, end_time=50.0)
assert sub.get_start_time(segment_index=0) == 20.0

# Parent has an explicit t_start; the slice offset stacks on top.
sorting_shifted = NumpySorting.from_unit_dict([spike_times], sf)
sorting_shifted.shift_times(shift=100.0)
sub_shifted = sorting_shifted.time_slice(start_time=120.0, end_time=150.0)
assert sub_shifted.get_start_time(segment_index=0) == 120.0


if __name__ == "__main__":
test_FrameSliceSorting()
test_time_slice_propagates_t_start()
126 changes: 115 additions & 11 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,8 @@ def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recordi
"""
_, times_recording, _ = time_vector_recording

sorting = si.generate_sorting(
durations=[times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
)
durations = [times_recording.get_duration(s) for s in range(times_recording.get_num_segments())]
sorting = si.generate_sorting(durations=durations)
sorting_analyzer = si.create_sorting_analyzer(sorting, recording=times_recording)

assert np.array_equal(sorting_analyzer.get_total_duration(), times_recording.get_total_duration())
Expand Down Expand Up @@ -461,10 +460,51 @@ def test_get_end_time_is_last_spike(self):
assert sorting.get_end_time(segment_index=0) == expected_time

def test_get_start_time_with_t_start(self):
sorting = generate_sorting(num_units=5, durations=[10])
sorting.segments[0]._t_start = 100.0
sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0])
assert sorting.get_start_time(segment_index=0) == 100.0

def test_shift_times(self):
sorting = generate_sorting(num_units=5, durations=[10])
unit_id = sorting.unit_ids[0]

spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)

sorting.shift_times(shift=5.0)

assert sorting.get_start_time(segment_index=0) == 5.0
spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
assert np.allclose(spike_times_after, spike_times_before + 5.0)

def test_shift_times_all_segments(self):
sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0])

sorting.shift_times(shift=3.0)

assert sorting.get_start_time(segment_index=0) == 4.0
assert sorting.get_start_time(segment_index=1) == 5.0

def test_shift_times_single_segment(self):
sorting = generate_sorting(num_units=5, durations=[10, 15], t_starts=[1.0, 2.0])

sorting.shift_times(shift=3.0, segment_index=1)

assert sorting.get_start_time(segment_index=0) == 1.0
assert sorting.get_start_time(segment_index=1) == 5.0

def test_shift_times_with_native_spike_times(self):
"""Shift must apply even when the segment provides native spike times (e.g. NWB extractors)."""
sorting = generate_sorting(num_units=5, durations=[10])
unit_id = sorting.unit_ids[0]
segment = sorting.segments[0]

# Simulate a segment that provides native spike times directly
original_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True).copy()
segment.get_unit_spike_train_in_seconds = lambda unit_id, start_time, end_time: original_times

sorting.shift_times(shift=5.0)
spike_times = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
assert np.allclose(spike_times, original_times + 5.0)


class TestSortingTimeWithRecording:
"""
Expand All @@ -481,17 +521,16 @@ def test_get_start_end_time(self):
assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0)

def test_register_recording_copies_start_times(self):
"""Registering a recording copies its start times into the sorting segments."""
sorting = generate_sorting(num_units=5, durations=[10])
sorting.segments[0]._t_start = 100.0
"""Registering a recording overrides any pre-existing sorting start time."""
sorting = generate_sorting(num_units=5, durations=[10], t_starts=[100.0])

recording = generate_recording(num_channels=4, durations=[10])
recording.shift_times(shift=50.0)
sorting.register_recording(recording)

# _t_start now mirrors the recording's start time, preserving it across
# save/load cycles even when the recording is not attached.
assert sorting.segments[0]._t_start == recording.get_start_time(segment_index=0)
# The sorting's start time now mirrors the recording's start time, preserving it
# across save/load cycles even when the recording is later detached.
assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)
assert sorting.get_start_time(segment_index=0) == 50.0

def test_with_recording_shifted_start(self):
Expand All @@ -503,3 +542,68 @@ def test_with_recording_shifted_start(self):
sorting.register_recording(recording)

assert sorting.get_start_time(segment_index=0) == 50.0

def test_shift_times(self):
recording = generate_recording(num_channels=4, durations=[10])
sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)
unit_id = sorting.unit_ids[0]

rec_start_before = recording.get_start_time(segment_index=0)
rec_end_before = recording.get_end_time(segment_index=0)
spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)

sorting.shift_times(shift=5.0)

# The recording should be untouched
assert recording.get_start_time(segment_index=0) == rec_start_before
assert recording.get_end_time(segment_index=0) == rec_end_before

# The sorting's times should be shifted
assert sorting.get_start_time(segment_index=0) == rec_start_before + 5.0
assert sorting.get_end_time(segment_index=0) == rec_end_before + 5.0
spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
assert np.allclose(spike_times_after, spike_times_before + 5.0)

def test_time_conversion_roundtrip_after_shift(self):
"""sample_index_to_time and time_to_sample_index must remain inverses after a shift."""
recording = generate_recording(num_channels=4, durations=[10])
sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)

sorting.shift_times(shift=5.0)

# Frame 30000 is 1.0s in the recording. After a 5.0s shift, the sorting should report 6.0s.
time = sorting.sample_index_to_time(30000, segment_index=0)
assert time == recording.sample_index_to_time(30000, segment_index=0) + 5.0

# The inverse: 6.0s in the sorting should map back to frame 30000.
frame = sorting.time_to_sample_index(time, segment_index=0)
assert frame == 30000

def test_shift_times_with_time_vector(self):
"""Shift on sorting composes with a recording that has an explicit time vector,
preserving the irregular spacing."""
recording = generate_recording(num_channels=4, durations=[1.0])
num_samples = recording.get_num_samples(segment_index=0)
# Irregular timestamps starting at 100.0
times = (
100.0
+ np.cumsum(np.random.RandomState(0).uniform(0.5, 1.5, num_samples)) / recording.get_sampling_frequency()
)
recording.set_times(times, segment_index=0, with_warning=False)

sorting = generate_sorting(num_units=5, durations=[1.0])
sorting.register_recording(recording)
unit_id = sorting.unit_ids[0]

spike_times_before = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)

sorting.shift_times(shift=5.0)

spike_times_after = sorting.get_unit_spike_train(unit_id, segment_index=0, return_times=True)
# Irregular spacing preserved, everything shifted by 5.0
assert np.allclose(spike_times_after, spike_times_before + 5.0)

# Recording is untouched
assert np.allclose(recording.get_times(segment_index=0), times)
Original file line number Diff line number Diff line change
Expand Up @@ -618,11 +618,10 @@ def __init__(
sampling_frequency,
neo_returns_frames,
):
BaseSortingSegment.__init__(self)
BaseSortingSegment.__init__(self, t_start=t_start)
self.neo_reader = neo_reader
self.segment_index = segment_index
self.block_index = block_index
self._t_start = t_start
self._sampling_frequency = sampling_frequency
self.neo_returns_frames = neo_returns_frames

Expand Down
Loading
Loading