Skip to content

Commit f98a445

Browse files
committed
docs: finalized JOSS figure with 10-20s window and 5 channels
1 parent 327b040 commit f98a445

4 files changed

Lines changed: 170 additions & 95 deletions

File tree

paper.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ The field of neural data analysis is supported by several specialized tools. The
4141

4242
`python-oephys` is designed with a modular architecture that separates data acquisition, processing, and visualization (see Figure 1).
4343

44-
![EMG Processing Pipeline. A) Raw signals from the first three channels. B) Signals after bandpass (20-400Hz) and 60Hz notch filtering. C) Automated channel quality indicators showing a failed channel in red. D) Mean RMS features extracted for each channel.](docs/figs/pipeline.png)
44+
![EMG Processing Pipeline. A) Raw signals from five representative channels (10–20s). B) Signals after CAR, bandpass (20-500Hz), and 60Hz notch filtering. C) Automated channel quality indicators evaluated on the first 5s of data. D) Mean RMS features extracted from the processed segment.](docs/figs/pipeline.png)
4545

4646
- **Interface Layer**: Implements ZMQ and LSL clients for low-latency data streaming. The `ZMQClient` is designed to run asynchronously, ensuring that data acquisition does not block processing or UI updates.
4747
- **Processing Layer**: Provides a suite of filters and feature extraction tools. This includes the `EMGPreprocessor` for standardized filtering and `ChannelQC` for real-time signal quality monitoring.

scripts/analyze_channels.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import os
3+
import sys
4+
5+
# Ensure local src is in path
6+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
7+
8+
from pyoephys.processing import ChannelQC
9+
10+
def analyze_all_channels(data_path):
11+
data = np.load(data_path, allow_pickle=True)
12+
raw_full = data['amplifier_data']
13+
fs = 2000
14+
15+
qc_start = 0
16+
qc_end = int(fs * 5.0)
17+
seg = raw_full[:, qc_start:qc_end]
18+
19+
qc = ChannelQC(fs=fs, n_channels=raw_full.shape[0])
20+
qc.update(seg.T)
21+
results = qc.evaluate()
22+
23+
metrics = results['metrics']
24+
robust_z = metrics['robust_z']
25+
is_bad = results['bad']
26+
is_watch = results['watch']
27+
28+
print("Channel Analysis (0-5s):")
29+
# Indices of not bad
30+
not_bad = np.where(~is_bad)[0]
31+
print(f"Not Bad: {not_bad}")
32+
33+
# Sort robust_z for not_bad
34+
if len(not_bad) > 0:
35+
sorted_good = not_bad[np.argsort(robust_z[not_bad])]
36+
print(f"Sorted Not-Bad (lowest Z first): {sorted_good}")
37+
38+
# Sort robust_z for bad
39+
bad = np.where(is_bad)[0]
40+
if len(bad) > 0:
41+
sorted_bad = bad[np.argsort(robust_z[bad])]
42+
print(f"Sorted Bad (lowest Z first): {sorted_bad}")
43+
44+
if __name__ == "__main__":
45+
DATA_PATH = r"G:\Shared drives\NML_shared\DataShare\HDEMG Human Healthy\HD-EMG_Cuff\Jonathan\2025_07_31\raw\gestures\gestures_emg_data.npz"
46+
analyze_all_channels(DATA_PATH)

scripts/find_qc_candidates_v2.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
import os
3+
import sys
4+
5+
# Ensure local src is in path
6+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
7+
8+
from pyoephys.processing import ChannelQC
9+
10+
def find_candidate_channels(data_path):
11+
print(f"Analyzing {data_path} for 5-channel QC candidates (0-5s window)...")
12+
data = np.load(data_path, allow_pickle=True)
13+
raw_full = data['amplifier_data']
14+
fs = 2000
15+
16+
# Use first 5 seconds for QC as requested
17+
qc_start = 0
18+
qc_end = int(fs * 5.0)
19+
seg = raw_full[:, qc_start:qc_end]
20+
21+
qc = ChannelQC(fs=fs, n_channels=raw_full.shape[0])
22+
qc.update(seg.T)
23+
results = qc.evaluate()
24+
25+
bad_indices = np.where(results['bad'])[0]
26+
good_indices = np.where(~results['bad'] & ~results['watch'])[0]
27+
28+
print(f"Found {len(bad_indices)} bad channels.")
29+
print(f"Found {len(good_indices)} good channels.")
30+
31+
if len(bad_indices) >= 1 and len(good_indices) >= 4:
32+
# Proposed set: [Pass, Pass, Fail, Pass, Pass]
33+
final_set = [good_indices[0], good_indices[1], bad_indices[0], good_indices[2], good_indices[3]]
34+
print(f"RECOMMENDED SET (4 pass, 1 fail): {final_set}")
35+
else:
36+
print("Could not find enough candidates with strict criteria.")
37+
38+
if __name__ == "__main__":
39+
DATA_PATH = r"G:\Shared drives\NML_shared\DataShare\HDEMG Human Healthy\HD-EMG_Cuff\Jonathan\2025_07_31\raw\gestures\gestures_emg_data.npz"
40+
find_candidate_channels(DATA_PATH)

