11use std:: {
2- collections:: { BinaryHeap , HashSet , VecDeque } ,
3- sync:: Arc ,
2+ cell:: RefCell , collections:: { BinaryHeap , HashSet , VecDeque } , sync:: Arc
43} ;
54
65use ahash:: HashMap ;
@@ -283,13 +282,19 @@ impl HnswOnlinePartition {
283282 query : & OwnedPrimArray ,
284283 field_id : u64 ,
285284 metric : M ,
285+ metadata : & SearchMetadata ,
286286 chunks : & Arc < Chunks > ,
287287 logger : & Arc < JobLogger > ,
288288 job_id : JobId ,
289289 ) -> Option < ( Id , Distance ) > {
290+ let mut distance_cache = metadata. distance_cache . borrow_mut ( ) ;
291+ if let Some ( distance) = distance_cache. get ( vertex_id) {
292+ return Some ( * distance) ;
293+ }
290294 let cell_id = self . get_vertex_cell_id ( vertex_id, chunks, logger, job_id) ?;
291295 let distance =
292296 self . get_cell_distance ( & cell_id, field_id, query, metric, chunks, logger, job_id) ?;
297+ distance_cache. insert ( * vertex_id, ( cell_id, distance) ) ;
293298 Some ( ( cell_id, distance) )
294299 }
295300
@@ -307,7 +312,7 @@ impl HnswOnlinePartition {
307312 let selected_vertex = frontier
308313 . iter ( )
309314 . filter_map ( |vertex_id| {
310- self . get_vertex_distance ( vertex_id, query, field_id, metric, chunks, logger, job_id)
315+ self . get_vertex_distance ( vertex_id, query, field_id, metric, metadata , chunks, logger, job_id)
311316 } )
312317 . filter ( |( id, _) | !metadata. visited . contains ( id) )
313318 // Greedy select the closest vertex
@@ -327,7 +332,8 @@ impl HnswOnlinePartition {
327332 // Make the selected vertex the last element in the history
328333 metadata. history . push ( selected_vertex) ;
329334 // Need to remember all the visited vertices
330- metadata. visited . extend ( frontier. iter ( ) . cloned ( ) ) ;
335+ // metadata.visited.extend(frontier.iter().cloned());
336+ metadata. visited . insert ( selected_vertex. 0 ) ;
331337 Ok ( selected_vertex)
332338 }
333339
@@ -420,6 +426,7 @@ impl HnswOnlinePartition {
420426 remote_frontiers : HashSet :: default ( ) ,
421427 last_distance : Distance :: INFINITY ,
422428 level_entries : HashMap :: default ( ) ,
429+ distance_cache : RefCell :: new ( HashMap :: default ( ) ) ,
423430 } ,
424431 frontier : starting_vertices,
425432 server_id : self . server_id ,
@@ -465,13 +472,13 @@ impl HnswOnlinePartition {
465472 } ;
466473 if new_distance >= search. metadata . last_distance {
467474 // Should quit the loop or decendent to the next level
468- if search. metadata . level > 0 && !search. metadata . history . is_empty ( ) {
475+ let metadata_level = search. metadata . level ;
476+ if metadata_level > 0 && !search. metadata . history . is_empty ( ) {
469477 // Decendent to the next level
470478 let last_vertex = search. metadata . history . last ( ) . unwrap ( ) . 0 ;
471- search
472- . metadata
479+ search. metadata
473480 . level_entries
474- . insert ( search . metadata . level , last_vertex) ;
481+ . insert ( metadata_level , last_vertex) ;
475482 search. metadata . level -= 1 ;
476483 continue ;
477484 } else {
@@ -504,13 +511,13 @@ impl HnswOnlinePartition {
504511 }
505512 } ;
506513 if next_frontier. is_empty ( ) {
507- if search. metadata . level > 0 {
514+ let metadata_level = search. metadata . level ;
515+ if metadata_level > 0 {
508516 // Decendent to the next level
509517 let last_vertex = search. metadata . history . last ( ) . unwrap ( ) . 0 ;
510- search
511- . metadata
518+ search. metadata
512519 . level_entries
513- . insert ( search . metadata . level , last_vertex) ;
520+ . insert ( metadata_level , last_vertex) ;
514521 search. metadata . level -= 1 ;
515522 continue ;
516523 } else {
@@ -536,6 +543,7 @@ impl HnswOnlinePartition {
536543 field_id : u64 ,
537544 level_schema : u32 ,
538545 metric : M ,
546+ metadata : & SearchMetadata ,
539547 logger : & Arc < JobLogger > ,
540548 job_id : JobId ,
541549 ) -> Result < Vec < ( Id , Id , Distance ) > , HNSWIndexError > {
@@ -552,70 +560,64 @@ impl HnswOnlinePartition {
552560
553561 while let Some ( current_vertex) = queue. pop_front ( ) {
554562 // Get distance for the current vertex
555- if let Some ( cell_id) =
556- self . get_vertex_cell_id ( & current_vertex, & self . chunks , logger, job_id)
557- {
558- if let Some ( distance) = self . get_cell_distance (
559- & cell_id,
560- field_id,
561- query,
562- metric,
563- & self . chunks ,
564- logger,
565- job_id,
566- ) {
567- // Use max heap with negative distance for top-k minimum
568- top_k. push ( ReverseOrd ( ( OrderedFloat ( distance) , current_vertex, cell_id) ) ) ;
563+ if let Some ( ( cell_id, distance) ) = self . get_vertex_distance (
564+ & current_vertex,
565+ query,
566+ field_id,
567+ metric,
568+ metadata,
569+ & self . chunks ,
570+ logger,
571+ job_id,
572+ ) {
573+ // Use max heap with negative distance for top-k minimum
574+ top_k. push ( ReverseOrd ( (
575+ OrderedFloat ( distance) ,
576+ current_vertex,
577+ cell_id,
578+ ) ) ) ;
569579
570- // If we have more than k elements, remove the largest one
571- if top_k. len ( ) > k {
572- top_k. pop ( ) ;
573- }
580+ // If we have more than k elements, remove the largest one
581+ if top_k. len ( ) > k {
582+ top_k. pop ( ) ;
583+ }
574584
575- // Get all neighbors of the current vertex
576- match self
577- . engine
578- . neighbour_id_and_edge (
579- & current_vertex,
580- level_schema,
581- EdgeDirection :: Undirected ,
582- )
583- . await
584- {
585- Ok ( Ok ( neighbors) ) => {
586- for ( neighbor_id, _) in neighbors {
587- let is_local = self . conshash . get_server_id ( neighbor_id. higher )
588- == Some ( self . server_id ) ;
589- if is_local && !visited. contains ( & neighbor_id) {
590- visited. insert ( neighbor_id) ;
591- queue. push_back ( neighbor_id) ;
592- }
585+ // Get all neighbors of the current vertex
586+ match self
587+ . engine
588+ . neighbour_id_and_edge ( & current_vertex, level_schema, EdgeDirection :: Undirected )
589+ . await
590+ {
591+ Ok ( Ok ( neighbors) ) => {
592+ for ( neighbor_id, _) in neighbors {
593+ let is_local = self . conshash . get_server_id ( neighbor_id. higher )
594+ == Some ( self . server_id ) ;
595+ if is_local && !visited. contains ( & neighbor_id) {
596+ visited. insert ( neighbor_id) ;
597+ queue. push_back ( neighbor_id) ;
593598 }
594599 }
595- Ok ( Err ( e) ) => {
596- append_job_log (
597- logger,
598- job_id,
599- JobLogLevel :: Error ,
600- format ! (
601- "Neighborhood error for vertex {:?}: {:?}" ,
602- current_vertex, e
603- ) ,
604- ) ;
605- continue ;
606- }
607- Err ( e) => {
608- append_job_log (
609- logger,
610- job_id,
611- JobLogLevel :: Error ,
612- format ! (
613- "Transaction error for vertex {:?}: {:?}" ,
614- current_vertex, e
615- ) ,
616- ) ;
617- continue ;
618- }
600+ }
601+ Ok ( Err ( e) ) => {
602+ append_job_log (
603+ logger,
604+ job_id,
605+ JobLogLevel :: Error ,
606+ format ! (
607+ "Neighborhood error for vertex {:?}: {:?}" ,
608+ current_vertex, e
609+ ) ,
610+ ) ;
611+ continue ;
612+ }
613+ Err ( e) => {
614+ append_job_log (
615+ logger,
616+ job_id,
617+ JobLogLevel :: Error ,
618+ format ! ( "Transaction error for vertex {:?}: {:?}" , current_vertex, e) ,
619+ ) ;
620+ continue ;
619621 }
620622 }
621623 }
@@ -660,6 +662,7 @@ impl HnswOnlinePartition {
660662 field_id,
661663 level_schema,
662664 metric,
665+ metadata,
663666 logger,
664667 job_id,
665668 )
@@ -673,6 +676,7 @@ impl HnswOnlinePartition {
673676 field_id,
674677 level_schema,
675678 metric,
679+ metadata,
676680 logger,
677681 job_id,
678682 )
@@ -686,6 +690,7 @@ impl HnswOnlinePartition {
686690 field_id,
687691 level_schema,
688692 metric,
693+ metadata,
689694 logger,
690695 job_id,
691696 )
@@ -699,6 +704,7 @@ impl HnswOnlinePartition {
699704 field_id,
700705 level_schema,
701706 metric,
707+ metadata,
702708 logger,
703709 job_id,
704710 )
@@ -742,7 +748,7 @@ impl HnswOnlinePartition {
742748 . get_index ( schema, field_id)
743749 . ok_or ( HNSWIndexError :: IndexNotFound ) ?;
744750
745- let mut random_level = generate_random_level ( PROB , MAX_LEVEL_CAP ) ;
751+ let random_level = generate_random_level ( PROB , MAX_LEVEL_CAP ) ;
746752
747753 append_job_log (
748754 logger,
@@ -794,6 +800,7 @@ impl HnswOnlinePartition {
794800 field_id,
795801 level_schema,
796802 metric,
803+ metadata,
797804 logger,
798805 job_id,
799806 )
0 commit comments