Skip to content

Commit d8508da

Browse files
committed
make internal mutability to cache only
1 parent 935ca7c commit d8508da

3 files changed

Lines changed: 86 additions & 75 deletions

File tree

src/apps/hnsw/partition/search.rs

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::{
2-
collections::{BinaryHeap, HashSet, VecDeque},
3-
sync::Arc,
2+
cell::RefCell, collections::{BinaryHeap, HashSet, VecDeque}, sync::Arc
43
};
54

65
use ahash::HashMap;
@@ -283,13 +282,19 @@ impl HnswOnlinePartition {
283282
query: &OwnedPrimArray,
284283
field_id: u64,
285284
metric: M,
285+
metadata: &SearchMetadata,
286286
chunks: &Arc<Chunks>,
287287
logger: &Arc<JobLogger>,
288288
job_id: JobId,
289289
) -> Option<(Id, Distance)> {
290+
let mut distance_cache = metadata.distance_cache.borrow_mut();
291+
if let Some(distance) = distance_cache.get(vertex_id) {
292+
return Some(*distance);
293+
}
290294
let cell_id = self.get_vertex_cell_id(vertex_id, chunks, logger, job_id)?;
291295
let distance =
292296
self.get_cell_distance(&cell_id, field_id, query, metric, chunks, logger, job_id)?;
297+
distance_cache.insert(*vertex_id, (cell_id, distance));
293298
Some((cell_id, distance))
294299
}
295300

