Skip to content

Commit 2e2be3e

Browse files
committed
tracking vector sampling progress for better recoverability
1 parent 9612070 commit 2e2be3e

File tree

4 files changed

+71
-18
lines changed

4 files changed

+71
-18
lines changed

index/scorch/merge.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context,
335335
docsToDrop := make([]*roaring.Bitmap, 0, len(task.Segments))
336336
mergedSegHistory := make(map[uint64]*mergedSegmentHistory, len(task.Segments))
337337

338+
var files []string
338339
for _, planSegment := range task.Segments {
339340
if segSnapshot, ok := planSegment.(*SegmentSnapshot); ok {
340341
oldMap[segSnapshot.id] = segSnapshot
@@ -350,6 +351,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context,
350351
} else {
351352
segmentsToMerge = append(segmentsToMerge, segSnapshot.segment)
352353
docsToDrop = append(docsToDrop, segSnapshot.deleted)
354+
files = append(files, persistedSeg.Path())
353355
}
354356
// track the files getting merged for unsetting the
355357
// removal ineligibility. This helps to unflip files
@@ -373,6 +375,7 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context,
373375
atomic.AddUint64(&s.stats.TotFileMergeZapBeg, 1)
374376
prevBytesReadTotal := cumulateBytesRead(segmentsToMerge)
375377

378+
fmt.Println("files", files)
376379
newDocNums, _, err := s.segPlugin.MergeEx(segmentsToMerge, docsToDrop, path,
377380
cw.cancelCh, s, s.segmentConfig)
378381
atomic.AddUint64(&s.stats.TotFileMergeZapEnd, 1)

index/scorch/persister.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,16 @@ func (s *Scorch) updateCentroidIndex(bucket *bolt.Bucket) error {
866866
defer s.rootLock.Unlock()
867867
fmt.Println("updateCentroidIndex", segmentSnapshot.segment != nil)
868868
s.centroidIndex = segmentSnapshot
869+
870+
trainBytes := bucket.Get(util.BoltTrainCompleteKey)
871+
trainComplete, err := strconv.ParseBool(string(trainBytes))
872+
if err != nil {
873+
return err
874+
}
875+
s.centroidIndex.cachedMeta.updateMeta(string(util.BoltTrainCompleteKey), trainComplete)
876+
877+
vecSamplesProcessedBytes := bucket.Get(util.BoltVecSamplesProcessedKey)
878+
s.centroidIndex.cachedMeta.updateMeta(string(util.BoltVecSamplesProcessedKey), vecSamplesProcessedBytes)
869879
return nil
870880
}
871881

index/scorch/scorch.go

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ import (
2222
"io"
2323
"os"
2424
"path/filepath"
25+
"strconv"
2526
"strings"
2627
"sync"
2728
"sync/atomic"
2829
"time"
30+
"unsafe"
2931

3032
"github.com/RoaringBitmap/roaring/v2"
3133
"github.com/blevesearch/bleve/v2/registry"
@@ -104,9 +106,10 @@ func (t ScorchErrorType) Error() string {
104106
}
105107

106108
type trainRequest struct {
107-
sample segment.Segment
108-
vecCount int
109-
ackCh chan error
109+
sample segment.Segment
110+
vecCount int
111+
train_complete bool
112+
ackCh chan error
110113
}
111114

112115
// ErrType values for ScorchError
@@ -563,9 +566,27 @@ func (s *Scorch) getInternal(key []byte) ([]byte, error) {
563566
defer s.rootLock.RUnlock()
564567
// todo: return the total number of vectors that have been processed so far in training
565568
// in cbft use that as a checkpoint to resume training for n-x samples.
566-
if string(key) == "_centroid_index_complete" {
567-
return []byte(fmt.Sprintf("%t", s.centroidIndex != nil)), nil
569+
if s.centroidIndex != nil {
570+
switch string(key) {
571+
case string(util.BoltTrainCompleteKey):
572+
val := s.centroidIndex.cachedMeta.fetchMeta(string(util.BoltTrainCompleteKey))
573+
if val != nil {
574+
return []byte(fmt.Sprintf("%t", val)), nil
575+
}
576+
577+
case string(util.BoltVecSamplesProcessedKey):
578+
val := s.centroidIndex.cachedMeta.fetchMeta(string(util.BoltVecSamplesProcessedKey))
579+
if val != nil {
580+
trainingSampleProcessed, ok := val.([]byte)
581+
if ok {
582+
return trainingSampleProcessed, nil
583+
}
584+
}
585+
}
586+
} else {
587+
fmt.Println("centroid index is nil")
568588
}
589+
569590
return nil, nil
570591
}
571592

@@ -578,6 +599,10 @@ func moveFile(sourcePath, destPath string) error {
578599
return nil
579600
}
580601

602+
func boolToByte(b bool) byte {
603+
return *(*byte)(unsafe.Pointer(&b))
604+
}
605+
581606
// this is not a routine that will be running throughout the lifetime of the index. It's purpose
582607
// is to only train the vector index before the data ingestion starts.
583608
func (s *Scorch) trainerLoop() {
@@ -614,9 +639,7 @@ func (s *Scorch) trainerLoop() {
614639
} else {
615640
// merge the new segment with the existing one, no need to persist?
616641
// persist in a tmp file and then rename - is that a fair strategy?
617-
fmt.Println("merging centroid index")
618642
s.segmentConfig["training"] = true
619-
fmt.Println("version while merging", s.segPlugin.Version())
620643
_, _, err := s.segPlugin.MergeEx([]segment.Segment{s.centroidIndex.segment, sampleSeg},
621644
[]*roaring.Bitmap{nil, nil}, filepath.Join(s.path, "centroid_index.tmp"), s.closeCh, nil, s.segmentConfig)
622645
if err != nil {
@@ -670,6 +693,13 @@ func (s *Scorch) trainerLoop() {
670693
return
671694
}
672695

696+
err = centroidBucket.Put(util.BoltTrainCompleteKey, []byte{boolToByte(trainReq.train_complete)})
697+
if err != nil {
698+
trainReq.ackCh <- fmt.Errorf("error updating centroid index complete: %v", err)
699+
close(trainReq.ackCh)
700+
return
701+
}
702+
673703
// total number of vectors that have been processed so far for the training
674704
n := binary.PutUvarint(buf, uint64(totalSamplesProcessed))
675705
err = centroidBucket.Put(util.BoltVecSamplesProcessedKey, buf[:n])
@@ -686,8 +716,14 @@ func (s *Scorch) trainerLoop() {
686716
return
687717
}
688718

719+
err = s.rootBolt.Sync()
720+
if err != nil {
721+
trainReq.ackCh <- fmt.Errorf("error committing bolt transaction: %v", err)
722+
close(trainReq.ackCh)
723+
return
724+
}
725+
689726
// update the centroid index pointer
690-
fmt.Println("version", s.segPlugin.Version())
691727
centroidIndex, err := s.segPlugin.OpenEx(filepath.Join(s.path, "centroid_index"), s.segmentConfig)
692728
if err != nil {
693729
trainReq.ackCh <- fmt.Errorf("error opening centroid index: %v", err)
@@ -706,6 +742,8 @@ func (s *Scorch) Train(batch *index.Batch) error {
706742
// regulate the Train function
707743
s.FireIndexEvent()
708744

745+
// batch.InternalOps
746+
709747
var trainData []index.Document
710748
for key, doc := range batch.IndexOps {
711749
if doc != nil {
@@ -732,24 +770,26 @@ func (s *Scorch) Train(batch *index.Batch) error {
732770
return err
733771
}
734772

773+
train_complete := false
774+
if batch.InternalOps[string(util.BoltTrainCompleteKey)] != nil {
775+
train_complete, err = strconv.ParseBool(string(batch.InternalOps[string(util.BoltTrainCompleteKey)]))
776+
if err != nil {
777+
return err
778+
}
779+
}
735780
trainReq := &trainRequest{
736-
sample: seg,
737-
vecCount: len(trainData), // todo: multivector support
738-
ackCh: make(chan error),
781+
sample: seg,
782+
vecCount: len(trainData), // todo: multivector support
783+
train_complete: train_complete,
784+
ackCh: make(chan error),
739785
}
740786

741787
s.train <- trainReq
742788
err = <-trainReq.ackCh
743789
if err != nil {
744-
fmt.Println("error training", err)
745790
return err
746791
}
747-
fmt.Println("got centroid index")
748792

749-
_, err = s.getCentroidIndex("emb")
750-
if err != nil {
751-
return err
752-
}
753793
fmt.Println("number of bytes written to centroid index", n)
754794
return err
755795
}
@@ -761,7 +801,6 @@ func (s *Scorch) getCentroidIndex(field string) (*faiss.IndexImpl, error) {
761801
return nil, fmt.Errorf("segment is not a centroid index segment", s.centroidIndex.segment != nil)
762802
}
763803

764-
fmt.Println("getting coarse quantizer", field)
765804
coarseQuantizer, err := centroidIndexSegment.GetCoarseQuantizer(field)
766805
if err != nil {
767806
return nil, err

util/keys.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ var (
1919
BoltSnapshotsBucket = []byte{'s'}
2020
BoltCentroidIndexKey = []byte{'c'}
2121
BoltVecSamplesProcessedKey = []byte{'v'}
22+
BoltTrainCompleteKey = []byte{'t'}
2223
BoltPathKey = []byte{'p'}
2324
BoltDeletedKey = []byte{'d'}
2425
BoltInternalKey = []byte{'i'}

0 commit comments

Comments
 (0)