Skip to content

Commit 327b040

Browse files
committed
docs: standardize y-scales and units in processing pipeline figure
1 parent 15a7389 commit 327b040

3 files changed

Lines changed: 71 additions & 34 deletions

File tree

scripts/find_qc_candidates.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import os
3+
import sys
4+
5+
# Ensure local src is in path
6+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
7+
8+
from pyoephys.processing import ChannelQC
9+
10+
def find_candidate_channels(data_path):
11+
print(f"Analyzing {data_path} for QC candidates...")
12+
data = np.load(data_path, allow_pickle=True)
13+
raw_full = data['amplifier_data']
14+
fs = 2000
15+
16+
# Analyze a representative 1-second segment
17+
start_idx = int(fs * 2.5)
18+
end_idx = start_idx + int(fs * 1.0)
19+
seg = raw_full[:, start_idx:end_idx]
20+
21+
qc = ChannelQC(fs=fs, n_channels=raw_full.shape[0])
22+
qc.update(seg.T)
23+
results = qc.evaluate()
24+
25+
bad_indices = np.where(results['bad'])[0]
26+
good_indices = np.where(~results['bad'] & ~results['watch'])[0]
27+
28+
print(f"Found {len(bad_indices)} bad channels: {bad_indices[:10]}...")
29+
print(f"Found {len(good_indices)} good channels: {good_indices[:10]}...")
30+
31+
if len(bad_indices) > 0 and len(good_indices) >= 2:
32+
print(f"RECOMMENDED: Good=[{good_indices[0]}, {good_indices[1]}], Bad=[{bad_indices[0]}]")
33+
else:
34+
# If no bad ones, find the "worst" good one
35+
metrics = results['metrics']
36+
worst_idx = np.argmax(metrics['robust_z'])
37+
print(f"No naturally 'bad' channels found. Worst Z-score is channel {worst_idx}")
38+
39+
if __name__ == "__main__":
40+
DATA_PATH = r"G:\Shared drives\NML_shared\DataShare\HDEMG Human Healthy\HD-EMG_Cuff\Jonathan\2025_07_31\raw\gestures\gestures_emg_data.npz"
41+
find_candidate_channels(DATA_PATH)

scripts/generate_pipeline_figure.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,18 @@ def generate_pipeline_figure(data_path, save_path):
2424
t_full = data['t_amplifier']
2525
# fs = float(data['sample_rate'])
2626
fs = 2000 # Typical for this dataset, or extract if scalar
27+
# Cherry-picked indices: 51 (pass), 0 (fail), 113 (pass)
28+
viz_indices = [51, 0, 113]
29+
2730
if raw_full.shape[1] > 10000:
2831
# Take a 1-second segment for visualization
2932
start_idx = int(fs * 2.5) # Pick a middle segment
3033
end_idx = start_idx + int(fs * 1.0)
31-
raw = raw_full[:3, start_idx:end_idx]
34+
raw = raw_full[viz_indices, start_idx:end_idx]
3235
t = t_full[start_idx:end_idx]
3336
t = t - t[0] # Zero relative time
3437
else:
35-
raw = raw_full[:3, :]
38+
raw = raw_full[viz_indices, :]
3639
t = t_full - t_full[0]
3740

3841
# 1. Processing Pipeline
@@ -43,7 +46,7 @@ def generate_pipeline_figure(data_path, save_path):
4346
filtered_full_seg = notch_filter(car_data, fs=fs, f0=60)
4447
filtered_full_seg = bandpass_filter(filtered_full_seg, lowcut=20, highcut=500, fs=fs)
4548

46-
processed_viz = filtered_full_seg[:3, :]
49+
processed_viz = filtered_full_seg[viz_indices, :]
4750

4851
# 2. QC Status (Run on full segment or whole array)
4952
qc = ChannelQC(fs=fs, n_channels=raw_full.shape[0])
@@ -53,61 +56,67 @@ def generate_pipeline_figure(data_path, save_path):
5356
qc.update(raw_full[:, start_idx:end_idx].T) # Transpose to (samples, channels)
5457
qc_results = qc.evaluate()
5558

56-
# Force one "Fail" for demonstration if none are naturally failing in the first 3
57-
qc_status = [not qc_results['bad'][i] for i in range(3)]
58-
# For the figure, we want to show a failure. If they all pass, let's mock one
59-
if all(qc_status):
60-
qc_status[1] = False
61-
# Add some noise to the filtered Ch 1 for visual consistency
62-
processed_viz[1] += np.random.randn(processed_viz.shape[1]) * 20.0
63-
59+
# Get actual results for the cherry-picked indices
60+
qc_status = [not qc_results['bad'][i] for i in viz_indices]
61+
6462
# 3. RMS Calculation
6563
rms = calculate_rms(processed_viz, window_size=int(0.1 * fs)) # 100ms windows
6664
rms_avg = np.mean(rms, axis=1)
6765

