Skip to content

Commit 631b563

Browse files
committed
use shared cache for HNSW vertices
1 parent 440fe0a commit 631b563

2 files changed

Lines changed: 37 additions & 24 deletions

File tree

src/apps/hnsw/partition/search.rs

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::{
22
cell::RefCell,
33
collections::{BTreeMap, BTreeSet, BinaryHeap, HashSet, VecDeque},
4-
rc::Rc,
54
sync::Arc,
65
vec,
76
};
@@ -16,7 +15,7 @@ use dovahkiin::{
1615
use lightning::map::Map;
1716
use lightning::map::PtrHashMap;
1817
use neb::{
19-
index::{ranged::trees::Ordering, IndexerClients},
18+
index::{ranged::trees::Ordering},
2019
query::data_client::IndexedDataClient,
2120
ram::{
2221
chunk::Chunks,
@@ -27,7 +26,7 @@ use neb::{
2726
use crate::{
2827
apps::hnsw::measurements::{vector_distance, HnswMetric, Metric, MetricEncoding},
2928
graph::{
30-
edge::Edge, format_vertex_cell, id_list::IdList, partitioner::DefaultPartitioner,
29+
format_vertex_cell, id_list::IdList, partitioner::DefaultPartitioner,
3130
vertex::Vertex, EdgeDirection, GraphEngine, NeighbourhoodError,
3231
},
3332
job::{
@@ -76,6 +75,9 @@ pub struct HnswOnlinePartition {
7675
pub server_id: u64,
7776
pub chunks: Arc<Chunks>,
7877
pub engine: Arc<GraphEngine>,
78+
// Shared caches across searches in this partition
79+
pub shared_vertex_cache: PtrHashMap<Id, (Id, Id)>,
80+
pub shared_vertex_vector_cache: PtrHashMap<(Id, u64), Arc<OwnedPrimArray>>,
7981
}
8082

8183
impl HnswOnlinePartition {
@@ -168,6 +170,8 @@ impl HnswOnlinePartition {
168170
server_id,
169171
chunks: chunks.clone(),
170172
engine: engine.clone(),
173+
shared_vertex_cache: PtrHashMap::with_capacity(1 << 16),
174+
shared_vertex_vector_cache: PtrHashMap::with_capacity(1 << 18),
171175
})
172176
}
173177

@@ -226,14 +230,15 @@ impl HnswOnlinePartition {
226230
metadata: &SearchMetadata,
227231
job_id: JobId,
228232
) -> Option<(Id, Id)> {
229-
if let Some(v) = metadata.vertex_cache.borrow().get(vertex_id) {
233+
// shared cache only
234+
if let Some(v) = self.shared_vertex_cache.get(vertex_id) {
230235
append_job_log(
231236
logger,
232237
job_id,
233238
JobLogLevel::Trace,
234-
format!("Found vertex in cache for vertex {:?}", vertex_id),
239+
format!("Found vertex in SHARED cache for vertex {:?}", vertex_id),
235240
);
236-
return Some(*v);
241+
return Some(v);
237242
}
238243
let direction = EdgeDirection::Undirected;
239244
let edge_list = direction.as_field();
@@ -261,10 +266,8 @@ impl HnswOnlinePartition {
261266
(SharedValue::Id(cell_id), SharedValue::Id(type_list_id)) => {
262267
let cell_id = **cell_id;
263268
let type_list_id = **type_list_id;
264-
metadata
265-
.vertex_cache
266-
.borrow_mut()
267-
.insert(*vertex_id, (cell_id, type_list_id));
269+
// write-through to shared cache
270+
self.shared_vertex_cache.insert(*vertex_id, (cell_id, type_list_id));
268271
// println!("Inserted cell id in cache for vertex {:?}", vertex_id);
269272
Some((cell_id, type_list_id))
270273
}
@@ -300,11 +303,16 @@ impl HnswOnlinePartition {
300303
chunks: &'a Arc<Chunks>,
301304
logger: &Arc<JobLogger>,
302305
job_id: JobId,
303-
) -> Option<Rc<OwnedPrimArray>> {
306+
) -> Option<Arc<OwnedPrimArray>> {
307+
// per-search
304308
let mut vertex_vector_cache = metadata.vertex_vector_cache.borrow_mut();
305-
if let Some(vector) = vertex_vector_cache.get(vertex_id) {
309+
if let Some(vector) = vertex_vector_cache.get(&(*vertex_id, field_id)) {
306310
return Some(vector.clone());
307311
}
312+
// shared cache fallback
313+
if let Some(vector) = self.shared_vertex_vector_cache.get(&(*vertex_id, field_id)) {
314+
return Some(vector);
315+
}
308316
let (cell_id, _type_list_id) =
309317
self.get_vertex(vertex_id, chunks, logger, metadata, job_id)?;
310318
let cell = match chunks.read_selected(&cell_id, &[field_id], false) {
@@ -327,9 +335,11 @@ impl HnswOnlinePartition {
327335
match field {
328336
SharedValue::PrimArray(vector) => {
329337
let vector = vector.copy_into_owned();
330-
let rc_vector = Rc::new(vector);
331-
vertex_vector_cache.insert(*vertex_id, rc_vector.clone());
332-
Some(rc_vector)
338+
let arc_vector = Arc::new(vector);
339+
self.shared_vertex_vector_cache
340+
.insert((*vertex_id, field_id), arc_vector.clone());
341+
vertex_vector_cache.insert((*vertex_id, field_id), arc_vector.clone());
342+
Some(arc_vector)
333343
}
334344
_ => {
335345
append_job_log(
@@ -356,7 +366,8 @@ impl HnswOnlinePartition {
356366
) -> Option<Distance> {
357367
let mut vs = vec![vertex_a, vertex_b];
358368
vs.sort();
359-
let distance_key = (*vs[0], *vs[1]);
369+
let distance_key = (*vs[0], *vs[1], field_id);
370+
// per-search cache
360371
let mut distance_cache = metadata.vertex_distance_cache.borrow_mut();
361372
if let Some(distance) = distance_cache.get(&distance_key) {
362373
return Some(*distance);
@@ -550,7 +561,6 @@ impl HnswOnlinePartition {
550561
level_entries: HashMap::default(),
551562
vertex_distance_cache: RefCell::new(BTreeMap::default()),
552563
vertex_vector_cache: RefCell::new(BTreeMap::default()),
553-
vertex_cache: RefCell::new(BTreeMap::default()),
554564
ef,
555565
ef_construction,
556566
},
@@ -562,7 +572,7 @@ impl HnswOnlinePartition {
562572
metric,
563573
max_level,
564574
};
565-
partition_search.metadata.set_query_vector(query);
575+
partition_search.metadata.set_query_vector(query, field_id);
566576
Ok(partition_search)
567577
}
568578

@@ -1116,6 +1126,10 @@ impl HnswOnlinePartition {
11161126
}
11171127
}
11181128
}
1129+
// Invalidate only affected entries
1130+
// Note: leaving distance cache entries intact for stability; vector values are immutable
1131+
self.shared_vertex_cache.remove(&vertex_id);
1132+
self.shared_vertex_vector_cache.remove(&(vertex_id.clone(), field_id));
11191133
}
11201134
// If this is a high-level vertex, add it to the index's top_level_vertices
11211135
if random_level >= max_level {

src/apps/hnsw/partition/types.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::cell::RefCell;
22
use std::collections::{BTreeMap, HashSet};
3-
use std::rc::Rc;
3+
use std::sync::Arc;
44

55
use ahash::HashMap;
66
use bifrost::rpc::RPCError;
@@ -62,9 +62,8 @@ pub struct SearchMetadata {
6262
pub level_entries: ahash::HashMap<usize, Id>,
6363
pub remote_frontiers: HashSet<Id>,
6464
pub last_distance: Distance,
65-
pub vertex_distance_cache: RefCell<BTreeMap<(Id, Id), Distance>>,
66-
pub vertex_cache: RefCell<BTreeMap<Id, (Id, Id)>>,
67-
pub vertex_vector_cache: RefCell<BTreeMap<Id, Rc<OwnedPrimArray>>>,
65+
pub vertex_distance_cache: RefCell<BTreeMap<(Id, Id, u64), Distance>>,
66+
pub vertex_vector_cache: RefCell<BTreeMap<(Id, u64), Arc<OwnedPrimArray>>>,
6867
pub ef: usize, // Extension factor for search, controls exploration vs. exploitation
6968
pub ef_construction: usize, // Extension factor for construction, controls how many neighbors to consider when building the graph
7069
}
@@ -73,9 +72,9 @@ unsafe impl Sync for SearchMetadata {}
7372
unsafe impl Send for SearchMetadata {}
7473

7574
impl SearchMetadata {
76-
pub fn set_query_vector(&self, query: OwnedPrimArray) {
75+
pub fn set_query_vector(&self, query: OwnedPrimArray, field_id: u64) {
7776
self.vertex_vector_cache
7877
.borrow_mut()
79-
.insert(Id::unit_id(), Rc::new(query));
78+
.insert((Id::unit_id(), field_id), Arc::new(query));
8079
}
8180
}

0 commit comments

Comments
 (0)