Skip to content

Commit 5ebddcf

Browse files
committed
New quantization, some fixes in evaluation scripts
1 parent 4aa5a72 commit 5ebddcf

27 files changed

Lines changed: 696 additions & 543 deletions

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ RawHash performs real-time mapping of nanopore raw signals. When the prefix of r
2020

2121
# Recent changes
2222

23+
* We came up with a better and more accurate quantization mechanism in RawHash2. The new quantization mechanism dynamically arranges the bucket sizes that each signal value is quantized depending on the normalized distribution of the signal values. **This provides significant improvements in both accuracy and performance.**
24+
2325
* We have integrated the signal alignment functionality with DTW as proposed in RawAlign (see the citation below). The parameters may still not be highly optimized as this is still in experimental stage. Use it with caution.
2426

2527
* Offline overlapping functionality is integrated.

src/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,4 @@ hit.o: rmap.h kalloc.h khash.h
252252
rmap.o: rindex.h rsig.h kthread.h rh_kvec.h rutils.h rsketch.h revent.h sequence_until.h dtw.h
253253
revent.o: roptions.h kalloc.h
254254
rindex.o: roptions.h rutils.h rsketch.h rsig.h bseq.h khash.h rh_kvec.h kthread.h
255-
main:o rawhash.h ketopt.h rutils.h
255+
main:o rawhash.h ketopt.h rutils.h

src/main.cpp

Lines changed: 99 additions & 87 deletions
Large diffs are not rendered by default.

src/revent.c

Lines changed: 179 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <assert.h>
44
#include <float.h>
55
#include <math.h>
6+
#include "rutils.h"
67

78
//Some of the functions here are adopted from the Sigmap implementation (https://github.com/haowenz/sigmap/tree/c9a40483264c9514587a36555b5af48d3f054f6f). We have optimized the Sigmap implementation to work with the hash tables efficiently.
89

@@ -19,8 +20,11 @@ typedef struct ri_detect_s {
1920
int valid_peak;
2021
}ri_detect_t;
2122

