|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import os |
| 4 | + |
| 5 | +def generate_pipeline_figure(save_path): |
| 6 | + """ |
| 7 | + Generates a professional 4-panel figure for the JOSS paper: |
| 8 | + 1. Raw signal (Waterfall) |
| 9 | + 2. Filtered signal |
| 10 | + 3. QC Status |
| 11 | + 4. RMS Barplot |
| 12 | + """ |
| 13 | + fs = 1000 |
| 14 | + t = np.arange(fs) / fs |
| 15 | + n_channels = 3 |
| 16 | + |
| 17 | + # Generate Synthetic Data |
| 18 | + # Channel 0: High signal, Channel 1: Noisy (failed QC), Channel 2: Normal |
| 19 | + np.random.seed(42) |
| 20 | + raw = np.zeros((n_channels, len(t))) |
| 21 | + raw[0] = 5 * np.sin(2 * np.pi * 10 * t) + np.random.randn(len(t)) * 0.5 |
| 22 | + raw[1] = 2 * np.sin(2 * np.pi * 10 * t) + np.random.randn(len(t)) * 10.0 # Noisy |
| 23 | + raw[2] = 3 * np.sin(2 * np.pi * 10 * t) + np.random.randn(len(t)) * 0.5 |
| 24 | + |
| 25 | + # Filtered (simulated bandpass) |
| 26 | + filtered = np.zeros_like(raw) |
| 27 | + filtered[0] = 5 * np.sin(2 * np.pi * 10 * t) + np.random.randn(len(t)) * 0.1 |
| 28 | + filtered[1] = 2 * np.sin(2 * np.pi * 10 * t) + np.random.randn(len(t)) * 0.1 |
| 29 | + filtered[2] = 3 * np.sin(2 * np.pi * 10 * t) + np.random.randn(len(t)) * 0.1 |
| 30 | + |
| 31 | + # QC Status |
| 32 | + qc_pass = [True, False, True] # Channel 1 fails |
| 33 | + |
| 34 | + # RMS |
| 35 | + rms = np.sqrt(np.mean(filtered**2, axis=1)) |
| 36 | + |
| 37 | + # Create Figure |
| 38 | + fig, axes = plt.subplots(4, 1, figsize=(10, 12), gridspec_kw={'height_ratios': [2, 2, 1, 2]}) |
| 39 | + plt.subplots_adjust(hspace=0.4) |
| 40 | + |
| 41 | + # 1. Raw Waterfall |
| 42 | + for i in range(n_channels): |
| 43 | + axes[0].plot(t, raw[i] + i * 25, color='black', alpha=0.7) |
| 44 | + axes[0].set_title("A) Raw High-Density Signals", loc='left', fontsize=14, fontweight='bold') |
| 45 | + axes[0].set_ylabel("Amplitude (Offset)") |
| 46 | + axes[0].set_yticks([]) |
| 47 | + axes[0].grid(True, alpha=0.3) |
| 48 | + |
| 49 | + # 2. Filtered |
| 50 | + for i in range(n_channels): |
| 51 | + axes[1].plot(t, filtered[i] + i * 10, label=f"Ch {i}") |
| 52 | + axes[1].set_title("B) Filtered Waveforms (Bandpass + Notch)", loc='left', fontsize=14, fontweight='bold') |
| 53 | + axes[1].set_ylabel("Amplitude (Offset)") |
| 54 | + axes[1].set_yticks([]) |
| 55 | + axes[1].grid(True, alpha=0.3) |
| 56 | + |
| 57 | + # 3. QC Status |
| 58 | + colors = ['green', 'red', 'green'] |
| 59 | + labels = ['Pass', 'Fail', 'Pass'] |
| 60 | + for i in range(n_channels): |
| 61 | + axes[2].barh(i, 1, color=colors[i], alpha=0.6, height=0.6) |
| 62 | + axes[2].text(0.5, i, labels[i], ha='center', va='center', color='white', fontweight='bold') |
| 63 | + axes[2].set_title("C) Automated Channel Quality Assessment (QC)", loc='left', fontsize=14, fontweight='bold') |
| 64 | + axes[2].set_yticks(range(n_channels)) |
| 65 | + axes[2].set_yticklabels([f"Ch {i}" for i in range(n_channels)]) |
| 66 | + axes[2].set_xticks([]) |
| 67 | + axes[2].set_xlim(0, 1) |
| 68 | + |
| 69 | + # 4. RMS Barplot |
| 70 | + axes[3].bar(range(n_channels), rms, color=['#3498db', '#e74c3c', '#2ecc71'], alpha=0.8) |
| 71 | + axes[3].set_title("D) Extracted RMS Features", loc='left', fontsize=14, fontweight='bold') |
| 72 | + axes[3].set_xticks(range(n_channels)) |
| 73 | + axes[3].set_xticklabels([f"Ch {i}" for i in range(n_channels)]) |
| 74 | + axes[3].set_ylabel("RMS Amplitude") |
| 75 | + axes[3].grid(axis='y', alpha=0.3) |
| 76 | + |
| 77 | + # Common X Axis |
| 78 | + axes[0].set_xlabel("Time (s)") |
| 79 | + axes[1].set_xlabel("Time (s)") |
| 80 | + axes[3].set_xlabel("Channel Index") |
| 81 | + |
| 82 | + plt.tight_layout() |
| 83 | + plt.savefig(save_path, dpi=300) |
| 84 | + print(f"Figure saved to {save_path}") |
| 85 | + |
| 86 | +if __name__ == "__main__": |
| 87 | + output_dir = "docs/figs" |
| 88 | + if not os.path.exists(output_dir): |
| 89 | + os.makedirs(output_dir) |
| 90 | + generate_pipeline_figure(os.path.join(output_dir, "pipeline.png")) |
0 commit comments