Skip to content

Commit aa9898f

Browse files
committed
order distance
1 parent 2bce30c commit aa9898f

2 files changed

Lines changed: 34 additions & 24 deletions

File tree

src/apps/hnsw/partition/search.rs

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,11 @@ impl HnswOnlinePartition {
601601
Ok(())
602602
}
603603

604+
605+
/// Performs a breadth-first search to find the top K closest vertices to the query
606+
///
607+
/// Returns a vector of tuples (vertex_id, cell_id, distance) containing the k vertices
608+
/// nearest to the query vector, sorted by distance in ascending order.
604609
async fn bfs_top_k<M: HnswMetric>(
605610
&self,
606611
starting_vertices: &[Id],
@@ -614,15 +619,20 @@ impl HnswOnlinePartition {
614619
job_id: JobId,
615620
) -> Result<Vec<(Id, Id, Distance)>, HNSWIndexError> {
616621
// Priority queue to maintain vertices ordered by distance
617-
let mut top_k = BinaryHeap::<ReverseOrd<(OrderedFloat, Id, Id)>>::with_capacity(k + 1);
622+
let mut top_k = BinaryHeap::<(OrderedFloat, Id, Id)>::with_capacity(k + 1);
618623

619624
// Queue for BFS traversal
620625
let mut queue = VecDeque::new();
621-
queue.extend(starting_vertices.iter().cloned());
622-
626+
623627
// Track visited vertices to avoid cycles
624628
let mut visited = HashSet::<Id>::default();
625-
visited.extend(starting_vertices.iter().cloned());
629+
630+
// Push starting vertices and mark them as visited at enqueue time
631+
for &vertex in starting_vertices {
632+
if visited.insert(vertex) { // Returns true if the value was not present
633+
queue.push_back(vertex);
634+
}
635+
}
626636

627637
while let Some(current_vertex) = queue.pop_front() {
628638
// Get distance for the current vertex
@@ -637,40 +647,34 @@ impl HnswOnlinePartition {
637647
job_id,
638648
) {
639649
// Use max heap with negative distance for top-k minimum
640-
top_k.push(ReverseOrd((
650+
// top_k.push(ReverseOrd((
651+
// OrderedFloat(distance),
652+
// current_vertex,
653+
// cell_id,
654+
// )));
655+
top_k.push((
641656
OrderedFloat(distance),
642657
current_vertex,
643658
cell_id,
644-
)));
659+
));
645660

646661
// If we have more than k elements, remove the largest one
647662
if top_k.len() > k {
648663
top_k.pop();
649664
}
650665

651-
visited.insert(current_vertex);
652666
// Get all neighbors of the current vertex
653667
match self.engine.node_local_neighbour_id_and_edge(
654668
&current_vertex,
655669
level_schema,
656670
EdgeDirection::Undirected,
657671
) {
658672
Ok(neighbors) => {
659-
append_job_log(
660-
logger,
661-
job_id,
662-
JobLogLevel::Info,
663-
format!(
664-
"Found {} neighbors for vertex {:?} with level schema {}",
665-
neighbors.len(),
666-
current_vertex,
667-
level_schema
668-
),
669-
);
670673
for (neighbor_id, _) in neighbors {
671674
let is_local = self.conshash.get_server_id(neighbor_id.higher)
672675
== Some(self.server_id);
673-
if is_local && !visited.contains(&neighbor_id) {
676+
// Only enqueue if local and not yet visited
677+
if is_local && visited.insert(neighbor_id) {
674678
queue.push_back(neighbor_id);
675679
}
676680
}
@@ -704,11 +708,11 @@ impl HnswOnlinePartition {
704708

705709
// Convert the max heap with negative distances back to sorted results
706710
let mut result: Vec<_> = top_k.into_iter().collect();
707-
result.sort_by(|a, b| b.0.cmp(&a.0));
711+
result.sort_by(|a, b| a.0.cmp(&b.0));
708712

709713
Ok(result
710714
.into_iter()
711-
.map(|ReverseOrd((dist, vid, cid))| (vid, cid, dist.0))
715+
.map(|(dist, vid, cid)| (vid, cid, dist.0))
712716
.collect())
713717
}
714718

@@ -917,11 +921,12 @@ impl HnswOnlinePartition {
917921
job_id,
918922
JobLogLevel::Info,
919923
format!(
920-
"Found {} nearest neighbors at level {} for vertex {:?}, with entry points {:?}",
924+
"Found {} nearest neighbors at level {} for vertex {:?}, with entry points {:?}. List of vertices: {}",
921925
nearest_neighbors.len(),
922926
level,
923927
vertex_id,
924-
entry_points
928+
entry_points,
929+
nearest_neighbors.iter().map(|(id, _, d)| format!("{:?}: {}", id, d)).collect::<Vec<_>>().join(", ")
925930
),
926931
);
927932
// Apply the neighbor selection heuristic to choose diverse neighbors

src/graph/id_list.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,10 +968,11 @@ impl<'a> Iterator for LocalIdListSegmentIdIterator<'a> {
968968
if !self.next.is_unit_id() {
969969
let cell = self.chunks.read_selected(&self.next, &*NEXT_KEY_ID_VEC);
970970
if let Ok(fields) = cell {
971+
let current_id = self.next;
971972
if let SharedValue::Id(ref id) = fields.data[0usize] {
972973
self.next = **id;
973974
self.level += 1;
974-
return Some(self.next);
975+
return Some(current_id);
975976
} else {
976977
error!("Expecting Id but got: {:?}", fields.data[0usize]);
977978
}
@@ -1060,6 +1061,9 @@ impl<'a> LocalIdList<'a> {
10601061
return Ok(Err(IdListError::ContainerCellNotReady));
10611062
}
10621063
}
1064+
SharedValue::Null | SharedValue::NA => {
1065+
return Ok(Err(IdListError::ContainerCellNotReady));
1066+
}
10631067
_ => {
10641068
error!(
10651069
"Cannot find type list id for container id {:?}",
@@ -1137,6 +1141,7 @@ impl<'a> LocalIdList<'a> {
11371141

11381142
/// Returns an iterator over all IDs in the list
11391143
pub fn iter(&mut self) -> Result<Result<LocalIdListIterator, IdListError>, ReadError> {
1144+
// Get the root list id
11401145
let list_root_id = match self.get_root_list_id()? {
11411146
Ok(id) => id,
11421147
Err(e) => return Ok(Err(e)),

0 commit comments

Comments
 (0)