6866
# --- Plotting ---
69-
fig, axes = plt.subplots(4, 1, figsize=(10, 12), gridspec_kw={'height_ratios': [2, 2, 1, 2]})
70-
plt.subplots_adjust(hspace=0.4)
67+
fig, axes = plt.subplots(4, 1, figsize=(10, 13), gridspec_kw={'height_ratios': [2, 2, 1, 2]})
68+
plt.subplots_adjust(hspace=0.45)
7169

7270
# colors for panels
7371
colors_main = ['#3498db', '#e67e22', '#2ecc71']
7472

73+
# Scale parameters for visual consistency
74+
raw_offset = 500 # Offset between channels in raw plot
75+
raw_ylim = (-200, 1200) # Suitable range for 3 channels @ 500 offset
76+
77+
filt_offset = 200 # Offset between channels in filtered plot
78+
filt_ylim = (-150, 550) # Suitable range for 3 channels @ 200 offset
79+
7580
# A. Raw Waterfall
7681
for i in range(3):
77-
# Subtract mean and add offset
82+
# Center signal around zero before adding offset
7883
sig = raw[i] - np.mean(raw[i])
79-
axes[0].plot(t, sig + i * 150, color='black', alpha=0.7, linewidth=0.8)
84+
axes[0].plot(t, sig + i * raw_offset, color='black', alpha=0.7, linewidth=0.8)
8085
axes[0].set_title("A) Raw High-Density EMG Signals", loc='left', fontsize=14, fontweight='bold')
8186
axes[0].set_ylabel("Amplitude ($\mu$V)")
82-
axes[0].set_yticks([])
87+
axes[0].set_ylim(raw_ylim)
88+
axes[0].set_yticks([0, raw_offset, 2*raw_offset])
89+
axes[0].set_yticklabels([f"Ch {viz_indices[0]}", f"Ch {viz_indices[1]}", f"Ch {viz_indices[2]}"])
8390
axes[0].grid(True, alpha=0.2)
8491

8592
# B. Filtered Signal
8693
for i in range(3):
8794
sig = processed_viz[i] - np.mean(processed_viz[i])
88-
axes[1].plot(t, sig + i * 100, color=colors_main[i], linewidth=1.0)
95+
axes[1].plot(t, sig + i * filt_offset, color=colors_main[i], linewidth=1.0)
8996
axes[1].set_title("B) Preprocessed Waveforms (Bandpass, Notch, CAR)", loc='left', fontsize=14, fontweight='bold')
9097
axes[1].set_ylabel("Amplitude ($\mu$V)")
91-
axes[1].set_yticks([])
98+
axes[1].set_ylim(filt_ylim)
99+
axes[1].set_yticks([0, filt_offset, 2*filt_offset])
100+
axes[1].set_yticklabels([f"Ch {viz_indices[0]}", f"Ch {viz_indices[1]}", f"Ch {viz_indices[2]}"])
92101
axes[1].grid(True, alpha=0.2)
93102

94103
# C. QC Status
95104
qc_colors = ['green' if s else 'red' for s in qc_status]
96105
qc_labels = ['Pass' if s else 'Fail' for s in qc_status]
97106
for i in range(3):
98107
axes[2].barh(i, 1, color=qc_colors[i], alpha=0.6, height=0.6)
99-
axes[2].text(0.5, i, qc_labels[i], ha='center', va='center', color='white', fontweight='bold', fontsize=12)
108+
axes[2].text(0.5, i, f"{qc_labels[i]} (Ch {viz_indices[i]})", ha='center', va='center', color='white', fontweight='bold', fontsize=12)
100109
axes[2].set_title("C) Automated Channel Quality Monitoring", loc='left', fontsize=14, fontweight='bold')
101110
axes[2].set_yticks(range(3))
102-
axes[2].set_yticklabels([f"Channel {i}" for i in range(3)])
111+
axes[2].set_yticklabels([f"Ch {viz_indices[i]}" for i in range(3)])
103112
axes[2].set_xticks([])
104113
axes[2].set_xlim(0, 1)
105114

106115
# D. RMS Barplot
107116
axes[3].bar(range(3), rms_avg, color=qc_colors, alpha=0.8)
108117
axes[3].set_title("D) Extracted RMS Activation Features", loc='left', fontsize=14, fontweight='bold')
109118
axes[3].set_xticks(range(3))
110-
axes[3].set_xticklabels([f"Ch {i}" for i in range(3)])
119+
axes[3].set_xticklabels([f"Ch {viz_indices[i]}" for i in range(3)])
111120
axes[3].set_ylabel("RMS Amplitude ($\mu$V)")
112121
axes[3].grid(axis='y', alpha=0.3)
113122

scripts/inspect_data.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)