Skip to content

Commit 75912a0

Browse files
committed
gesture classifier: quiet-segment channel QC, ChannelQC class, label auto-discovery docs
Channel QC overhaul (1_build_dataset.py): - Replace ChannelQC/assess_channel_quality with direct scipy bandpass+notch - Quiet-segment method: rank 500ms windows by global RMS, use bottom 20% to assess per-channel noise floor during rest - Dead threshold: filtered RMS < 0.5 uV - Noise threshold: filtered RMS > --noise_threshold (default 30 uV) - New --qc_quiet_sec START END flag to pin an explicit rest window - New --noise_threshold CLI flag Label auto-discovery docs: - Document all accepted label filenames (emg.txt, labels.csv, events.csv, labels.txt, {recording}_emg.txt) in module docstring, param docstring, --labels_path help, and README - Add docstring to find_event_for_file() in _file_utils.py run_channel_qc.py example: fully rewritten to use correct ChannelQC API (update/evaluate), demonstrating both realtime and batch patterns README (gesture_classifier): updated Channel QC section, parameters table, label auto-discovery table, Option B examples Core fixes (earlier in session): - _session_loader: fix _extract_bitvolts multiplying by 1e6 when units=uV - _zmq_client: get_latest_window returns data [0] not timestamps [1] - 3_predict_realtime: N_channels lazy-init workaround, channels=[] normalisation - predict.py: metadata.json fallback path, manager_root derivation
1 parent 4525770 commit 75912a0

20 files changed

Lines changed: 2713 additions & 930 deletions
Lines changed: 137 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,144 @@
11
"""
22
Channel Quality Control Example
3-
-------------------------------
4-
Demonstrates how to use the ChannelQC class to identify bad channels
5-
based on line noise, impedance (if available), and signal artifacts.
3+
--------------------------------
4+
Demonstrates two usage patterns of ChannelQC + QCParams:
5+
6+
1. Realtime (streaming) — update() on each incoming chunk, evaluate() each window.
7+
2. Batch (offline) — slide over a full recording; flag channels bad if
8+
>50 % of evaluations mark them bad. Flat / zero-crossing
9+
checks are disabled because quiet rest-period windows are
10+
healthy, not stuck.
11+
12+
API summary
13+
-----------
14+
qc = ChannelQC(fs, n_channels, window_sec, params=QCParams(...))
15+
qc.update(chunk) # chunk: (samples, n_channels), any number of rows
16+
out = qc.evaluate()
17+
# out['bad'] – np.bool_[n_channels] single-eval verdict
18+
# out['excluded'] – set(int) hysteresis-stabilised bad set
19+
# out['metrics'] – dict with 'rms', 'robust_z', 'pl_ratio', etc.
620
"""
721

