Skip to content

Commit 81f872e

Browse files
committed
feat: empirical fs measurement + round to nearest standard rate, show header vs measured
1 parent f1ed273 commit 81f872e

1 file changed

Lines changed: 62 additions & 10 deletions

File tree

examples/joint_angle_regression/open_ephys_lsl_streamer.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,43 @@ def __init__(
149149
self.last_chunk = 0
150150
self.last_error = ""
151151
self.detected_fs = 0.0 # filled after connect
152+
self._header_fs = 0.0 # from ZMQ header field
153+
self._measured_fs = 0 # empirical throughput
152154
self._prev_written = 0 # track ref-channel total_samples_written
153155

156+
@staticmethod
157+
def _round_fs(raw: float) -> float:
158+
"""Round a measured fs to the nearest 'standard' rate."""
159+
# Common DAQ rates
160+
standard = [
161+
1000, 1250, 1500, 2000, 2500, 3000, 3333, 4000, 5000,
162+
6250, 8000, 10000, 12500, 15000, 20000, 25000, 30000, 40000, 50000,
163+
]
164+
best = min(standard, key=lambda s: abs(s - raw))
165+
# Only snap if within 10 %
166+
if abs(best - raw) / max(best, 1) < 0.10:
167+
return float(best)
168+
return round(raw)
169+
154170
def _wait_for_channels(self, timeout=3.0):
155-
"""Poll seen_nums until the count stabilises for 0.5 s or *timeout* expires."""
171+
"""Poll seen_nums until the count stabilises for 0.5 s or *timeout* expires.
172+
173+
Also measures empirical sample throughput so the caller can
174+
cross-validate the header-reported ``sample_rate``.
175+
176+
Returns ``(n_channels, measured_fs)`` where *measured_fs* is
177+
samples-per-second computed from ``total_samples_written``.
178+
"""
156179
import time as _t
157180

158181
start = _t.time()
159182
prev_count = 0
160183
stable_since = start
184+
185+
# snapshot sample counter at start
186+
with self.client._lock:
187+
samples_t0 = int(self.client.total_samples_written)
188+
161189
while (_t.time() - start) < timeout:
162190
with self.client._lock:
163191
n = len(self.client.seen_nums)
@@ -167,7 +195,12 @@ def _wait_for_channels(self, timeout=3.0):
167195
elif (_t.time() - stable_since) >= 0.5:
168196
break
169197
_t.sleep(0.05)
170-
return prev_count
198+
199+
elapsed = max(_t.time() - start, 1e-6)
200+
with self.client._lock:
201+
samples_t1 = int(self.client.total_samples_written)
202+
measured_fs = (samples_t1 - samples_t0) / elapsed
203+
return prev_count, measured_fs
171204

172205
def start(self):
173206
if self.running:
@@ -200,7 +233,7 @@ def start(self):
200233
)
201234

202235
# Wait for channel count to stabilise (auto-detect)
203-
n_detected = self._wait_for_channels(timeout=3.0)
236+
n_detected, measured_fs = self._wait_for_channels(timeout=3.0)
204237

205238
with self.client._lock:
206239
detected = sorted(self.client.seen_nums)
@@ -215,14 +248,28 @@ def start(self):
215248
ch_idx = detected[: self.emg_channels]
216249
self.client.set_channel_index(ch_idx)
217250

218-
# Infer sampling rate from the stream (client.fs is updated from ZMQ headers)
219-
client_fs = float(self.client.fs)
220-
if client_fs > 0 and (self.expected_fs <= 0 or self.expected_fs == 5000.0):
221-
self.detected_fs = client_fs
222-
elif self.expected_fs > 0:
251+
# ---- Infer sampling rate ----
252+
# Three possible sources (best → worst):
253+
# 1. User-supplied expected_fs (if > 0)
254+
# 2. Empirical throughput measured during channel stabilisation
255+
# 3. Header-reported sample_rate (client.fs)
256+
# The empirical rate is the most trustworthy when available because
257+
# it reflects actual data throughput rather than a header field that
258+
# some plugins may set incorrectly.
259+
header_fs = float(self.client.fs)
260+
self._header_fs = header_fs
261+
self._measured_fs = round(measured_fs)
262+
263+
if self.expected_fs > 0:
264+
# User explicitly chose a rate – honour it.
223265
self.detected_fs = self.expected_fs
266+
elif measured_fs > 100:
267+
# Round to nearest "nice" rate (multiple of 250 or 1000)
268+
self.detected_fs = self._round_fs(measured_fs)
269+
elif header_fs > 0:
270+
self.detected_fs = header_fs
224271
else:
225-
self.detected_fs = client_fs if client_fs > 0 else 2000.0
272+
self.detected_fs = 2000.0
226273
fs = self.detected_fs
227274
self.emg_outlet, self.imu_outlet = build_outlets(
228275
self.emg_stream_name, self.imu_stream_name, fs, self.emg_channels
@@ -632,17 +679,21 @@ def _tick(self):
632679
info = self.streamer.poll_once()
633680
ch = info["channels"]
634681
fs_str = f"{self.streamer.detected_fs:.0f}" if self.streamer.detected_fs > 0 else "?"
682+
hdr = f"{self.streamer._header_fs:.0f}" if self.streamer._header_fs > 0 else "?"
683+
meas = f"{self.streamer._measured_fs}" if self.streamer._measured_fs > 0 else "?"
635684
self.samples.setText(
636685
f"Samples: {info['total_emg']:,} | chunk ({ch}ch, {info['chunk']}) @ {fs_str} Hz"
637686
)
687+
self.rate.setText(
688+
f"Rate: {info['rate_hz']:.1f} Hz | fs: header={hdr} measured={meas}"
689+
)
638690
if info["chunk"] > 0:
639691
self.emg_stats.setText(
640692
f"EMG RMS: {info['emg_rms']:.3f} | \u03c3: {info['emg_std']:.3f}"
641693
)
642694
self.imu_stats.setText(
643695
f"IMU \u03c3: {info['imu_std']:.3f} | Mag \u03c3: {info['mag_std']:.3f}"
644696
)
645-
self.rate.setText(f"Rate: {info['rate_hz']:.1f} Hz")
646697
except Exception as exc:
647698
self.status.setText(f"Error: {exc}")
648699
self.status.setStyleSheet(
@@ -674,6 +725,7 @@ def run_cli(args):
674725
print(
675726
f"Streaming LSL: EMG='{args.emg_stream_name}', IMU='{args.imu_stream_name}'"
676727
f" | {streamer.emg_channels}ch @ {streamer.detected_fs:.0f} Hz"
728+
f" (header={streamer._header_fs:.0f}, measured={streamer._measured_fs})"
677729
)
678730
try:
679731
while True:

0 commit comments

Comments
 (0)