Skip to content

Commit 5352673

Browse files
committed
impl update rtree
Change-Id: I27e30cfa84def7193a0df6aceefb5bae8eec5b3a
1 parent 5afb6d0 commit 5352673

File tree

3 files changed

+163
-86
lines changed

3 files changed

+163
-86
lines changed

rust/lance-index/src/scalar.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,22 @@ pub trait IndexReader: Send + Sync {
209209
pub struct IndexReaderStream {
210210
reader: Arc<dyn IndexReader>,
211211
batch_size: u64,
212-
num_batches: u32,
213-
batch_idx: u32,
212+
offset: u64,
213+
limit: u64,
214214
}
215215

216216
impl IndexReaderStream {
217217
async fn new(reader: Arc<dyn IndexReader>, batch_size: u64) -> Self {
218-
let num_batches = reader.num_batches(batch_size).await;
218+
let limit = reader.num_rows() as u64;
219+
Self::new_with_limit(reader, batch_size, limit).await
220+
}
221+
222+
async fn new_with_limit(reader: Arc<dyn IndexReader>, batch_size: u64, limit: u64) -> Self {
219223
Self {
220224
reader,
221225
batch_size,
222-
num_batches,
223-
batch_idx: 0,
226+
offset: 0,
227+
limit,
224228
}
225229
}
226230
}
@@ -233,16 +237,17 @@ impl Stream for IndexReaderStream {
233237
_cx: &mut std::task::Context<'_>,
234238
) -> std::task::Poll<Option<Self::Item>> {
235239
let this = self.get_mut();
236-
if this.batch_idx >= this.num_batches {
240+
if this.offset >= this.limit {
237241
return std::task::Poll::Ready(None);
238242
}
239-
let batch_num = this.batch_idx;
240-
this.batch_idx += 1;
243+
let read_start = this.offset;
244+
let read_end = this.limit.min(this.offset + this.batch_size);
245+
this.offset = read_end;
241246
let reader_copy = this.reader.clone();
242-
let batch_size = this.batch_size;
247+
243248
let read_task = async move {
244249
reader_copy
245-
.read_record_batch(batch_num as u64, batch_size)
250+
.read_range(read_start as usize..read_end as usize, None)
246251
.await
247252
}
248253
.boxed();

rust/lance-index/src/scalar/btree.rs

Lines changed: 4 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ use std::{
1111
};
1212

1313
use super::{
14-
flat::FlatIndexMetadata, AnyQuery, BuiltinIndexType, IndexReader, IndexStore, IndexWriter,
15-
MetricsCollector, SargableQuery, ScalarIndex, ScalarIndexParams, SearchResult,
14+
flat::FlatIndexMetadata, AnyQuery, BuiltinIndexType, IndexReader, IndexReaderStream,
15+
IndexStore, IndexWriter, MetricsCollector, SargableQuery, ScalarIndex, ScalarIndexParams,
16+
SearchResult,
1617
};
1718
use crate::pbold;
1819
use crate::{
@@ -36,9 +37,8 @@ use datafusion_common::{DataFusionError, ScalarValue};
3637
use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr};
3738
use deepsize::DeepSizeOf;
3839
use futures::{
39-
future::BoxFuture,
4040
stream::{self},
41-
FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt,
41+
FutureExt, StreamExt, TryFutureExt, TryStreamExt,
4242
};
4343
use lance_core::{
4444
cache::{CacheKey, LanceCache, WeakLanceCache},
@@ -1837,53 +1837,6 @@ pub(crate) fn part_lookup_file_path(partition_id: u64) -> String {
18371837
format!("part_{}_{}", partition_id, BTREE_LOOKUP_NAME)
18381838
}
18391839

1840-
/// A stream that reads the original training data back out of the index
1841-
///
1842-
/// This is used for updating the index
1843-
struct IndexReaderStream {
1844-
reader: Arc<dyn IndexReader>,
1845-
batch_size: u64,
1846-
num_batches: u32,
1847-
batch_idx: u32,
1848-
}
1849-
1850-
impl IndexReaderStream {
1851-
async fn new(reader: Arc<dyn IndexReader>, batch_size: u64) -> Self {
1852-
let num_batches = reader.num_batches(batch_size).await;
1853-
Self {
1854-
reader,
1855-
batch_size,
1856-
num_batches,
1857-
batch_idx: 0,
1858-
}
1859-
}
1860-
}
1861-
1862-
impl Stream for IndexReaderStream {
1863-
type Item = BoxFuture<'static, Result<RecordBatch>>;
1864-
1865-
fn poll_next(
1866-
self: std::pin::Pin<&mut Self>,
1867-
_cx: &mut std::task::Context<'_>,
1868-
) -> std::task::Poll<Option<Self::Item>> {
1869-
let this = self.get_mut();
1870-
if this.batch_idx >= this.num_batches {
1871-
return std::task::Poll::Ready(None);
1872-
}
1873-
let batch_num = this.batch_idx;
1874-
this.batch_idx += 1;
1875-
let reader_copy = this.reader.clone();
1876-
let batch_size = this.batch_size;
1877-
let read_task = async move {
1878-
reader_copy
1879-
.read_record_batch(batch_num as u64, batch_size)
1880-
.await
1881-
}
1882-
.boxed();
1883-
std::task::Poll::Ready(Some(read_task))
1884-
}
1885-
}
1886-
18871840
/// Parameters for a btree index
18881841
#[derive(Debug, Serialize, Deserialize)]
18891842
pub struct BTreeParameters {

rust/lance-index/src/scalar/rtree.rs

Lines changed: 144 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,43 @@ impl RTreeIndex {
324324
};
325325
Ok(null_map)
326326
}
327+
328+
/// Create a stream of all the data in the index, in the same format used to train the index
329+
async fn into_data_stream(self) -> Result<SendableRecordBatchStream> {
330+
let reader = self.store.open_index_file(RTREE_PAGES_NAME).await?;
331+
let reader_stream = IndexReaderStream::new_with_limit(
332+
reader,
333+
self.metadata.page_size as u64,
334+
self.metadata.num_items as u64,
335+
)
336+
.await;
337+
let batches = reader_stream
338+
.map(|fut| fut.map_err(DataFusionError::from))
339+
.buffered(self.store.io_parallelism())
340+
.boxed();
341+
Ok(Box::pin(RecordBatchStreamAdapter::new(
342+
RTREE_PAGE_SCHEMA.clone(),
343+
batches,
344+
)))
345+
}
346+
347+
async fn combine_old_new(
348+
self,
349+
new_input: SendableRecordBatchStream,
350+
) -> Result<SendableRecordBatchStream> {
351+
let old_input = self.into_data_stream().await?;
352+
debug_assert_eq!(
353+
old_input.schema().flattened_fields().len(),
354+
new_input.schema().flattened_fields().len()
355+
);
356+
357+
let merged = futures::stream::select(old_input, new_input);
358+
359+
Ok(Box::pin(RecordBatchStreamAdapter::new(
360+
RTREE_PAGE_SCHEMA.clone(),
361+
merged,
362+
)))
363+
}
327364
}
328365

329366
impl DeepSizeOf for RTreeIndex {
@@ -412,11 +449,7 @@ impl Index for RTreeIndex {
412449
async fn calculate_included_frags(&self) -> Result<RoaringBitmap> {
413450
let mut frag_ids = RoaringBitmap::default();
414451

415-
let sub_index_reader = self.store.open_index_file(RTREE_PAGES_NAME).await?;
416-
let mut reader_stream =
417-
IndexReaderStream::new(sub_index_reader, self.metadata.page_size as u64)
418-
.await
419-
.buffered(self.store.io_parallelism());
452+
let mut reader_stream = self.clone().into_data_stream().await?;
420453
let mut read_rows = 0;
421454
while let Some(page) = reader_stream.try_next().await? {
422455
let mut page_frag_ids = page
@@ -501,13 +534,15 @@ impl ScalarIndex for RTreeIndex {
501534
tmpdir.obj_path(),
502535
Arc::new(LanceCache::no_cache()),
503536
));
504-
let (bbox_data, analyze) = RTreeIndexPlugin::process_and_analyze_bbox_stream(
537+
let (new_bbox_data, analyze) = RTreeIndexPlugin::process_and_analyze_bbox_stream(
505538
bbox_data,
506539
self.metadata.page_size,
507540
spill_store.clone(),
508541
)
509542
.await?;
510543

544+
let merged_bbox_data = self.clone().combine_old_new(new_bbox_data).await?;
545+
511546
let null_map = self.search_null(&NoOpMetricsCollector).await?;
512547

513548
let mut new_bbox = BoundingBox::new();
@@ -521,7 +556,7 @@ impl ScalarIndex for RTreeIndex {
521556
};
522557

523558
RTreeIndexPlugin::train_rtree_index(
524-
bbox_data,
559+
merged_bbox_data,
525560
merge_analyze,
526561
self.metadata.page_size,
527562
dest_store,
@@ -946,6 +981,8 @@ async fn train_rtree_page(
946981
mod tests {
947982
use super::*;
948983
use crate::metrics::NoOpMetricsCollector;
984+
use crate::scalar::registry::VALUE_COLUMN_NAME;
985+
use arrow_array::ArrayRef;
949986
use arrow_schema::Schema;
950987
use geo_types::{coord, Rect};
951988
use geoarrow_array::builder::{PointBuilder, RectBuilder};
@@ -975,6 +1012,23 @@ mod tests {
9751012
expected_page_offsets(num_items, page_size).len() as u64
9761013
}
9771014

1015+
fn convert_bbox_rowid_batch_stream(
1016+
geo_array: &dyn GeoArrowArray,
1017+
row_id_array: ArrayRef,
1018+
) -> SendableRecordBatchStream {
1019+
let schema = Arc::new(Schema::new(vec![
1020+
geo_array.data_type().to_field(VALUE_COLUMN_NAME, true),
1021+
ArrowField::new(ROW_ID, DataType::UInt64, false),
1022+
]));
1023+
1024+
let batch =
1025+
RecordBatch::try_new(schema.clone(), vec![geo_array.to_array_ref(), row_id_array])
1026+
.unwrap();
1027+
1028+
let stream = stream::once(async move { Ok(batch) });
1029+
Box::pin(RecordBatchStreamAdapter::new(schema, stream))
1030+
}
1031+
9781032
async fn train_index(
9791033
geo_array: &dyn GeoArrowArray,
9801034
page_size: Option<u32>,
@@ -994,24 +1048,12 @@ mod tests {
9941048
Arc::new(LanceCache::no_cache()),
9951049
));
9961050

997-
let schema = Arc::new(Schema::new(vec![
998-
geo_array.data_type().to_field("value", true),
999-
Field::new(ROW_ID, DataType::UInt64, false),
1000-
]));
1001-
1002-
let row_ids = (0..geo_array.len() as u64).collect::<Vec<_>>();
1003-
1004-
let batch = RecordBatch::try_new(
1005-
schema.clone(),
1006-
vec![
1007-
geo_array.to_array_ref(),
1008-
Arc::new(UInt64Array::from(row_ids.clone())),
1009-
],
1010-
)
1011-
.unwrap();
1012-
1013-
let stream = stream::once(async move { Ok(batch) });
1014-
let stream = Box::pin(RecordBatchStreamAdapter::new(schema, stream));
1051+
let stream = convert_bbox_rowid_batch_stream(
1052+
geo_array,
1053+
Arc::new(UInt64Array::from(
1054+
(0..geo_array.len() as u64).collect::<Vec<_>>(),
1055+
)),
1056+
);
10151057

10161058
let plugin = RTreeIndexPlugin;
10171059
plugin
@@ -1107,4 +1149,81 @@ mod tests {
11071149
]
11081150
);
11091151
}
1152+
1153+
#[tokio::test]
1154+
async fn test_update_and_search() {
1155+
let bbox_type = RectType::new(Dimension::XY, Default::default());
1156+
1157+
let page_size = 16;
1158+
let mut rect_builder = RectBuilder::new(bbox_type.clone());
1159+
let num_items = 10000;
1160+
for i in 0..num_items {
1161+
let i = i as f64;
1162+
rect_builder.push_rect(Some(&Rect::new(
1163+
coord! { x: i, y: i },
1164+
coord! { x: i + 1.0, y: i + 1.0 },
1165+
)));
1166+
}
1167+
let rect_arr = rect_builder.finish();
1168+
let (rtree_index, _tmpdir) = train_index(&rect_arr, Some(page_size)).await;
1169+
1170+
let tmpdir = TempObjDir::default();
1171+
let new_store = Arc::new(LanceIndexStore::new(
1172+
Arc::new(ObjectStore::local()),
1173+
tmpdir.clone(),
1174+
Arc::new(LanceCache::no_cache()),
1175+
));
1176+
1177+
let mut rect_builder = RectBuilder::new(bbox_type.clone());
1178+
let num_items = 10000;
1179+
for i in 0..num_items {
1180+
let i = i as f64;
1181+
rect_builder.push_rect(Some(&Rect::new(
1182+
coord! { x: i + 0.5, y: i + 0.5 },
1183+
coord! { x: i + 1.5, y: i + 1.5 },
1184+
)));
1185+
}
1186+
let new_rect_arr = rect_builder.finish();
1187+
let new_rowid_arr = (rect_arr.len() as u64..(rect_arr.len() + new_rect_arr.len()) as u64)
1188+
.collect::<Vec<_>>();
1189+
let stream = convert_bbox_rowid_batch_stream(
1190+
&new_rect_arr,
1191+
Arc::new(UInt64Array::from(new_rowid_arr.clone())),
1192+
);
1193+
rtree_index
1194+
.update(stream, new_store.as_ref())
1195+
.await
1196+
.unwrap();
1197+
1198+
let new_rtree_index = RTreeIndex::load(new_store.clone(), None, &LanceCache::no_cache())
1199+
.await
1200+
.unwrap();
1201+
1202+
let mut search_bbox = BoundingBox::new();
1203+
search_bbox.add_rect(&Rect::new(
1204+
coord! { x: 10.5, y: 1.5 },
1205+
coord! { x: 99.5, y: 200.5 },
1206+
));
1207+
let row_ids = new_rtree_index
1208+
.search_bbox(search_bbox, &NoOpMetricsCollector)
1209+
.await
1210+
.unwrap();
1211+
1212+
let mut expected_row_ids = RowIdTreeMap::new();
1213+
for i in 0..rect_arr.len() {
1214+
let bbox = BoundingBox::new_with_rect(&rect_arr.value(i).unwrap());
1215+
if search_bbox.rect_intersects(&bbox) {
1216+
expected_row_ids.insert(i as u64);
1217+
}
1218+
}
1219+
for i in 0..new_rect_arr.len() {
1220+
let bbox = BoundingBox::new_with_rect(&new_rect_arr.value(i).unwrap());
1221+
if search_bbox.rect_intersects(&bbox) {
1222+
expected_row_ids.insert(new_rowid_arr.get(i).copied().unwrap());
1223+
}
1224+
}
1225+
1226+
println!("row_ids: {:?}", row_ids);
1227+
assert_eq!(row_ids, expected_row_ids);
1228+
}
11101229
}

0 commit comments

Comments
 (0)