22-
static inline void comp_prefix_prefixsq(const float *sig, uint32_t s_len, float* prefix_sum, float* prefix_sum_square) {
23-
23+
static inline void comp_prefix_prefixsq(const float *sig,
24+
const uint32_t s_len,
25+
float* prefix_sum,
26+
float* prefix_sum_square)
27+
{
2428
assert(s_len > 0);
2529

2630
prefix_sum[0] = 0.0f;
@@ -31,22 +35,17 @@ static inline void comp_prefix_prefixsq(const float *sig, uint32_t s_len, float*
3135
}
3236
}
3337

34-
static inline float* comp_tstat(void *km, const float *prefix_sum, const float *prefix_sum_square, uint32_t s_len, uint32_t w_len) {
35-
38+
static inline float* comp_tstat(void *km,
39+
const float *prefix_sum,
40+
const float *prefix_sum_square,
41+
const uint32_t s_len,
42+
const uint32_t w_len)
43+
{
3644
const float eta = FLT_MIN;
37-
38-
// rh_kvec_t(float) tstat = {0,0,0};
39-
// rh_kv_resize(float, 0, tstat, s_len+1);
40-
// rh_kv_pushp(float, 0, tstat, &s);
41-
4245
float* tstat = (float*)ri_kcalloc(km, s_len+1, sizeof(float));
43-
// Quick return:
44-
// t-test not defined for number of points less than 2
45-
// need at least as many points as twice the window length
4646
if (s_len < 2*w_len || w_len < 2) return tstat;
47-
// fudge boundaries
4847
memset(tstat, 0, w_len*sizeof(float));
49-
// get to work on the rest
48+
5049
for (uint32_t i = w_len; i <= s_len - w_len; ++i) {
5150
float sum1 = prefix_sum[i];
5251
float sumsq1 = prefix_sum_square[i];
@@ -58,46 +57,59 @@ static inline float* comp_tstat(void *km, const float *prefix_sum, const float *
5857
float sumsq2 = prefix_sum_square[i + w_len] - prefix_sum_square[i];
5958
float mean1 = sum1 / w_len;
6059
float mean2 = sum2 / w_len;
61-
float combined_var = sumsq1 / w_len - mean1 * mean1 + sumsq2 / w_len - mean2 * mean2;
60+
float combined_var = (sumsq1/w_len - mean1*mean1 + sumsq2/w_len - mean2*mean2)/w_len;
6261
// Prevent problem due to very small variances
6362
combined_var = fmaxf(combined_var, eta);
6463
// t-stat
6564
// Formula is a simplified version of Student's t-statistic for the
6665
// special case where there are two samples of equal size with
6766
// differing variance
6867
const float delta_mean = mean2 - mean1;
69-
tstat[i] = fabs(delta_mean) / sqrt(combined_var / w_len);
68+
tstat[i] = fabs(delta_mean) / sqrt(combined_var);
7069
}
7170
// fudge boundaries
7271
memset(tstat+s_len-w_len+1, 0, (w_len)*sizeof(float));
7372

7473
return tstat;
7574
}
7675

77-
static inline uint32_t gen_peaks(ri_detect_t *short_detector, ri_detect_t *long_detector, const float peak_height, uint32_t* peaks) {
78-
79-
assert(short_detector->s_len == long_detector->s_len);
76+
// static inline float calculate_adaptive_peak_height(const float *prefix_sum, const float *prefix_sum_square, uint32_t current_index, uint32_t window_length, float base_peak_height) {
77+
// // Ensure we don't go beyond signal bounds
78+
// uint32_t start_index = current_index > window_length ? current_index - window_length : 0;
79+
// uint32_t end_index = current_index + window_length;
8080

81-
uint32_t curInd = 0;
81+
// float sum = prefix_sum[end_index] - prefix_sum[start_index];
82+
// float sumsq = prefix_sum_square[end_index] - prefix_sum_square[start_index];
83+
// float mean = sum / (end_index - start_index);
84+
// float variance = (sumsq / (end_index - start_index)) - (mean * mean);
85+
// float stddev = sqrtf(variance);
86+
87+
// // Example adaptive strategy: Increase peak height in high-variance regions
88+
// return base_peak_height * (1 + stddev);
89+
// }
90+
91+
static inline uint32_t gen_peaks(ri_detect_t **detectors,
92+
const uint32_t n_detectors,
93+
const float peak_height,
94+
const float *prefix_sum,
95+
const float *prefix_sum_square,
96+
uint32_t* peaks) {
8297

83-
uint32_t ndetector = 2;
84-
ri_detect_t *detectors[ndetector]; // = {short_detector, long_detector};
85-
detectors[0] = short_detector;
86-
detectors[1] = long_detector;
87-
for (uint32_t i = 0; i < short_detector->s_len; i++) {
88-
for (uint32_t k = 0; k < ndetector; k++) {
98+
uint32_t curInd = 0;
99+
for (uint32_t i = 0; i < detectors[0]->s_len; i++) {
100+
for (uint32_t k = 0; k < n_detectors; k++) {
89101
ri_detect_t *detector = detectors[k];
90-
// Carry on if we've been masked out
91102
if (detector->masked_to >= i) continue;
92103

93104
float current_value = detector->sig[i];
105+
// float adaptive_peak_height = calculate_adaptive_peak_height(prefix_sum, prefix_sum_square, i, detector->window_length, peak_height);
106+
94107
if (detector->peak_pos == detector->DEF_PEAK_POS) {
95-
// CASE 1: We've not yet recorded a maximum
96-
if (current_value < detector->peak_value) {
97-
// Either record a deeper minimum...
108+
// CASE 1: We've not yet recorded any maximum
109+
if (current_value < detector->peak_value) { // A deeper minimum:
98110
detector->peak_value = current_value;
99-
} else if (current_value - detector->peak_value > peak_height) { // TODO(Haowen): this might cause overflow, need to fix this
100-
// ...or we've seen a qualifying maximum
111+
} else if (current_value - detector->peak_value > peak_height) {
112+
// ...or a qualifying maximum:
101113
detector->peak_value = current_value;
102114
detector->peak_pos = i;
103115
// otherwise, wait to rise high enough to be considered a peak
@@ -109,22 +121,22 @@ static inline uint32_t gen_peaks(ri_detect_t *short_detector, ri_detect_t *long_
109121
detector->peak_value = current_value;
110122
detector->peak_pos = i;
111123
}
112-
// Dominate other tstat signals if we're going to fire at some point
113-
if (detector == short_detector) {
114-
if (detector->peak_value > detector->threshold) {
115-
long_detector->masked_to = detector->peak_pos + detector->window_length;
116-
long_detector->peak_pos = long_detector->DEF_PEAK_POS;
117-
long_detector->peak_value = long_detector->DEF_PEAK_VAL;
118-
long_detector->valid_peak = 0;
124+
// Tell other detectors no need to check for a peak until a certain point
125+
if (detector->peak_value > detector->threshold) {
126+
for(int n_d = k+1; n_d < n_detectors; n_d++){
127+
detectors[n_d]->masked_to = detector->peak_pos + detectors[0]->window_length;
128+
detectors[n_d]->peak_pos = detectors[n_d]->DEF_PEAK_POS;
129+
detectors[n_d]->peak_value = detectors[n_d]->DEF_PEAK_VAL;
130+
detectors[n_d]->valid_peak = 0;
119131
}
120132
}
121-
// Have we convinced ourselves we've seen a peak
122-
if (detector->peak_value - current_value > peak_height && detector->peak_value > detector->threshold) {
133+
// There is a good peak
134+
if (detector->peak_value - current_value > peak_height &&
135+
detector->peak_value > detector->threshold) {
123136
detector->valid_peak = 1;
124137
}
125-
// Finally, check the distance if this is a good peak
138+
// Check if we are now further away from the current peak
126139
if (detector->valid_peak && (i - detector->peak_pos) > detector->window_length / 2) {
127-
// Emit the boundary and reset
128140
peaks[curInd++] = detector->peak_pos;
129141
detector->peak_pos = detector->DEF_PEAK_POS;
130142
detector->peak_value = current_value;
@@ -137,78 +149,165 @@ static inline uint32_t gen_peaks(ri_detect_t *short_detector, ri_detect_t *long_
137149
return curInd;
138150
}
139151

152+
int compare_floats(const void* a, const void* b) {
153+
const float* da = (const float*) a;
154+
const float* db = (const float*) b;
155+
return (*da > *db) - (*da < *db);
156+
}
157+
158+
float calculate_mean_of_filtered_segment(float* segment,
159+
const uint32_t segment_length)
160+
{
161+
// Calculate median and IQR
162+
qsort(segment, segment_length, sizeof(float), compare_floats); // Assuming compare_floats is already defined
163+
float q1 = segment[segment_length / 4];
164+
float q3 = segment[3 * segment_length / 4];
165+
float iqr = q3 - q1;
166+
float lower_bound = q1 - iqr;
167+
float upper_bound = q3 + iqr;
168+
169+
float sum = 0.0;
170+
uint32_t count = 0;
171+
for (uint32_t i = 0; i < segment_length; i++) {
172+
if (segment[i] >= lower_bound && segment[i] <= upper_bound) {
173+
sum += segment[i];
174+
++count;
175+
}
176+
}
177+
178+
// Return the mean of the filtered segment
179+
return count > 0 ? sum / count : 0; // Ensure we don't divide by zero
180+
}
181+
140182
/**
141183
* @brief Generates events from peaks, prefix sums and s_len.
142184
*
143185
* @param km Pointer to memory manager.
144186
* @param peaks Array of peak positions.
145187
* @param peak_size Size of peaks array.
146188
* @param prefix_sum Array of prefix sums.
147-
* @param prefix_sum_square Array of prefix sums squared.
148189
* @param s_len Length of the signal.
149190
* @param n_events Pointer to the number of events generated.
150191
* @return float* Pointer to the array of generated events.
151192
*/
152193
static inline float* gen_events(void *km,
194+
float* sig,
153195
const uint32_t *peaks,
154196
const uint32_t peak_size,
155-
const float *prefix_sum,
156-
const float *prefix_sum_square,
157197
const uint32_t s_len,
158198
uint32_t* n_events)
159199
{
160-
uint32_t n_ev = 1;
161-
double mean = 0, std_dev = 0, sum = 0, sum2 = 0;
200+
uint32_t n_ev = 0;
162201

163-
for (uint32_t i = 1; i < peak_size; ++i)
164-
if (peaks[i] > 0 && peaks[i] < s_len)
165-
n_ev++;
202+
for (uint32_t pi = 0; pi < peak_size; ++pi)
203+
if (peaks[pi] > 0 && peaks[pi] < s_len) n_ev++;
166204

167205
float* events = (float*)ri_kmalloc(km, n_ev*sizeof(float));
168-
float l_prefixsum = 0, l_peak = 0;
169-
170-
for (uint32_t pi = 0; pi < n_ev - 1; pi++){
171-
events[pi] = (prefix_sum[peaks[pi]] - l_prefixsum)/(peaks[pi]-l_peak);
172-
sum += events[pi];
173-
sum2 += events[pi]*events[pi];
174-
l_prefixsum = prefix_sum[peaks[pi]];
175-
l_peak = peaks[pi];
176-
}
177206

178-
events[n_ev-1] = (prefix_sum[s_len] - l_prefixsum)/(s_len-l_peak);
179-
sum += events[n_ev-1];
180-
sum2 += events[n_ev-1]*events[n_ev-1];
207+
uint32_t start_idx = 0, segment_length = 0;
181208

182-
//normalization
183-
mean = sum/n_ev;
184-
std_dev = sqrt(sum2/n_ev - (mean)*(mean));
209+
for (uint32_t pi = 0, i = 0; pi < peak_size && i < n_ev; pi++){
210+
if (!(peaks[pi] > 0 && peaks[pi] < s_len)) continue;
185211

186-
for(uint32_t i = 0; i < n_ev; ++i){
187-
events[i] = (events[i]-mean)/std_dev;
212+
segment_length = peaks[pi] - start_idx;
213+
events[i++] = calculate_mean_of_filtered_segment(sig + start_idx, segment_length);
214+
start_idx = peaks[pi];
188215
}
189216

190217
(*n_events) = n_ev;
191218
return events;
192219
}
193220

194-
float* detect_events(void *km, uint32_t s_len, const float* sig, uint32_t window_length1, uint32_t window_length2, float threshold1, float threshold2, float peak_height, uint32_t *n) // kt_for() callback
221+
static inline float* normalize_signal(void *km,
222+
const float* sig,
223+
const uint32_t s_len,
224+
double* mean_sum,
225+
double* std_dev_sum,
226+
uint32_t* n_events_sum,
227+
uint32_t* n_sig)
228+
{
229+
double sum = (*mean_sum), sum2 = (*std_dev_sum);
230+
double mean = 0, std_dev = 0;
231+
float* events = (float*)ri_kcalloc(km, s_len, sizeof(float));
232+
233+
for (uint32_t i = 0; i < s_len; ++i) {
234+
sum += sig[i];
235+
sum2 += sig[i]*sig[i];
236+
}
237+
238+
(*n_events_sum) += s_len;
239+
(*mean_sum) = sum;
240+
(*std_dev_sum) = sum2;
241+
242+
mean = sum/(*n_events_sum);
243+
std_dev = sqrt(sum2/(*n_events_sum) - (mean)*(mean));
244+
245+
float norm_val = 0;
246+
int k = 0;
247+
for(uint32_t i = 0; i < s_len; ++i){
248+
norm_val = (sig[i]-mean)/std_dev;
249+
if(norm_val < 3 && norm_val > -3) events[k++] = norm_val;
250+
}
251+
252+
(*n_sig) = k;
253+
254+
return events;
255+
}
256+
257+
float* detect_events(void *km,
258+
const uint32_t s_len,
259+
const float* sig,
260+
const uint32_t window_length1,
261+
const uint32_t window_length2,
262+
const float threshold1,
263+
const float threshold2,
264+
const float peak_height,
265+
double* mean_sum,
266+
double* std_dev_sum,
267+
uint32_t* n_events_sum,
268+
uint32_t *n_events)
195269
{
196270
float* prefix_sum = (float*)ri_kcalloc(km, s_len+1, sizeof(float));
197271
float* prefix_sum_square = (float*)ri_kcalloc(km, s_len+1, sizeof(float));
198272

199-
comp_prefix_prefixsq(sig, s_len, prefix_sum, prefix_sum_square);
200-
float* tstat1 = comp_tstat(km, prefix_sum, prefix_sum_square, s_len, window_length1);
201-
float* tstat2 = comp_tstat(km, prefix_sum, prefix_sum_square, s_len, window_length2);
202-
ri_detect_t short_detector = {.DEF_PEAK_POS = -1, .DEF_PEAK_VAL = FLT_MAX, .sig = tstat1, .s_len = s_len, .threshold = threshold1,
203-
.window_length = window_length1, .masked_to = 0, .peak_pos = -1, .peak_value = FLT_MAX, .valid_peak = 0};
273+
//Normalize the signal
274+
uint32_t n_signals = 0;
275+
float* norm_signals = normalize_signal(km, sig, s_len, mean_sum, std_dev_sum, n_events_sum, &n_signals);
276+
if(n_signals == 0) return 0;
277+
comp_prefix_prefixsq(norm_signals, n_signals, prefix_sum, prefix_sum_square);
278+
279+
float* tstat1 = comp_tstat(km, prefix_sum, prefix_sum_square, n_signals, window_length1);
280+
float* tstat2 = comp_tstat(km, prefix_sum, prefix_sum_square, n_signals, window_length2);
281+
ri_detect_t short_detector = {.DEF_PEAK_POS = -1,
282+
.DEF_PEAK_VAL = FLT_MAX,
283+
.sig = tstat1,
284+
.s_len = n_signals,
285+
.threshold = threshold1,
286+
.window_length = window_length1,
287+
.masked_to = 0,
288+
.peak_pos = -1,
289+
.peak_value = FLT_MAX,
290+
.valid_peak = 0};
291+
292+
ri_detect_t long_detector = {.DEF_PEAK_POS = -1,
293+
.DEF_PEAK_VAL = FLT_MAX,
294+
.sig = tstat2,
295+
.s_len = n_signals,
296+
.threshold = threshold2,
297+
.window_length = window_length2,
298+
.masked_to = 0,
299+
.peak_pos = -1,
300+
.peak_value = FLT_MAX,
301+
.valid_peak = 0};
302+
303+
uint32_t* peaks = (uint32_t*)ri_kmalloc(km, n_signals * sizeof(uint32_t));
304+
ri_detect_t *detectors[2] = {&short_detector, &long_detector};
305+
uint32_t n_peaks = gen_peaks(detectors, 2, peak_height, prefix_sum, prefix_sum_square, peaks);
306+
ri_kfree(km, tstat1); ri_kfree(km, tstat2); ri_kfree(km, prefix_sum); ri_kfree(km, prefix_sum_square);
204307

205-
ri_detect_t long_detector = {.DEF_PEAK_POS = -1, .DEF_PEAK_VAL = FLT_MAX, .sig = tstat2, .s_len = s_len, .threshold = threshold2,
206-
.window_length = window_length2, .masked_to = 0, .peak_pos = -1, .peak_value = FLT_MAX, .valid_peak = 0};
207-
uint32_t* peaks = (uint32_t*)ri_kmalloc(km, s_len * sizeof(uint32_t));
208-
uint32_t n_peaks = gen_peaks(&short_detector, &long_detector, peak_height, peaks);
209308
float* events = 0;
210-
if(n_peaks > 0) events = gen_events(km, peaks, n_peaks, prefix_sum, prefix_sum_square, s_len, n);
211-
ri_kfree(km, tstat1); ri_kfree(km, tstat2); ri_kfree(km, prefix_sum); ri_kfree(km, prefix_sum_square); ri_kfree(km, peaks);
309+
if(n_peaks > 0) events = gen_events(km, norm_signals, peaks, n_peaks, n_signals, n_events);
310+
ri_kfree(km, norm_signals); ri_kfree(km, peaks);
212311

213312
return events;
214313
}

0 commit comments

Comments
 (0)