Skip to content

Commit 34e5a19

Browse files
authored
perf: materialize the tokens after WAND done (#5572)
this avoids copying tokens for each candidate during WAND search, these small strings could lead high memory usage if we have many partitions and the queries are with large limit. verified on a 1M dataset before: Peak RSS delta: 497.09 MiB after: Peak RSS delta: 222.34 MiB this also improves the FTS performance by ~10% for such queries --------- Signed-off-by: BubbleCal <[email protected]>
1 parent 341a599 commit 34e5a19

File tree

2 files changed

+60
-15
lines changed

2 files changed

+60
-15
lines changed

rust/lance-index/src/scalar/inverted/index.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,21 @@ pub static FTS_SCHEMA: LazyLock<SchemaRef> =
111111
static ROW_ID_SCHEMA: LazyLock<SchemaRef> =
112112
LazyLock::new(|| Arc::new(Schema::new(vec![ROW_ID_FIELD.clone()])));
113113

114+
#[derive(Debug)]
115+
struct PartitionCandidates {
116+
tokens_by_position: Vec<String>,
117+
candidates: Vec<DocCandidate>,
118+
}
119+
120+
impl PartitionCandidates {
121+
fn empty() -> Self {
122+
Self {
123+
tokens_by_position: Vec::new(),
124+
candidates: Vec::new(),
125+
}
126+
}
127+
}
128+
114129
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Default)]
115130
pub enum TokenSetFormat {
116131
Arrow,
@@ -256,19 +271,28 @@ impl InvertedIndex {
256271
.load_posting_lists(tokens.as_ref(), params.as_ref(), metrics.as_ref())
257272
.await?;
258273
if postings.is_empty() {
259-
return Ok(Vec::new());
274+
return Ok(PartitionCandidates::empty());
275+
}
276+
let mut tokens_by_position = vec![String::new(); postings.len()];
277+
for posting in &postings {
278+
let idx = posting.term_index() as usize;
279+
tokens_by_position[idx] = posting.token().to_owned();
260280
}
261281
let params = params.clone();
262282
let mask = mask.clone();
263283
let metrics = metrics.clone();
264284
spawn_cpu(move || {
265-
part.bm25_search(
285+
let candidates = part.bm25_search(
266286
params.as_ref(),
267287
operator,
268288
mask,
269289
postings,
270290
metrics.as_ref(),
271-
)
291+
)?;
292+
Ok(PartitionCandidates {
293+
tokens_by_position,
294+
candidates,
295+
})
272296
})
273297
.await
274298
}
@@ -277,14 +301,20 @@ impl InvertedIndex {
277301
let mut parts = stream::iter(parts).buffer_unordered(get_num_compute_intensive_cpus());
278302
let scorer = IndexBM25Scorer::new(self.partitions.iter().map(|part| part.as_ref()));
279303
while let Some(res) = parts.try_next().await? {
304+
if res.candidates.is_empty() {
305+
continue;
306+
}
307+
let tokens_by_position = &res.tokens_by_position;
280308
for DocCandidate {
281309
row_id,
282310
freqs,
283311
doc_length,
284-
} in res
312+
} in res.candidates
285313
{
286314
let mut score = 0.0;
287-
for (token, freq) in freqs.into_iter() {
315+
for (term_index, freq) in freqs.into_iter() {
316+
debug_assert!((term_index as usize) < tokens_by_position.len());
317+
let token = &tokens_by_position[term_index as usize];
288318
score += scorer.score(token.as_str(), freq, doc_length);
289319
}
290320
if candidates.len() < limit {

rust/lance-index/src/scalar/inverted/wand.rs

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,16 @@ impl PostingIterator {
166166
}
167167
}
168168

169+
#[inline]
170+
pub(crate) fn term_index(&self) -> u32 {
171+
self.position
172+
}
173+
174+
#[inline]
175+
pub(crate) fn token(&self) -> &str {
176+
&self.token
177+
}
178+
169179
#[inline]
170180
fn approximate_upper_bound(&self) -> f32 {
171181
self.approximate_upper_bound
@@ -293,9 +303,11 @@ impl PostingIterator {
293303
}
294304
}
295305

306+
#[derive(Debug)]
296307
pub struct DocCandidate {
297308
pub row_id: u64,
298-
pub freqs: Vec<(String, u32)>,
309+
/// (term_index, freq)
310+
pub freqs: Vec<(u32, u32)>,
299311
pub doc_length: u32,
300312
}
301313

@@ -359,7 +371,7 @@ impl<'a, S: Scorer> Wand<'a, S> {
359371
_ => {}
360372
}
361373

362-
let mut candidates = BinaryHeap::new();
374+
let mut candidates = BinaryHeap::with_capacity(std::cmp::min(limit, BLOCK_SIZE * 10));
363375
let mut num_comparisons = 0;
364376
while let Some((pivot, doc)) = self.next()? {
365377
if let Some(cur_doc) = self.cur_doc {
@@ -394,10 +406,7 @@ impl<'a, S: Scorer> Wand<'a, S> {
394406
DocInfo::Located(doc) => self.docs.num_tokens_by_row_id(doc.row_id),
395407
};
396408
let score = self.score(pivot, doc_length);
397-
let freqs = self
398-
.iter_token_freqs(pivot)
399-
.map(|(token, freq)| (token.to_owned(), freq))
400-
.collect();
409+
let freqs = self.iter_term_freqs(pivot).collect();
401410
if candidates.len() < limit {
402411
candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length)));
403412
if candidates.len() == limit {
@@ -522,10 +531,7 @@ impl<'a, S: Scorer> Wand<'a, S> {
522531
};
523532

524533
let score = self.score(max_pivot, doc_length);
525-
let freqs = self
526-
.iter_token_freqs(max_pivot)
527-
.map(|(token, freq)| (token.to_owned(), freq))
528-
.collect();
534+
let freqs = self.iter_term_freqs(max_pivot).collect();
529535

530536
if candidates.len() < limit {
531537
candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length)));
@@ -568,6 +574,15 @@ impl<'a, S: Scorer> Wand<'a, S> {
568574
})
569575
}
570576

577+
// iterate over all the preceding terms and collect the term index and frequency
578+
fn iter_term_freqs(&self, pivot: usize) -> impl Iterator<Item = (u32, u32)> + '_ {
579+
self.postings[..=pivot].iter().filter_map(|posting| {
580+
posting
581+
.doc()
582+
.map(|doc| (posting.term_index(), doc.frequency()))
583+
})
584+
}
585+
571586
// find the next doc candidate
572587
fn next(&mut self) -> Result<Option<(usize, DocInfo)>> {
573588
while let Some((pivot, max_pivot)) = self.find_pivot_term() {

0 commit comments

Comments
 (0)