Skip to content

Commit 85fb14a

Browse files
committed
Refine Examples: Add 1_build_dataset.py and standardize Gesture Classifier pipeline
1 parent add3360 commit 85fb14a

5 files changed

Lines changed: 175 additions & 1 deletion

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ print(results) # Status (Good/Bad) per channel
8787

8888
## Examples
8989
Check the `examples/` directory for complete scripts:
90-
- `examples/gesture_classifier/2v2_train_model.py`: Train a gesture classifier.
90+
- `examples/gesture_classifier/2_train_model.py`: Train a gesture classifier.
9191
- `examples/synchronization/sync_multimodal_data.py`: Align EMG with 3D hand landmarks.
9292
- `examples/analysis/run_channel_qc.py`: Run quality control checks.
9393
- `examples/interface/zmq_client.py`: Real-time ZMQ client example.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# EMG Gesture Classifier Configuration (Open Ephys)
2+
# =================================================
3+
4+
# Root directory containing Open Ephys recordings (parent of data subfolders)
5+
root_dir=./data
6+
7+
# Model label (used for naming output files, e.g., 'model/128ch_model.pkl')
8+
label=demo_model
9+
10+
# Enable verbose logging
11+
verbose=true
12+
13+
# ============================================================================
14+
# Dataset Building (1_build_dataset.py)
15+
# ============================================================================
16+
17+
# Output path for the training dataset (.npz)
18+
# If not specified, defaults to {root_dir}/training_dataset.npz
19+
save_path=./data/training_dataset.npz
20+
21+
# Feature extraction settings
22+
window_ms=200
23+
step_ms=50
24+
25+
# ============================================================================
26+
# Training (2_train_model.py)
27+
# ============================================================================
28+
29+
# Number of training epochs
30+
epochs=100
31+
32+
# Batch size
33+
batch_size=32
34+
35+
# Learning rate
36+
learning_rate=0.001
37+
38+
# Use K-Fold Cross Validation? (true/false)
39+
kfold=true
40+
41+
# ============================================================================
42+
# Prediction (3_predict_realtime.py)
43+
# ============================================================================
44+
45+
# ZMQ connection
46+
zmq_ip=tcp://127.0.0.1
47+
zmq_port=5556
48+
49+
# Inference Smoothing
50+
smooth_k=5
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Build EMG Gesture Classification Dataset from Open Ephys Data.
4+
5+
Walks a directory structure, finds Open Ephys recordings (structure.oebin),
6+
extracts EMG features (RMS/MAV), loads labels, and saves a consolidated .npz dataset.
7+
"""
8+
9+
import os
10+
import argparse
11+
import numpy as np
12+
import json
13+
import scipy.signal
14+
15+
from pyoephys.io import load_open_ephys_session
16+
from pyoephys.processing import normalize_emg, butter_bandpass_filter
17+
18+
def extract_windows(signal, fs, window_ms, step_ms):
19+
"""Simple sliding window feature extraction (RMS)."""
20+
window_samples = int(fs * window_ms / 1000)
21+
step_samples = int(fs * step_ms / 1000)
22+
23+
n_samples, n_channels = signal.shape
24+
n_windows = (n_samples - window_samples) // step_samples + 1
25+
26+
if n_windows <= 0:
27+
return np.empty((0, n_channels)), np.empty((0,))
28+
29+
# Shape: (n_windows, n_channels)
30+
# Using RMS for simplicity
31+
features = np.zeros((n_windows, n_channels), dtype=np.float32)
32+
33+
# Timestamps (center of window)
34+
window_centers = np.zeros(n_windows)
35+
36+
for i in range(n_windows):
37+
start = i * step_samples
38+
end = start + window_samples
39+
chunk = signal[start:end, :]
40+
41+
# RMS
42+
rms = np.sqrt(np.mean(chunk**2, axis=0))
43+
features[i, :] = rms
44+
window_centers[i] = (start + end) / 2 / fs
45+
46+
return features, window_centers
47+
48+
def load_labels(events_path, window_centers, duration_sec):
49+
"""
50+
Load labels from an explicit events file or inference logic.
51+
For now, this is a placeholder or simple logic (e.g. filename based).
52+
"""
53+
# TODO: Implement proper event loading from Open Ephys events or CSV
54+
# For now, return dummy class '0' or based on folder name
55+
return np.zeros(len(window_centers), dtype=int)
56+
57+
58+
def build_dataset(root_dir, save_path, window_ms=200, step_ms=50):
59+
print(f"Scanning {root_dir} for Open Ephys recordings...")
60+
61+
X_list = []
62+
y_list = []
63+
64+
# Walk directory
65+
for root, dirs, files in os.walk(root_dir):
66+
if "structure.oebin" in files:
67+
print(f"Processing: {root}")
68+
try:
69+
# Load data
70+
session = load_open_ephys_session(root)
71+
# Assuming first continuous stream is EMG
72+
# In robust code, we'd select by name or metadata
73+
rec = session.recordnodes[0].recordings[0]
74+
data = rec.continuous[0].samples
75+
fs = rec.continuous[0].metadata['sample_rate']
76+
77+
# Preprocess
78+
# Highpass to remove DC drift
79+
data_filt = butter_bandpass_filter(data.T, 20, fs/2 - 1, fs, order=4).T
80+
81+
# Normalize (Z-score or similar) - maybe per file?
82+
# data_norm = normalize_emg(data_filt)
83+
84+
# Extract Features
85+
feats, centers = extract_windows(data_filt, fs, window_ms, step_ms)
86+
87+
if len(feats) > 0:
88+
# Labeling logic
89+
# Try to deduce label from folder name (e.g., 'fist', 'rest')
90+
folder_name = os.path.basename(root).lower()
91+
label = folder_name # string label for now
92+
93+
labels = np.full(len(feats), label)
94+
95+
X_list.append(feats)
96+
y_list.append(labels)
97+
print(f" -> Added {len(feats)} windows, Label: {label}")
98+
99+
except Exception as e:
100+
print(f" Failed: {e}")
101+
102+
if not X_list:
103+
print("No data found!")
104+
return
105+
106+
X = np.concatenate(X_list, axis=0)
107+
y = np.concatenate(y_list, axis=0)
108+
109+
# Convert string labels to int if needed, or keep as strings
110+
# The existing trainer expects somewhat formatted labels
111+
112+
print(f"Saving {X.shape[0]} samples to {save_path}")
113+
np.savez(save_path, X=X, y=y, fs=fs, window_ms=window_ms, step_ms=step_ms)
114+
115+
116+
if __name__ == "__main__":
117+
p = argparse.ArgumentParser()
118+
p.add_argument("--root_dir", required=True)
119+
p.add_argument("--save_path", default="training_dataset.npz")
120+
p.add_argument("--window_ms", type=int, default=200)
121+
p.add_argument("--step_ms", type=int, default=50)
122+
args = p.parse_args()
123+
124+
build_dataset(args.root_dir, args.save_path, args.window_ms, args.step_ms)
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)