@@ -307,7 +312,7 @@ impl HnswOnlinePartition {
307312
let selected_vertex = frontier
308313
.iter()
309314
.filter_map(|vertex_id| {
310-
self.get_vertex_distance(vertex_id, query, field_id, metric, chunks, logger, job_id)
315+
self.get_vertex_distance(vertex_id, query, field_id, metric, metadata, chunks, logger, job_id)
311316
})
312317
.filter(|(id, _)| !metadata.visited.contains(id))
313318
// Greedy select the closest vertex
@@ -327,7 +332,8 @@ impl HnswOnlinePartition {
327332
// Make the selected vertex the last element in the history
328333
metadata.history.push(selected_vertex);
329334
// Need to remember all the visited vertices
330-
metadata.visited.extend(frontier.iter().cloned());
335+
// metadata.visited.extend(frontier.iter().cloned());
336+
metadata.visited.insert(selected_vertex.0);
331337
Ok(selected_vertex)
332338
}
333339

@@ -420,6 +426,7 @@ impl HnswOnlinePartition {
420426
remote_frontiers: HashSet::default(),
421427
last_distance: Distance::INFINITY,
422428
level_entries: HashMap::default(),
429+
distance_cache: RefCell::new(HashMap::default()),
423430
},
424431
frontier: starting_vertices,
425432
server_id: self.server_id,
@@ -465,13 +472,13 @@ impl HnswOnlinePartition {
465472
};
466473
if new_distance >= search.metadata.last_distance {
467474
// Should quit the loop or decendent to the next level
468-
if search.metadata.level > 0 && !search.metadata.history.is_empty() {
475+
let metadata_level = search.metadata.level;
476+
if metadata_level > 0 && !search.metadata.history.is_empty() {
469477
// Decendent to the next level
470478
let last_vertex = search.metadata.history.last().unwrap().0;
471-
search
472-
.metadata
479+
search.metadata
473480
.level_entries
474-
.insert(search.metadata.level, last_vertex);
481+
.insert(metadata_level, last_vertex);
475482
search.metadata.level -= 1;
476483
continue;
477484
} else {
@@ -504,13 +511,13 @@ impl HnswOnlinePartition {
504511
}
505512
};
506513
if next_frontier.is_empty() {
507-
if search.metadata.level > 0 {
514+
let metadata_level = search.metadata.level;
515+
if metadata_level > 0 {
508516
// Decendent to the next level
509517
let last_vertex = search.metadata.history.last().unwrap().0;
510-
search
511-
.metadata
518+
search.metadata
512519
.level_entries
513-
.insert(search.metadata.level, last_vertex);
520+
.insert(metadata_level, last_vertex);
514521
search.metadata.level -= 1;
515522
continue;
516523
} else {
@@ -536,6 +543,7 @@ impl HnswOnlinePartition {
536543
field_id: u64,
537544
level_schema: u32,
538545
metric: M,
546+
metadata: &SearchMetadata,
539547
logger: &Arc<JobLogger>,
540548
job_id: JobId,
541549
) -> Result<Vec<(Id, Id, Distance)>, HNSWIndexError> {
@@ -552,70 +560,64 @@ impl HnswOnlinePartition {
552560

553561
while let Some(current_vertex) = queue.pop_front() {
554562
// Get distance for the current vertex
555-
if let Some(cell_id) =
556-
self.get_vertex_cell_id(&current_vertex, &self.chunks, logger, job_id)
557-
{
558-
if let Some(distance) = self.get_cell_distance(
559-
&cell_id,
560-
field_id,
561-
query,
562-
metric,
563-
&self.chunks,
564-
logger,
565-
job_id,
566-
) {
567-
// Use max heap with negative distance for top-k minimum
568-
top_k.push(ReverseOrd((OrderedFloat(distance), current_vertex, cell_id)));
563+
if let Some((cell_id, distance)) = self.get_vertex_distance(
564+
&current_vertex,
565+
query,
566+
field_id,
567+
metric,
568+
metadata,
569+
&self.chunks,
570+
logger,
571+
job_id,
572+
) {
573+
// Use max heap with negative distance for top-k minimum
574+
top_k.push(ReverseOrd((
575+
OrderedFloat(distance),
576+
current_vertex,
577+
cell_id,
578+
)));
569579

570-
// If we have more than k elements, remove the largest one
571-
if top_k.len() > k {
572-
top_k.pop();
573-
}
580+
// If we have more than k elements, remove the largest one
581+
if top_k.len() > k {
582+
top_k.pop();
583+
}
574584

575-
// Get all neighbors of the current vertex
576-
match self
577-
.engine
578-
.neighbour_id_and_edge(
579-
&current_vertex,
580-
level_schema,
581-
EdgeDirection::Undirected,
582-
)
583-
.await
584-
{
585-
Ok(Ok(neighbors)) => {
586-
for (neighbor_id, _) in neighbors {
587-
let is_local = self.conshash.get_server_id(neighbor_id.higher)
588-
== Some(self.server_id);
589-
if is_local && !visited.contains(&neighbor_id) {
590-
visited.insert(neighbor_id);
591-
queue.push_back(neighbor_id);
592-
}
585+
// Get all neighbors of the current vertex
586+
match self
587+
.engine
588+
.neighbour_id_and_edge(&current_vertex, level_schema, EdgeDirection::Undirected)
589+
.await
590+
{
591+
Ok(Ok(neighbors)) => {
592+
for (neighbor_id, _) in neighbors {
593+
let is_local = self.conshash.get_server_id(neighbor_id.higher)
594+
== Some(self.server_id);
595+
if is_local && !visited.contains(&neighbor_id) {
596+
visited.insert(neighbor_id);
597+
queue.push_back(neighbor_id);
593598
}
594599
}
595-
Ok(Err(e)) => {
596-
append_job_log(
597-
logger,
598-
job_id,
599-
JobLogLevel::Error,
600-
format!(
601-
"Neighborhood error for vertex {:?}: {:?}",
602-
current_vertex, e
603-
),
604-
);
605-
continue;
606-
}
607-
Err(e) => {
608-
append_job_log(
609-
logger,
610-
job_id,
611-
JobLogLevel::Error,
612-
format!(
613-
"Transaction error for vertex {:?}: {:?}",
614-
current_vertex, e
615-
),
616-
);
617-
continue;
618-
}
600+
}
601+
Ok(Err(e)) => {
602+
append_job_log(
603+
logger,
604+
job_id,
605+
JobLogLevel::Error,
606+
format!(
607+
"Neighborhood error for vertex {:?}: {:?}",
608+
current_vertex, e
609+
),
610+
);
611+
continue;
612+
}
613+
Err(e) => {
614+
append_job_log(
615+
logger,
616+
job_id,
617+
JobLogLevel::Error,
618+
format!("Transaction error for vertex {:?}: {:?}", current_vertex, e),
619+
);
620+
continue;
619621
}
620622
}
621623
}
@@ -660,6 +662,7 @@ impl HnswOnlinePartition {
660662
field_id,
661663
level_schema,
662664
metric,
665+
metadata,
663666
logger,
664667
job_id,
665668
)
@@ -673,6 +676,7 @@ impl HnswOnlinePartition {
673676
field_id,
674677
level_schema,
675678
metric,
679+
metadata,
676680
logger,
677681
job_id,
678682
)
@@ -686,6 +690,7 @@ impl HnswOnlinePartition {
686690
field_id,
687691
level_schema,
688692
metric,
693+
metadata,
689694
logger,
690695
job_id,
691696
)
@@ -699,6 +704,7 @@ impl HnswOnlinePartition {
699704
field_id,
700705
level_schema,
701706
metric,
707+
metadata,
702708
logger,
703709
job_id,
704710
)
@@ -742,7 +748,7 @@ impl HnswOnlinePartition {
742748
.get_index(schema, field_id)
743749
.ok_or(HNSWIndexError::IndexNotFound)?;
744750

745-
let mut random_level = generate_random_level(PROB, MAX_LEVEL_CAP);
751+
let random_level = generate_random_level(PROB, MAX_LEVEL_CAP);
746752

747753
append_job_log(
748754
logger,
@@ -794,6 +800,7 @@ impl HnswOnlinePartition {
794800
field_id,
795801
level_schema,
796802
metric,
803+
metadata,
797804
logger,
798805
job_id,
799806
)

src/apps/hnsw/partition/service.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ impl service::Service for service::HnswPartitionService {
195195
JobLogLevel::Info,
196196
format!("Starting BFS search for job {:?}", job_id),
197197
);
198-
199198
match self
200199
.partition
201200
.search_top_k(

src/apps/hnsw/partition/types.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use std::cell::RefCell;
12
use std::collections::HashSet;
23

4+
use ahash::HashMap;
35
use bifrost::rpc::RPCError;
46
use dovahkiin::types::{Id, OwnedPrimArray};
57
use neb::{client::transaction::TxnError, ram::cell::WriteError};
@@ -23,7 +25,7 @@ impl Ord for OrderedFloat {
2325
pub struct ReverseOrd<T: Ord>(pub T);
2426
impl<T: Ord> Ord for ReverseOrd<T> {
2527
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
26-
other.0.cmp(&self.0) // Reversed comparison
28+
other.0.cmp(&self.0) // Reversed comparison
2729
}
2830
}
2931
impl<T: Ord> PartialOrd for ReverseOrd<T> {
@@ -57,4 +59,7 @@ pub struct SearchMetadata {
5759
pub level_entries: ahash::HashMap<usize, Id>,
5860
pub remote_frontiers: HashSet<Id>,
5961
pub last_distance: Distance,
62+
pub distance_cache: RefCell<HashMap<Id, (Id, Distance)>>,
6063
}
64+
65+
unsafe impl Sync for SearchMetadata {}

0 commit comments

Comments
 (0)