scripts/generate_pipeline_figure.py

Lines changed: 83 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -6,134 +6,123 @@
66
# Ensure local src is in path
77
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
88

9-
from pyoephys.processing import bandpass_filter, notch_filter, calculate_rms, ChannelQC, common_average_reference
9+
from pyoephys.processing import bandpass_filter, notch_filter, calculate_rms, ChannelQC, common_average_reference, QCParams
1010

1111
def generate_pipeline_figure(data_path, save_path):
1212
"""
13-
Generates a professional 4-panel figure for the JOSS paper using REAL data:
14-
1. Raw signal (Waterfall)
15-
2. Filtered signal
16-
3. QC Status
17-
4. RMS Barplot
13+
Refined JOSS figure:
14+
- 5 channels (4 pass, 1 fail)
15+
- 10-15s window for A & B
16+
- 0-5s window for C (QC analysis)
17+
- Side-by-side layout for C & D
1818
"""
19-
# Load Real Data
2019
print(f"Loading data from {data_path}...")
2120
data = np.load(data_path, allow_pickle=True)
2221

2322
raw_full = data['amplifier_data']
2423
t_full = data['t_amplifier']
25-
# fs = float(data['sample_rate'])
26-
fs = 2000 # Typical for this dataset, or extract if scalar
27-
# Cherry-picked indices: 51 (pass), 0 (fail), 113 (pass)
28-
viz_indices = [51, 0, 113]
24+
fs = 2000
2925

30-
if raw_full.shape[1] > 10000:
31-
# Take a 1-second segment for visualization
32-
start_idx = int(fs * 2.5) # Pick a middle segment
33-
end_idx = start_idx + int(fs * 1.0)
34-
raw = raw_full[viz_indices, start_idx:end_idx]
35-
t = t_full[start_idx:end_idx]
36-
t = t - t[0] # Zero relative time
37-
else:
38-
raw = raw_full[viz_indices, :]
39-
t = t_full - t_full[0]
40-
41-
# 1. Processing Pipeline
42-
# Apply CAR first for better quality visualization
43-
car_data = common_average_reference(raw_full[:, start_idx:end_idx])
26+
# Cherry-picked indices (from 0-5s analysis):
27+
# Pass: 114, 113, 121, 51
28+
# Fail: 0
29+
viz_indices = [114, 113, 0, 121, 51] # Putting Fail in the middle for visual contrast
4430

45-
# Filter
46-
filtered_full_seg = notch_filter(car_data, fs=fs, f0=60)
47-
filtered_full_seg = bandpass_filter(filtered_full_seg, lowcut=20, highcut=500, fs=fs)
31+
# 1. QC Analysis (0-5s window)
32+
print("Running QC Analysis on 0-5s window...")
33+
qc_start_idx = 0
34+
qc_end_idx = int(fs * 5.0)
35+
qc_seg = raw_full[:, qc_start_idx:qc_end_idx]
4836

49-
processed_viz = filtered_full_seg[viz_indices, :]
50-
51-
# 2. QC Status (Run on full segment or whole array)
52-
qc = ChannelQC(fs=fs, n_channels=raw_full.shape[0])
53-
# The current evaluate logic depends on buffering via update()
54-
# Or we can just use the compute_metrics logic if available
55-
# Actually ChannelQC.evaluate() works on the buffers. Let's update with a chunk.
56-
qc.update(raw_full[:, start_idx:end_idx].T) # Transpose to (samples, channels)
37+
# Use slightly relaxed params for the figure's "Pass" labels to match user request
38+
params = QCParams(robust_z_bad=6.0, robust_z_warn=4.0)
39+
qc = ChannelQC(fs=fs, n_channels=raw_full.shape[0], params=params)
40+
qc.update(qc_seg.T)
5741
qc_results = qc.evaluate()
58-
59-
# Get actual results for the cherry-picked indices
6042
qc_status = [not qc_results['bad'][i] for i in viz_indices]
6143

