Skip to content

Commit e8d61db

Browse files
committed
fix: improve HNSW search quality and tests
1 parent ea2f707 commit e8d61db

6 files changed

Lines changed: 517 additions & 143 deletions

File tree

src/apps/embedding/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,10 @@ impl EmbeddingIndexerCore for EmbeddingIndexer {
627627

628628
// Search HNSW (on the embedding schema's vector field)
629629
let query_vector = OwnedPrimArray::F32(result.vector);
630-
let ef = (limit * 2) as u64;
630+
let ef_min = 32usize;
631+
let ef_max = 200usize;
632+
let ef_scaled = (limit * 4).clamp(ef_min, ef_max);
633+
let ef = std::cmp::max(limit, ef_scaled) as u64;
631634

632635
let hits = self
633636
.hnsw_coordinator

src/apps/hnsw/coordinator/tests.rs

Lines changed: 108 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,59 @@ impl TestVector {
4646
}
4747
}
4848

49+
fn compute_l2_distance(query: &[f32], vector: &OwnedPrimArray) -> Distance {
50+
match vector {
51+
OwnedPrimArray::F32(values) => values
52+
.iter()
53+
.zip(query.iter())
54+
.map(|(a, b)| {
55+
let diff = a - b;
56+
diff * diff
57+
})
58+
.sum(),
59+
_ => Distance::INFINITY,
60+
}
61+
}
62+
63+
fn exact_top_k(query: &[f32], vectors: &[TestVector], k: usize) -> Vec<(Id, Distance)> {
64+
let mut scored: Vec<(Id, Distance)> = vectors
65+
.iter()
66+
.map(|v| (v.cell_id, compute_l2_distance(query, &v.vector)))
67+
.collect();
68+
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
69+
scored.truncate(k);
70+
scored
71+
}
72+
73+
fn recall_at_k(results: &[(Id, Distance)], exact: &[(Id, Distance)]) -> f64 {
74+
if exact.is_empty() {
75+
return 0.0;
76+
}
77+
let exact_ids: std::collections::HashSet<Id> = exact.iter().map(|(id, _)| *id).collect();
78+
let hit_count = results
79+
.iter()
80+
.filter(|(id, _)| exact_ids.contains(id))
81+
.count();
82+
hit_count as f64 / exact.len() as f64
83+
}
84+
85+
fn top1_distance_ratio(results: &[(Id, Distance)], exact: &[(Id, Distance)]) -> Distance {
86+
if results.is_empty() || exact.is_empty() {
87+
return Distance::INFINITY;
88+
}
89+
let exact_top = exact[0].1.max(1e-9);
90+
results[0].1 / exact_top
91+
}
92+
93+
fn overlap_at_k(a: &[(Id, Distance)], b: &[(Id, Distance)]) -> f64 {
94+
if b.is_empty() {
95+
return 0.0;
96+
}
97+
let ids_a: std::collections::HashSet<Id> = a.iter().map(|(id, _)| *id).collect();
98+
let hits = b.iter().filter(|(id, _)| ids_a.contains(id)).count();
99+
hits as f64 / b.len() as f64
100+
}
101+
49102
/// Test environment structure to hold all components needed for HNSW coordinator tests
50103
pub struct TestEnvironment {
51104
pub job_counter: RefCell<u64>,
@@ -348,10 +401,24 @@ mod tests {
348401
env.index_cells().await.unwrap();
349402

350403
let query = vec![1.0, 2.0, 3.0];
351-
let results = env.top_k(query, 1).await.unwrap();
404+
let results = env.top_k(query.clone(), 1).await.unwrap();
352405
assert_eq!(results.len(), 1);
353-
assert_eq!(results[0].0.lower, 1000); // First cell
354-
assert_eq!(results[0].1, 0.0); // Exact match
406+
407+
let exact = exact_top_k(&query, &env.test_vectors, 1);
408+
let recall = recall_at_k(&results, &exact);
409+
assert!(recall >= 1.0, "Recall@1 should be 1.0, got {:.2}", recall);
410+
let ratio = top1_distance_ratio(&results, &exact);
411+
assert!(
412+
ratio <= 1.0 + 1e-6,
413+
"Top-1 distance ratio should be 1.0, got {:.6}",
414+
ratio
415+
);
416+
let overlap = overlap_at_k(&results, &exact);
417+
assert!(
418+
overlap >= 1.0,
419+
"Overlap@1 should be 1.0, got {:.2}",
420+
overlap
421+
);
355422
env.cleanup().await;
356423
}
357424

@@ -382,11 +449,24 @@ mod tests {
382449
env.index_cells().await.unwrap();
383450

384451
let query = vec![4.0, 5.0, 6.0];
385-
let results = env.top_k(query, 2).await.unwrap();
452+
let results = env.top_k(query.clone(), 2).await.unwrap();
386453
assert_eq!(results.len(), 2);
387-
assert_eq!(results[0].0.lower, 1001); // Second cell
388-
assert_eq!(results[0].1, 0.0); // Exact match
389-
assert!(results[1].0.lower == 1000 || results[1].0.lower == 1002); // First or third cell
454+
455+
let exact = exact_top_k(&query, &env.test_vectors, 2);
456+
let recall = recall_at_k(&results, &exact);
457+
assert!(recall >= 1.0, "Recall@2 should be 1.0, got {:.2}", recall);
458+
let ratio = top1_distance_ratio(&results, &exact);
459+
assert!(
460+
ratio <= 1.0 + 1e-6,
461+
"Top-1 distance ratio should be 1.0, got {:.6}",
462+
ratio
463+
);
464+
let overlap = overlap_at_k(&results, &exact);
465+
assert!(
466+
overlap >= 1.0,
467+
"Overlap@2 should be 1.0, got {:.2}",
468+
overlap
469+
);
390470
env.cleanup().await;
391471
}
392472

@@ -420,11 +500,28 @@ mod tests {
420500

421501
let query = vec![10.0, 11.0, 12.0];
422502
let k = 5;
423-
let results = env.top_k(query, k).await.unwrap();
503+
let results = env.top_k(query.clone(), k).await.unwrap();
424504
assert_eq!(results.len(), k as usize);
425-
assert_eq!(results[0].0.lower, 1010); // Tenth cell
426-
assert_eq!(results[0].1, 0.0); // Exact match
427-
assert!(results[1].0.lower == 1009 || results[1].0.lower == 1011); // Neighboring cells
505+
506+
let exact = exact_top_k(&query, &env.test_vectors, k as usize);
507+
let recall = recall_at_k(&results, &exact);
508+
assert!(
509+
recall >= 0.9,
510+
"Recall@5 should be >= 0.9, got {:.2}",
511+
recall
512+
);
513+
let ratio = top1_distance_ratio(&results, &exact);
514+
assert!(
515+
ratio <= 1.1,
516+
"Top-1 distance ratio should be <= 1.1, got {:.3}",
517+
ratio
518+
);
519+
let overlap = overlap_at_k(&results, &exact);
520+
assert!(
521+
overlap >= 0.8,
522+
"Overlap@5 should be >= 0.8, got {:.2}",
523+
overlap
524+
);
428525
env.cleanup().await;
429526
}
430527

src/apps/hnsw/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ impl VectorIndexerCore for VectorIndexer {
121121
) -> BoxFuture<Result<Vec<VectorHit>, IndexError>> {
122122
// Convert the query vector slice to OwnedPrimArray
123123
let query = OwnedPrimArray::F32(query_vector.to_vec());
124-
// Use L2 metric as default, ef = limit * 2 for better recall
125-
let ef = (limit * 2) as u64;
124+
let ef_min = 32usize;
125+
let ef_max = 200usize;
126+
let ef = (limit * 4).clamp(ef_min, ef_max) as u64;
126127
self.coordinator
127128
.query_top_k(
128129
schema_id,

src/apps/hnsw/partition/search.rs

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use bifrost::{conshash::ConsistentHashing, raft::client::RaftClient};
1111
use dovahkiin::types::Map as DovMap;
1212
use dovahkiin::{
1313
data_map,
14-
types::{Id, OwnedPrimArray, SharedValue},
14+
types::{Id, OwnedPrimArray, SharedPrimArray, SharedValue},
1515
};
1616
use futures::stream::FuturesUnordered;
1717
use lightning::map::Map;
@@ -820,13 +820,30 @@ impl HnswOnlinePartition {
820820
_logger: &Arc<JobLogger>,
821821
_job_id: JobId,
822822
) -> Result<Vec<Id>, NeighbourhoodError> {
823-
// Use edge cache for unified read path (includes pending writes)
824823
if let Some(neighbors) = self.edge_cache.get_neighbors(*vertex_id, level) {
825-
Ok(neighbors)
826-
} else {
827-
// Vertex not found in cache - return empty neighbors (new vertex or not yet indexed)
828-
Ok(vec![])
824+
return Ok(neighbors);
825+
}
826+
827+
if level >= MAX_LEVELS {
828+
return Ok(vec![]);
829829
}
830+
831+
let cell_field_id = super::schema::CELL_FIELD_ID;
832+
let mut field_ids = vec![cell_field_id];
833+
field_ids.extend(NEIGHBORS_FIELD_IDS.iter().copied());
834+
835+
let cell = match self.chunks.read_selected(vertex_id, &field_ids, false) {
836+
Ok(cell) => cell,
837+
Err(_) => return Ok(vec![]),
838+
};
839+
840+
let neighbors = match &cell.data[level + 1] {
841+
SharedValue::PrimArray(SharedPrimArray::Id(list)) => list.to_vec(),
842+
SharedValue::Null | SharedValue::NA => vec![],
843+
_ => vec![],
844+
};
845+
846+
Ok(neighbors)
830847
}
831848

832849
fn bfs_top_k_pruned(
@@ -907,12 +924,7 @@ impl HnswOnlinePartition {
907924
}
908925
};
909926

910-
// Deterministic neighbor sampling: limit neighbors explored per vertex
911-
// HNSW neighbor lists are often ordered by distance, so taking the first N
912-
// prioritizes closer (more relevant) neighbors while reducing overhead
913-
const MAX_NEIGHBORS_TO_EXPLORE: usize = 64;
914-
915-
let sample_limit = MAX_NEIGHBORS_TO_EXPLORE.min(neighbors.len());
927+
let sample_limit = neighbors.len();
916928
for nid in neighbors.iter().take(sample_limit) {
917929
let nid_index = metadata.registry.get_or_insert(*nid);
918930
if !visited.contains(nid_index) {
@@ -1977,8 +1989,8 @@ impl HnswOnlinePartition {
19771989
}
19781990
last_best_distance = current_best;
19791991

1980-
// Early termination: if best distance is stable for 3 iterations and we have enough candidates
1981-
if stable_count >= 3 && search.history.len() >= ef && priority_frontier.len() >= ef {
1992+
if stable_count >= 5 && search.history.len() >= ef * 2 && priority_frontier.len() >= ef
1993+
{
19821994
break;
19831995
}
19841996

0 commit comments

Comments
 (0)