Skip to content

Commit f20e0a9

Browse files
authored
perf: compute HNSW level counts after build (#5590)
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent 781110a commit f20e0a9

File tree

1 file changed

+47
-23
lines changed

1 file changed

+47
-23
lines changed

rust/lance-index/src/vector/hnsw/builder.rs

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ use std::cmp::min;
1919
use std::collections::{BinaryHeap, HashMap, VecDeque};
2020
use std::fmt::Debug;
2121
use std::iter;
22-
use std::sync::atomic::{AtomicUsize, Ordering};
2322
use std::sync::Arc;
2423
use std::sync::RwLock;
2524
use tracing::instrument;
@@ -307,10 +306,10 @@ impl HNSW {
307306
.inner
308307
.level_count
309308
.iter()
310-
.chain(iter::once(&AtomicUsize::new(0)))
311-
.scan(0, |state, x| {
309+
.chain(iter::once(&0usize))
310+
.scan(0usize, |state, &count| {
312311
let start = *state;
313-
*state += x.load(Ordering::Relaxed);
312+
*state += count;
314313
Some(start)
315314
})
316315
.collect();
@@ -327,7 +326,7 @@ struct HnswBuilder {
327326
params: HnswBuildParams,
328327

329328
nodes: Arc<Vec<RwLock<GraphBuilderNode>>>,
330-
level_count: Vec<AtomicUsize>,
329+
level_count: Vec<usize>,
331330

332331
entry_point: u32,
333332

@@ -349,7 +348,7 @@ impl HnswBuilder {
349348
}
350349

351350
fn num_nodes(&self, level: usize) -> usize {
352-
self.level_count[level].load(Ordering::Relaxed)
351+
self.level_count[level]
353352
}
354353

355354
fn nodes(&self) -> Arc<Vec<RwLock<GraphBuilderNode>>> {
@@ -361,9 +360,7 @@ impl HnswBuilder {
361360
let len = storage.len();
362361
let max_level = params.max_level;
363362

364-
let level_count = (0..max_level)
365-
.map(|_| AtomicUsize::new(0))
366-
.collect::<Vec<_>>();
363+
let level_count = vec![0usize; max_level as usize];
367364

368365
let visited_generator_queue = Arc::new(ArrayQueue::new(get_num_compute_intensive_cpus()));
369366
for _ in 0..get_num_compute_intensive_cpus() {
@@ -445,8 +442,6 @@ impl HnswBuilder {
445442
{
446443
let mut current_node = nodes[node as usize].write().unwrap();
447444
for level in (0..=target_level).rev() {
448-
self.level_count[level as usize].fetch_add(1, Ordering::Relaxed);
449-
450445
let neighbors = self.search_level(&ep, level, &dist_calc, nodes, visited_generator);
451446
for neighbor in &neighbors {
452447
current_node.add_neighbor(neighbor.id, neighbor.dist, level);
@@ -525,6 +520,17 @@ impl HnswBuilder {
525520
*neighbors_ranked = select_neighbors_heuristic(storage, &level_neighbors, m_max);
526521
builder_node.update_from_ranked_neighbors(level);
527522
}
523+
524+
fn compute_level_count(&self) -> Vec<usize> {
525+
let mut level_count = vec![0usize; self.max_level() as usize];
526+
for node in self.nodes.iter() {
527+
let levels = node.read().unwrap().level_neighbors.len();
528+
for count in level_count.iter_mut().take(levels) {
529+
*count += 1;
530+
}
531+
}
532+
level_count
533+
}
528534
}
529535

530536
// View of a level in HNSW graph.
@@ -666,7 +672,7 @@ impl IvfSubIndex for HNSW {
666672
let inner = HnswBuilder {
667673
params: hnsw_metadata.params,
668674
nodes: Arc::new(nodes.into_iter().map(RwLock::new).collect()),
669-
level_count: level_count.into_iter().map(AtomicUsize::new).collect(),
675+
level_count,
670676
entry_point: hnsw_metadata.entry_point,
671677
visited_generator_queue,
672678
};
@@ -763,34 +769,37 @@ impl IvfSubIndex for HNSW {
763769
where
764770
Self: Sized,
765771
{
766-
let inner = HnswBuilder::with_params(params, storage);
767-
let hnsw = Self {
768-
inner: Arc::new(inner),
769-
};
772+
let mut inner = HnswBuilder::with_params(params, storage);
770773

771774
log::debug!(
772775
"Building HNSW graph: num={}, max_levels={}, m={}, ef_construction={}, distance_type:{}",
773776
storage.len(),
774-
hnsw.inner.params.max_level,
775-
hnsw.inner.params.m,
776-
hnsw.inner.params.ef_construction,
777+
inner.params.max_level,
778+
inner.params.m,
779+
inner.params.ef_construction,
777780
storage.distance_type(),
778781
);
779782

780783
if storage.is_empty() {
781-
return Ok(hnsw);
784+
return Ok(Self {
785+
inner: Arc::new(inner),
786+
});
782787
}
783788

784789
let len = storage.len();
785-
hnsw.inner.level_count[0].fetch_add(1, Ordering::Relaxed);
786790
(1..len).into_par_iter().for_each_init(
787791
|| VisitedGenerator::new(len),
788792
|visited_generator, node| {
789-
hnsw.inner.insert(node as u32, visited_generator, storage);
793+
inner.insert(node as u32, visited_generator, storage);
790794
},
791795
);
796+
inner.level_count = inner.compute_level_count();
792797

793-
assert_eq!(hnsw.inner.level_count[0].load(Ordering::Relaxed), len);
798+
let hnsw = Self {
799+
inner: Arc::new(inner),
800+
};
801+
802+
assert_eq!(hnsw.inner.level_count[0], len);
794803
Ok(hnsw)
795804
}
796805

@@ -945,4 +954,19 @@ mod tests {
945954
.unwrap();
946955
assert_eq!(builder_results, loaded_results);
947956
}
957+
958+
#[test]
959+
fn test_level_offsets_match_batch_rows() {
960+
const DIM: usize = 16;
961+
const TOTAL: usize = 512;
962+
let data = generate_random_array(TOTAL * DIM);
963+
let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap();
964+
let store = FlatFloatStorage::new(fsl, DistanceType::L2);
965+
let hnsw = HNSW::index_vectors(&store, HnswBuildParams::default()).unwrap();
966+
let metadata = hnsw.metadata();
967+
let batch = hnsw.to_batch().unwrap();
968+
969+
assert_eq!(metadata.level_offsets.len(), hnsw.max_level() as usize + 1);
970+
assert_eq!(*metadata.level_offsets.last().unwrap(), batch.num_rows());
971+
}
948972
}

0 commit comments

Comments
 (0)