Skip to content

Commit b32fb76

Browse files
committed
feat: Introduce a visit limit and optimize HNSW search by refining candidate heap management and conditional result set population.
1 parent e2f2b0d commit b32fb76

1 file changed

Lines changed: 52 additions & 44 deletions

File tree

src/apps/hnsw/partition/search.rs

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,11 +1396,15 @@ impl HnswOnlinePartition {
13961396
// Ensure ef is at least as large as k
13971397
let ef = ef.max(k);
13981398

1399+
// Safety cap to prevent unbounded exploration in pathological cases
1400+
let max_visits = ef * 20;
1401+
let mut visit_count = 0;
1402+
13991403
// Max-heap for results (to maintain closest ef elements, with worst at top for easy removal)
14001404
let mut result_set: BinaryHeap<(OrderedFloat, Id, Id)> = BinaryHeap::with_capacity(ef + 1);
14011405

1402-
// Min-heap for candidates (sort by distance, closest first)
1403-
let mut candidates: BinaryHeap<ReverseOrd<(OrderedFloat, Id, Id)>> =
1406+
// Min-heap for candidates (sort by distance, closest first) - only stores (distance, vertex_id)
1407+
let mut candidates: BinaryHeap<ReverseOrd<(OrderedFloat, Id)>> =
14041408
BinaryHeap::with_capacity(ef + 1);
14051409

14061410
// Track visited vertices
@@ -1426,12 +1430,13 @@ impl HnswOnlinePartition {
14261430
logger,
14271431
job_id,
14281432
) {
1429-
let cell_id = self
1430-
.get_vertex_cell_id(&start_id, &self.chunks, logger, metadata, job_id)
1431-
.unwrap();
1432-
// Add to both candidates and result_set
1433-
candidates.push(ReverseOrd((OrderedFloat(distance), start_id, cell_id)));
1434-
result_set.push((OrderedFloat(distance), start_id, cell_id));
1433+
// Only add to result_set if we can get the cell_id
1434+
if let Some(cell_id) =
1435+
self.get_vertex_cell_id(&start_id, &self.chunks, logger, metadata, job_id)
1436+
{
1437+
candidates.push(ReverseOrd((OrderedFloat(distance), start_id)));
1438+
result_set.push((OrderedFloat(distance), start_id, cell_id));
1439+
}
14351440
}
14361441
}
14371442

@@ -1441,19 +1446,25 @@ impl HnswOnlinePartition {
14411446
}
14421447

14431448
// Start with the worst distance in our result set
1444-
let mut dist_threshold = if result_set.len() > 0 {
1449+
let mut dist_threshold = if !result_set.is_empty() {
14451450
result_set.peek().unwrap().0
14461451
} else {
14471452
OrderedFloat(Distance::INFINITY)
14481453
};
14491454

14501455
// Main search loop
1451-
while let Some(ReverseOrd((dist, current_vertex, _))) = candidates.pop() {
1456+
while let Some(ReverseOrd((dist, current_vertex))) = candidates.pop() {
14521457
// Stop if minimum distance to candidates is worse than our worst result
14531458
if dist > dist_threshold {
14541459
break;
14551460
}
14561461

1462+
// Safety cap on visits
1463+
visit_count += 1;
1464+
if visit_count >= max_visits {
1465+
break;
1466+
}
1467+
14571468
// Get neighbors of current vertex
14581469
match self.visit_vertex_and_edge(&current_vertex, level, metadata, logger, job_id) {
14591470
Ok(neighbors) => {
@@ -1478,41 +1489,38 @@ impl HnswOnlinePartition {
14781489
logger,
14791490
job_id,
14801491
) {
1481-
// Only process if better than our worst result or if we need more results
1482-
if result_set.len() < ef
1483-
|| OrderedFloat(distance) < dist_threshold
1484-
{
1485-
let cell_id = self
1486-
.get_vertex_cell_id(
1487-
&neighbor_id,
1488-
&self.chunks,
1489-
logger,
1490-
metadata,
1491-
job_id,
1492-
)
1493-
.unwrap();
1494-
// Add to candidates queue
1495-
candidates.push(ReverseOrd((
1496-
OrderedFloat(distance),
1497-
neighbor_id,
1498-
cell_id,
1499-
)));
1500-
1501-
// Add to results
1502-
result_set.push((
1503-
OrderedFloat(distance),
1504-
neighbor_id,
1505-
cell_id,
1506-
));
1507-
1508-
// If we exceed capacity, remove the worst element
1509-
if result_set.len() > ef {
1510-
result_set.pop();
1511-
}
1492+
let ord_distance = OrderedFloat(distance);
15121493

1513-
// Update threshold to be the distance of the worst element
1514-
if let Some(worst) = result_set.peek() {
1515-
dist_threshold = worst.0;
1494+
// Only process if better than our worst result or if we need more results
1495+
if result_set.len() < ef || ord_distance < dist_threshold {
1496+
// Only add to results if we can get the cell_id
1497+
if let Some(cell_id) = self.get_vertex_cell_id(
1498+
&neighbor_id,
1499+
&self.chunks,
1500+
logger,
1501+
metadata,
1502+
job_id,
1503+
) {
1504+
// Only add to candidates if strictly better (avoids redundant exploration)
1505+
if result_set.len() < ef
1506+
|| ord_distance < dist_threshold
1507+
{
1508+
candidates
1509+
.push(ReverseOrd((ord_distance, neighbor_id)));
1510+
}
1511+
1512+
// Add to results
1513+
result_set.push((ord_distance, neighbor_id, cell_id));
1514+
1515+
// If we exceed capacity, remove the worst element
1516+
if result_set.len() > ef {
1517+
result_set.pop();
1518+
}
1519+
1520+
// Update threshold to be the distance of the worst element
1521+
if let Some(worst) = result_set.peek() {
1522+
dist_threshold = worst.0;
1523+
}
15161524
}
15171525
}
15181526
}

0 commit comments

Comments
 (0)