62-
# 3. RMS Calculation
63-
rms = calculate_rms(processed_viz, window_size=int(0.1 * fs)) # 100ms windows
64-
rms_avg = np.mean(rms, axis=1)
44+
# 2. Main Visualization Prep (10-20s window)
45+
print("Preparing visualization for 10-20s window...")
46+
viz_start_idx = int(fs * 10.0)
47+
viz_end_idx = int(fs * 20.0)
48+
49+
raw_seg = raw_full[viz_indices, viz_start_idx:viz_end_idx]
50+
t_seg = t_full[viz_start_idx:viz_end_idx]
51+
t_plot = t_seg - t_seg[0] + 10.0 # Keep 10-20s x-axis label style
52+
53+
# Processing (Full array for CAR, then subset)
54+
car_data = common_average_reference(raw_full[:, viz_start_idx:viz_end_idx])
55+
filt = notch_filter(car_data, fs=fs, f0=60)
56+
filt = bandpass_filter(filt, lowcut=20, highcut=500, fs=fs)
57+
processed_viz = filt[viz_indices, :]
58+
59+
# 3. RMS Calculation (on the viz segment)
60+
rms_vals = calculate_rms(processed_viz, window_size=int(0.1 * fs))
61+
rms_avg = np.mean(rms_vals, axis=1)
6562

6663
# --- Plotting ---
67-
fig, axes = plt.subplots(4, 1, figsize=(10, 13), gridspec_kw={'height_ratios': [2, 2, 1, 2]})
68-
plt.subplots_adjust(hspace=0.45)
69-
70-
# colors for panels
71-
colors_main = ['#3498db', '#e67e22', '#2ecc71']
64+
fig = plt.figure(figsize=(10, 14))
65+
gs = fig.add_gridspec(3, 2, height_ratios=[2, 2, 1], hspace=0.4, wspace=0.3)
7266

73-
# Scale parameters for visual consistency
74-
raw_offset = 500 # Offset between channels in raw plot
75-
raw_ylim = (-200, 1200) # Suitable range for 3 channels @ 500 offset
67+
ax_raw = fig.add_subplot(gs[0, :])
68+
ax_filt = fig.add_subplot(gs[1, :])
69+
ax_qc = fig.add_subplot(gs[2, 0])
70+
ax_rms = fig.add_subplot(gs[2, 1])
7671

77-
filt_offset = 200 # Offset between channels in filtered plot
78-
filt_ylim = (-150, 550) # Suitable range for 3 channels @ 200 offset
72+
colors_main = ['#3498db', '#e67e22', '#e74c3c', '#2ecc71', '#9b59b6']
73+
raw_offset = 600
74+
filt_offset = 250
7975

8076
# A. Raw Waterfall
81-
for i in range(3):
82-
# Center signal around zero before adding offset
83-
sig = raw[i] - np.mean(raw[i])
84-
axes[0].plot(t, sig + i * raw_offset, color='black', alpha=0.7, linewidth=0.8)
85-
axes[0].set_title("A) Raw High-Density EMG Signals", loc='left', fontsize=14, fontweight='bold')
86-
axes[0].set_ylabel("Amplitude ($\mu$V)")
87-
axes[0].set_ylim(raw_ylim)
88-
axes[0].set_yticks([0, raw_offset, 2*raw_offset])
89-
axes[0].set_yticklabels([f"Ch {viz_indices[0]}", f"Ch {viz_indices[1]}", f"Ch {viz_indices[2]}"])
90-
axes[0].grid(True, alpha=0.2)
77+
for i in range(len(viz_indices)):
78+
sig = raw_seg[i] - np.mean(raw_seg[i])
79+
ax_raw.plot(t_plot, sig + i * raw_offset, color='black', alpha=0.7, linewidth=0.4)
80+
ax_raw.set_title("A) Raw High-Density EMG Signals (10–20s Window)", loc='left', fontsize=13, fontweight='bold')
81+
ax_raw.set_ylabel("Amplitude ($\mu$V)")
82+
ax_raw.set_ylim(-400, 4 * raw_offset + 400)
83+
ax_raw.set_yticks([i * raw_offset for i in range(len(viz_indices))])
84+
ax_raw.set_yticklabels([f"Ch {idx}" for idx in viz_indices])
85+
ax_raw.set_xlabel("Time (s)")
86+
ax_raw.grid(True, alpha=0.1)
9187