822
import numpy as np
9-
from pyoephys.processing import ChannelQC
10-
11-
def generate_bad_data(n_channels=8, n_samples=10000):
12-
"""
13-
Generates data where:
14-
- Ch 0 is good
15-
- Ch 1 is railed (saturated)
16-
- Ch 2 has high 60Hz noise
17-
- Ch 3 is disconnected (approx zero)
18-
"""
19-
t = np.linspace(0, 5, n_samples)
20-
data = np.random.randn(n_channels, n_samples) * 10 # Good baseline (10 uV)
21-
22-
# Ch 1: Railed
23-
data[1, :] = 5000.0
24-
25-
# Ch 2: 60Hz Noise
26-
data[2, :] += 500 * np.sin(2 * np.pi * 60 * t)
27-
28-
# Ch 3: Dead/Disconnected
29-
data[3, :] = np.random.randn(n_samples) * 0.1
30-
31-
return data
32-
33-
def main():
34-
fs = 2000.0
35-
print("Generating sample data with 8 channels...")
36-
data = generate_bad_data()
37-
38-
print("Initializing Channel QC...")
39-
qc = ChannelQC(fs=fs)
40-
41-
print("Running QC analysis...")
42-
# Analyze a window
43-
results = qc.compute_qc(data, window_idx=0)
44-
45-
print("\n--- QC Results ---")
46-
print(f"Total Channels: {len(results)}")
47-
48-
for ch, metrics in results.items():
49-
status = "BAD" if not metrics['status'] else "GOOD"
50-
print(f"Channel {ch}: {status}")
51-
if not metrics['status']:
52-
print(f" Issues: {metrics.get('reasons', [])}")
53-
54-
# Visualize if requested (requires matplotlib)
55-
try:
56-
import matplotlib.pyplot as plt
57-
print("\nPlotting results...")
58-
qc.plot_qc_summary(data)
59-
plt.show()
60-
except ImportError:
61-
print("\nMatplotlib not found, skipping plot.")
23+
from pyoephys.processing import ChannelQC, QCParams
24+
25+
26+
# ---------------------------------------------------------------------------
27+
# Synthetic data helpers
28+
# ---------------------------------------------------------------------------
29+
30+
def make_test_recording(n_channels: int = 16, duration_sec: float = 5.0, fs: float = 2000.0):
31+
"""Return (samples, n_channels) array with a few planted bad channels."""
32+
n = int(duration_sec * fs)
33+
t = np.arange(n) / fs
34+
rng = np.random.default_rng(42)
35+
36+
# Good EMG: band-limited noise, ~30 µV RMS
37+
data = rng.normal(0, 30, (n, n_channels)).astype(np.float32)
38+
39+
# Ch 1 – railed / saturated
40+
data[:, 1] = 4500.0
41+
42+
# Ch 2 – heavy 60 Hz powerline
43+
data[:, 2] += 800 * np.sin(2 * np.pi * 60 * t)
44+
45+
# Ch 3 – dead electrode (~0.05 µV RMS)
46+
data[:, 3] = rng.normal(0, 0.05, n)
47+
48+
return data, fs
49+
50+
51+
# ---------------------------------------------------------------------------
52+
# 1. Realtime pattern
53+
# ---------------------------------------------------------------------------
54+
55+
def demo_realtime(data: np.ndarray, fs: float):
56+
"""Process data chunk-by-chunk as it would arrive from hardware."""
57+
n_samples, n_ch = data.shape
58+
chunk_ms = 200
59+
chunk_size = int(fs * chunk_ms / 1000)
60+
61+
params = QCParams(
62+
robust_z_bad=3.0,
63+
pl_ratio_thresh=0.30,
64+
flat_std_min=1.0, # enabled in realtime — sustained flat = stuck ADC
65+
zc_min_hz=3.0,
66+
)
67+
qc = ChannelQC(fs=int(fs), n_channels=n_ch, window_sec=chunk_ms / 1000, params=params)
68+
69+
print("\n=== Realtime pattern ===")
70+
for start in range(0, n_samples - chunk_size, chunk_size):
71+
chunk = data[start : start + chunk_size]
72+
qc.update(chunk)
73+
out = qc.evaluate()
74+
75+
# Final stabilised verdict
76+
excluded = out["excluded"]
77+
m = out["metrics"]
78+
print(f" Excluded channels (hysteresis): {sorted(excluded)}")
79+
print(f" Median RMS: {m['median_rms']:.1f} µV")
80+
for ch in sorted(excluded):
81+
print(f" ch {ch:3d} rms={m['rms'][ch]:.1f} z={m['robust_z'][ch]:.2f} pl={m['pl_ratio'][ch]:.3f}")
82+
83+
84+
# ---------------------------------------------------------------------------
85+
# 2. Batch / offline pattern
86+
# ---------------------------------------------------------------------------
87+
88+
def demo_batch(data: np.ndarray, fs: float):
89+
"""Score every channel over the full recording; vote across windows."""
90+
n_samples, n_ch = data.shape
91+
chunk_size = int(fs * 0.5) # 500 ms windows
92+
93+
# Disable robust Z (HD-EMG has a wide biological RMS spread; active muscle
94+
# channels would be flagged as Z-score outliers at ±3 SD, which is wrong).
95+
# Disable flatline/ZC checks too — quiet rest-period channels are healthy.
96+
# Only powerline ratio + absolute dead-channel threshold are kept.
97+
dead_rms_uv = 0.5
98+
params = QCParams(
99+
robust_z_bad=999.0, # effectively disabled
100+
robust_z_warn=999.0,
101+
pl_ratio_thresh=0.30,
102+
flat_std_min=0.0, # disabled
103+
zc_min_hz=0.0, # disabled
104+
)
105+
qc = ChannelQC(fs=int(fs), n_channels=n_ch, window_sec=0.5, params=params)
106+
107+
bad_vote = np.zeros(n_ch, dtype=int)
108+
n_evals = 0
109+
for start in range(0, n_samples - chunk_size, chunk_size):
110+
qc.update(data[start : start + chunk_size])
111+
out = qc.evaluate()
112+
bad_vote += out["bad"].astype(int)
113+
n_evals += 1
114+
115+
bad_channels = sorted(i for i in range(n_ch) if bad_vote[i] > n_evals * 0.5)
116+
117+
# Also catch truly dead electrodes via absolute RMS threshold
118+
dead = [i for i in range(n_ch) if out["metrics"]["rms"][i] < dead_rms_uv]
119+
bad_channels = sorted(set(bad_channels) | set(dead))
120+
good_channels = sorted(set(range(n_ch)) - set(bad_channels))
121+
m = out["metrics"]
122+
123+
print("\n=== Batch / offline pattern ===")
124+
print(f" {len(good_channels)} good, {len(bad_channels)} bad channels")
125+
print(f" Median RMS: {m['median_rms']:.1f} µV")
126+
print(f" Bad channels: {bad_channels}")
127+
for ch in bad_channels:
128+
print(f" ch {ch:3d} rms={m['rms'][ch]:.1f} z={m['robust_z'][ch]:.2f} pl={m['pl_ratio'][ch]:.3f}")
129+
return good_channels
130+
131+
132+
# ---------------------------------------------------------------------------
133+
# Main
134+
# ---------------------------------------------------------------------------
62135

