@@ -19,7 +19,6 @@ use std::cmp::min;
1919use std:: collections:: { BinaryHeap , HashMap , VecDeque } ;
2020use std:: fmt:: Debug ;
2121use std:: iter;
22- use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
2322use std:: sync:: Arc ;
2423use std:: sync:: RwLock ;
2524use tracing:: instrument;
@@ -307,10 +306,10 @@ impl HNSW {
307306 . inner
308307 . level_count
309308 . iter ( )
310- . chain ( iter:: once ( & AtomicUsize :: new ( 0 ) ) )
311- . scan ( 0 , |state, x | {
309+ . chain ( iter:: once ( & 0usize ) )
310+ . scan ( 0usize , |state, & count | {
312311 let start = * state;
313- * state += x . load ( Ordering :: Relaxed ) ;
312+ * state += count ;
314313 Some ( start)
315314 } )
316315 . collect ( ) ;
@@ -327,7 +326,7 @@ struct HnswBuilder {
327326 params : HnswBuildParams ,
328327
329328 nodes : Arc < Vec < RwLock < GraphBuilderNode > > > ,
330- level_count : Vec < AtomicUsize > ,
329+ level_count : Vec < usize > ,
331330
332331 entry_point : u32 ,
333332
@@ -349,7 +348,7 @@ impl HnswBuilder {
349348 }
350349
351350 fn num_nodes ( & self , level : usize ) -> usize {
352- self . level_count [ level] . load ( Ordering :: Relaxed )
351+ self . level_count [ level]
353352 }
354353
355354 fn nodes ( & self ) -> Arc < Vec < RwLock < GraphBuilderNode > > > {
@@ -361,9 +360,7 @@ impl HnswBuilder {
361360 let len = storage. len ( ) ;
362361 let max_level = params. max_level ;
363362
364- let level_count = ( 0 ..max_level)
365- . map ( |_| AtomicUsize :: new ( 0 ) )
366- . collect :: < Vec < _ > > ( ) ;
363+ let level_count = vec ! [ 0usize ; max_level as usize ] ;
367364
368365 let visited_generator_queue = Arc :: new ( ArrayQueue :: new ( get_num_compute_intensive_cpus ( ) ) ) ;
369366 for _ in 0 ..get_num_compute_intensive_cpus ( ) {
@@ -445,8 +442,6 @@ impl HnswBuilder {
445442 {
446443 let mut current_node = nodes[ node as usize ] . write ( ) . unwrap ( ) ;
447444 for level in ( 0 ..=target_level) . rev ( ) {
448- self . level_count [ level as usize ] . fetch_add ( 1 , Ordering :: Relaxed ) ;
449-
450445 let neighbors = self . search_level ( & ep, level, & dist_calc, nodes, visited_generator) ;
451446 for neighbor in & neighbors {
452447 current_node. add_neighbor ( neighbor. id , neighbor. dist , level) ;
@@ -525,6 +520,17 @@ impl HnswBuilder {
525520 * neighbors_ranked = select_neighbors_heuristic ( storage, & level_neighbors, m_max) ;
526521 builder_node. update_from_ranked_neighbors ( level) ;
527522 }
523+
524+ fn compute_level_count ( & self ) -> Vec < usize > {
525+ let mut level_count = vec ! [ 0usize ; self . max_level( ) as usize ] ;
526+ for node in self . nodes . iter ( ) {
527+ let levels = node. read ( ) . unwrap ( ) . level_neighbors . len ( ) ;
528+ for count in level_count. iter_mut ( ) . take ( levels) {
529+ * count += 1 ;
530+ }
531+ }
532+ level_count
533+ }
528534}
529535
530536// View of a level in HNSW graph.
@@ -666,7 +672,7 @@ impl IvfSubIndex for HNSW {
666672 let inner = HnswBuilder {
667673 params : hnsw_metadata. params ,
668674 nodes : Arc :: new ( nodes. into_iter ( ) . map ( RwLock :: new) . collect ( ) ) ,
669- level_count : level_count . into_iter ( ) . map ( AtomicUsize :: new ) . collect ( ) ,
675+ level_count,
670676 entry_point : hnsw_metadata. entry_point ,
671677 visited_generator_queue,
672678 } ;
@@ -763,34 +769,37 @@ impl IvfSubIndex for HNSW {
763769 where
764770 Self : Sized ,
765771 {
766- let inner = HnswBuilder :: with_params ( params, storage) ;
767- let hnsw = Self {
768- inner : Arc :: new ( inner) ,
769- } ;
772+ let mut inner = HnswBuilder :: with_params ( params, storage) ;
770773
771774 log:: debug!(
772775 "Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}" ,
773776 storage. len( ) ,
774- hnsw . inner. params. max_level,
775- hnsw . inner. params. m,
776- hnsw . inner. params. ef_construction,
777+ inner. params. max_level,
778+ inner. params. m,
779+ inner. params. ef_construction,
777780 storage. distance_type( ) ,
778781 ) ;
779782
780783 if storage. is_empty ( ) {
781- return Ok ( hnsw) ;
784+ return Ok ( Self {
785+ inner : Arc :: new ( inner) ,
786+ } ) ;
782787 }
783788
784789 let len = storage. len ( ) ;
785- hnsw. inner . level_count [ 0 ] . fetch_add ( 1 , Ordering :: Relaxed ) ;
786790 ( 1 ..len) . into_par_iter ( ) . for_each_init (
787791 || VisitedGenerator :: new ( len) ,
788792 |visited_generator, node| {
789- hnsw . inner . insert ( node as u32 , visited_generator, storage) ;
793+ inner. insert ( node as u32 , visited_generator, storage) ;
790794 } ,
791795 ) ;
796+ inner. level_count = inner. compute_level_count ( ) ;
792797
793- assert_eq ! ( hnsw. inner. level_count[ 0 ] . load( Ordering :: Relaxed ) , len) ;
798+ let hnsw = Self {
799+ inner : Arc :: new ( inner) ,
800+ } ;
801+
802+ assert_eq ! ( hnsw. inner. level_count[ 0 ] , len) ;
794803 Ok ( hnsw)
795804 }
796805
@@ -945,4 +954,19 @@ mod tests {
945954 . unwrap ( ) ;
946955 assert_eq ! ( builder_results, loaded_results) ;
947956 }
957+
958+ #[ test]
959+ fn test_level_offsets_match_batch_rows ( ) {
960+ const DIM : usize = 16 ;
961+ const TOTAL : usize = 512 ;
962+ let data = generate_random_array ( TOTAL * DIM ) ;
963+ let fsl = FixedSizeListArray :: try_new_from_values ( data, DIM as i32 ) . unwrap ( ) ;
964+ let store = FlatFloatStorage :: new ( fsl, DistanceType :: L2 ) ;
965+ let hnsw = HNSW :: index_vectors ( & store, HnswBuildParams :: default ( ) ) . unwrap ( ) ;
966+ let metadata = hnsw. metadata ( ) ;
967+ let batch = hnsw. to_batch ( ) . unwrap ( ) ;
968+
969+ assert_eq ! ( metadata. level_offsets. len( ) , hnsw. max_level( ) as usize + 1 ) ;
970+ assert_eq ! ( * metadata. level_offsets. last( ) . unwrap( ) , batch. num_rows( ) ) ;
971+ }
948972}
0 commit comments