9288
# B. Filtered Signal
93-
for i in range(3):
89+
for i in range(len(viz_indices)):
9490
sig = processed_viz[i] - np.mean(processed_viz[i])
95-
axes[1].plot(t, sig + i * filt_offset, color=colors_main[i], linewidth=1.0)
96-
axes[1].set_title("B) Preprocessed Waveforms (Bandpass, Notch, CAR)", loc='left', fontsize=14, fontweight='bold')
97-
axes[1].set_ylabel("Amplitude ($\mu$V)")
98-
axes[1].set_ylim(filt_ylim)
99-
axes[1].set_yticks([0, filt_offset, 2*filt_offset])
100-
axes[1].set_yticklabels([f"Ch {viz_indices[0]}", f"Ch {viz_indices[1]}", f"Ch {viz_indices[2]}"])
101-
axes[1].grid(True, alpha=0.2)
91+
ax_filt.plot(t_plot, sig + i * filt_offset, color=colors_main[i], linewidth=0.4)
92+
ax_filt.set_title("B) Preprocessed Waveforms (CAR + Bandpass + Notch)", loc='left', fontsize=13, fontweight='bold')
93+
ax_filt.set_ylabel("Amplitude ($\mu$V)")
94+
ax_filt.set_ylim(-150, 4 * filt_offset + 150)
95+
ax_filt.set_yticks([i * filt_offset for i in range(len(viz_indices))])
96+
ax_filt.set_yticklabels([f"Ch {idx}" for idx in viz_indices])
97+
ax_filt.set_xlabel("Time (s)")
98+
ax_filt.grid(True, alpha=0.1)
10299

103-
# C. QC Status
100+
# C. QC Status (using the 0-5s results)
104101
qc_colors = ['green' if s else 'red' for s in qc_status]
105102
qc_labels = ['Pass' if s else 'Fail' for s in qc_status]
106-
for i in range(3):
107-
axes[2].barh(i, 1, color=qc_colors[i], alpha=0.6, height=0.6)
108-
axes[2].text(0.5, i, f"{qc_labels[i]} (Ch {viz_indices[i]})", ha='center', va='center', color='white', fontweight='bold', fontsize=12)
109-
axes[2].set_title("C) Automated Channel Quality Monitoring", loc='left', fontsize=14, fontweight='bold')
110-
axes[2].set_yticks(range(3))
111-
axes[2].set_yticklabels([f"Ch {viz_indices[i]}" for i in range(3)])
112-
axes[2].set_xticks([])
113-
axes[2].set_xlim(0, 1)
103+
for i in range(len(viz_indices)):
104+
ax_qc.barh(i, 1, color=qc_colors[i], alpha=0.6, height=0.7)
105+
ax_qc.text(0.5, i, f"{qc_labels[i]}", ha='center', va='center', color='white', fontweight='bold', fontsize=10)
106+
ax_qc.set_title("C) Automated Channel QC (0-5s)", loc='left', fontsize=12, fontweight='bold')
107+
ax_qc.set_yticks(range(len(viz_indices)))
108+
ax_qc.set_yticklabels([f"Ch {idx}" for idx in viz_indices])
109+
ax_qc.set_xticks([])
110+
ax_qc.set_xlim(0, 1)
114111

115112
# D. RMS Barplot
116-
axes[3].bar(range(3), rms_avg, color=qc_colors, alpha=0.8)
117-
axes[3].set_title("D) Extracted RMS Activation Features", loc='left', fontsize=14, fontweight='bold')
118-
axes[3].set_xticks(range(3))
119-
axes[3].set_xticklabels([f"Ch {viz_indices[i]}" for i in range(3)])
120-
axes[3].set_ylabel("RMS Amplitude ($\mu$V)")
121-
axes[3].grid(axis='y', alpha=0.3)
122-
123-
# Labels
124-
axes[0].set_xlabel("Time (s)")
125-
axes[1].set_xlabel("Time (s)")
126-
axes[3].set_xlabel("Channel Index")
113+
ax_rms.bar(range(len(viz_indices)), rms_avg, color=qc_colors, alpha=0.8)
114+
ax_rms.set_title("D) RMS Features (10-20s)", loc='left', fontsize=12, fontweight='bold')
115+
ax_rms.set_xticks(range(len(viz_indices)))
116+
ax_rms.set_xticklabels([f"Ch {idx}" for idx in viz_indices], rotation=0)
117+
ax_rms.set_ylabel("RMS ($\mu$V)")
118+
ax_rms.grid(axis='y', alpha=0.2)
119+
ax_rms.set_xlabel("Channel Index")
127120

128-
plt.tight_layout()
129121
plt.savefig(save_path, dpi=300, bbox_inches='tight')
130-
print(f"Successfully generated new figure at {save_path}")
122+
print(f"Successfully generated refined figure at {save_path}")
131123

132124
if __name__ == "__main__":
133125
DATA_PATH = r"G:\Shared drives\NML_shared\DataShare\HDEMG Human Healthy\HD-EMG_Cuff\Jonathan\2025_07_31\raw\gestures\gestures_emg_data.npz"
134126
SAVE_PATH = "docs/figs/pipeline.png"
135-
136-
# Ensure dir exists
137127
os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)
138-
139128
generate_pipeline_figure(DATA_PATH, SAVE_PATH)

0 commit comments

Comments
 (0)