63136
if __name__ == "__main__":
64-
main()
137+
data, fs = make_test_recording()
138+
print(f"Data shape: {data.shape} (samples × channels), fs={fs:.0f} Hz")
139+
print("Planted bad channels: 1 (railed), 2 (60 Hz), 3 (dead)")
140+
141+
demo_realtime(data, fs)
142+
good = demo_batch(data, fs)
143+
print(f"\nChannels safe for downstream use: {good}")
144+
Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,48 @@
11
# EMG Gesture Classifier Configuration (Open Ephys)
22
# =================================================
3+
# Copy this file to .gesture_config and edit as needed.
4+
# All three pipeline scripts read settings from .gesture_config
5+
# (CLI arguments override these values).
36

4-
# Root directory containing Open Ephys recordings (parent of data subfolders)
5-
root_dir=./data
7+
# ── Shared ───────────────────────────────────────────────────────────────────
68

7-
# Model label (used for naming output files, e.g., 'model/128ch_model.pkl')
8-
label=demo_model
9+
# Path to the EMG recording (CSV, .npz, or Open Ephys folder / .oebin).
10+
# Defaults to the included example data when not set.
11+
data_path=./data/gestures
912

10-
# Enable verbose logging
11-
verbose=true
13+
# Path to the labels/events file (CSV with Sample Index,Label columns).
14+
# Auto-discovered alongside data_path when not set.
15+
labels_path=./data/labels.csv
1216

13-
# ============================================================================
14-
# Dataset Building (1_build_dataset.py)
15-
# ============================================================================
17+
# Directory where the trained model will be saved and loaded from.
18+
# This is the *root* data directory — model files go in a model/ sub-folder.
19+
root_dir=./data/gesture_model
1620

17-
# Output path for the training dataset (.npz)
18-
# If not specified, defaults to {root_dir}/training_dataset.npz
19-
save_path=./data/training_dataset.npz
21+
# Model label / name tag (used for output file naming).
22+
label=gesture_model
2023

21-
# Feature extraction settings
22-
window_ms=200
23-
step_ms=50
24+
# Enable verbose logging.
25+
verbose=false
2426

25-
# ============================================================================
26-
# Training (2_train_model.py)
27-
# ============================================================================
27+
# ── Dataset Building (1_build_dataset.py) ────────────────────────────────────
2828

29-
# Number of training epochs
30-
epochs=100
29+
# Output path for the windowed feature dataset (.npz).
30+
dataset_path=./data/training_dataset.npz
3131

32-
# Batch size
33-
batch_size=32
32+
# Window and step lengths in milliseconds.
33+
window_ms=200
34+
step_ms=50
3435

35-
# Learning rate
36-
learning_rate=0.001
36+
# ── Training (2_train_model.py) ──────────────────────────────────────────────
3737

38-
# Use K-Fold Cross Validation? (true/false)
38+
# Use K-Fold Cross Validation (true/false).
3939
kfold=true
4040

41-
# ============================================================================
42-
# Prediction (3_predict_realtime.py)
43-
# ============================================================================
41+
# ── Real-Time Prediction (3_predict_realtime.py) ─────────────────────────────
4442

45-
# ZMQ connection
46-
zmq_ip=tcp://127.0.0.1
47-
zmq_port=5556
43+
# Open Ephys machine IP and ZMQ data port.
44+
host=127.0.0.1
45+
port=5556
4846

49-
# Inference Smoothing
47+
# Majority-vote smoothing window (number of consecutive predictions).
5048
smooth_k=5

0 commit comments

Comments
 (0)