|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Build EMG Gesture Classification Dataset from Open Ephys Data. |
| 4 | +
|
| 5 | +Walks a directory structure, finds Open Ephys recordings (structure.oebin), |
| 6 | +extracts EMG features (RMS/MAV), loads labels, and saves a consolidated .npz dataset. |
| 7 | +""" |
| 8 | + |
| 9 | +import os |
| 10 | +import argparse |
| 11 | +import numpy as np |
| 12 | +import json |
| 13 | +import scipy.signal |
| 14 | + |
| 15 | +from pyoephys.io import load_open_ephys_session |
| 16 | +from pyoephys.processing import normalize_emg, butter_bandpass_filter |
| 17 | + |
| 18 | +def extract_windows(signal, fs, window_ms, step_ms): |
| 19 | + """Simple sliding window feature extraction (RMS).""" |
| 20 | + window_samples = int(fs * window_ms / 1000) |
| 21 | + step_samples = int(fs * step_ms / 1000) |
| 22 | + |
| 23 | + n_samples, n_channels = signal.shape |
| 24 | + n_windows = (n_samples - window_samples) // step_samples + 1 |
| 25 | + |
| 26 | + if n_windows <= 0: |
| 27 | + return np.empty((0, n_channels)), np.empty((0,)) |
| 28 | + |
| 29 | + # Shape: (n_windows, n_channels) |
| 30 | + # Using RMS for simplicity |
| 31 | + features = np.zeros((n_windows, n_channels), dtype=np.float32) |
| 32 | + |
| 33 | + # Timestamps (center of window) |
| 34 | + window_centers = np.zeros(n_windows) |
| 35 | + |
| 36 | + for i in range(n_windows): |
| 37 | + start = i * step_samples |
| 38 | + end = start + window_samples |
| 39 | + chunk = signal[start:end, :] |
| 40 | + |
| 41 | + # RMS |
| 42 | + rms = np.sqrt(np.mean(chunk**2, axis=0)) |
| 43 | + features[i, :] = rms |
| 44 | + window_centers[i] = (start + end) / 2 / fs |
| 45 | + |
| 46 | + return features, window_centers |
| 47 | + |
| 48 | +def load_labels(events_path, window_centers, duration_sec): |
| 49 | + """ |
| 50 | + Load labels from an explicit events file or inference logic. |
| 51 | + For now, this is a placeholder or simple logic (e.g. filename based). |
| 52 | + """ |
| 53 | + # TODO: Implement proper event loading from Open Ephys events or CSV |
| 54 | + # For now, return dummy class '0' or based on folder name |
| 55 | + return np.zeros(len(window_centers), dtype=int) |
| 56 | + |
| 57 | + |
| 58 | +def build_dataset(root_dir, save_path, window_ms=200, step_ms=50): |
| 59 | + print(f"Scanning {root_dir} for Open Ephys recordings...") |
| 60 | + |
| 61 | + X_list = [] |
| 62 | + y_list = [] |
| 63 | + |
| 64 | + # Walk directory |
| 65 | + for root, dirs, files in os.walk(root_dir): |
| 66 | + if "structure.oebin" in files: |
| 67 | + print(f"Processing: {root}") |
| 68 | + try: |
| 69 | + # Load data |
| 70 | + session = load_open_ephys_session(root) |
| 71 | + # Assuming first continuous stream is EMG |
| 72 | + # In robust code, we'd select by name or metadata |
| 73 | + rec = session.recordnodes[0].recordings[0] |
| 74 | + data = rec.continuous[0].samples |
| 75 | + fs = rec.continuous[0].metadata['sample_rate'] |
| 76 | + |
| 77 | + # Preprocess |
| 78 | + # Highpass to remove DC drift |
| 79 | + data_filt = butter_bandpass_filter(data.T, 20, fs/2 - 1, fs, order=4).T |
| 80 | + |
| 81 | + # Normalize (Z-score or similar) - maybe per file? |
| 82 | + # data_norm = normalize_emg(data_filt) |
| 83 | + |
| 84 | + # Extract Features |
| 85 | + feats, centers = extract_windows(data_filt, fs, window_ms, step_ms) |
| 86 | + |
| 87 | + if len(feats) > 0: |
| 88 | + # Labeling logic |
| 89 | + # Try to deduce label from folder name (e.g., 'fist', 'rest') |
| 90 | + folder_name = os.path.basename(root).lower() |
| 91 | + label = folder_name # string label for now |
| 92 | + |
| 93 | + labels = np.full(len(feats), label) |
| 94 | + |
| 95 | + X_list.append(feats) |
| 96 | + y_list.append(labels) |
| 97 | + print(f" -> Added {len(feats)} windows, Label: {label}") |
| 98 | + |
| 99 | + except Exception as e: |
| 100 | + print(f" Failed: {e}") |
| 101 | + |
| 102 | + if not X_list: |
| 103 | + print("No data found!") |
| 104 | + return |
| 105 | + |
| 106 | + X = np.concatenate(X_list, axis=0) |
| 107 | + y = np.concatenate(y_list, axis=0) |
| 108 | + |
| 109 | + # Convert string labels to int if needed, or keep as strings |
| 110 | + # The existing trainer expects somewhat formatted labels |
| 111 | + |
| 112 | + print(f"Saving {X.shape[0]} samples to {save_path}") |
| 113 | + np.savez(save_path, X=X, y=y, fs=fs, window_ms=window_ms, step_ms=step_ms) |
| 114 | + |
| 115 | + |
| 116 | +if __name__ == "__main__": |
| 117 | + p = argparse.ArgumentParser() |
| 118 | + p.add_argument("--root_dir", required=True) |
| 119 | + p.add_argument("--save_path", default="training_dataset.npz") |
| 120 | + p.add_argument("--window_ms", type=int, default=200) |
| 121 | + p.add_argument("--step_ms", type=int, default=50) |
| 122 | + args = p.parse_args() |
| 123 | + |
| 124 | + build_dataset(args.root_dir, args.save_path, args.window_ms, args.step_ms) |
0 commit comments