@@ -22,10 +22,12 @@ import (
2222 "io"
2323 "os"
2424 "path/filepath"
25+ "strconv"
2526 "strings"
2627 "sync"
2728 "sync/atomic"
2829 "time"
30+ "unsafe"
2931
3032 "github.com/RoaringBitmap/roaring/v2"
3133 "github.com/blevesearch/bleve/v2/registry"
@@ -104,9 +106,10 @@ func (t ScorchErrorType) Error() string {
104106}
105107
106108type trainRequest struct {
107- sample segment.Segment
108- vecCount int
109- ackCh chan error
109+ sample segment.Segment
110+ vecCount int
111+ train_complete bool
112+ ackCh chan error
110113}
111114
112115// ErrType values for ScorchError
@@ -563,9 +566,27 @@ func (s *Scorch) getInternal(key []byte) ([]byte, error) {
563566 defer s .rootLock .RUnlock ()
564567 // todo: return the total number of vectors that have been processed so far in training
565568 // in cbft use that as a checkpoint to resume training for n-x samples.
566- if string (key ) == "_centroid_index_complete" {
567- return []byte (fmt .Sprintf ("%t" , s .centroidIndex != nil )), nil
569+ if s .centroidIndex != nil {
570+ switch string (key ) {
571+ case string (util .BoltTrainCompleteKey ):
572+ val := s .centroidIndex .cachedMeta .fetchMeta (string (util .BoltTrainCompleteKey ))
573+ if val != nil {
574+ return []byte (fmt .Sprintf ("%t" , val )), nil
575+ }
576+
577+ case string (util .BoltVecSamplesProcessedKey ):
578+ val := s .centroidIndex .cachedMeta .fetchMeta (string (util .BoltVecSamplesProcessedKey ))
579+ if val != nil {
580+ trainingSampleProcessed , ok := val .([]byte )
581+ if ok {
582+ return trainingSampleProcessed , nil
583+ }
584+ }
585+ }
586+ } else {
587+ fmt .Println ("centroid index is nil" )
568588 }
589+
569590 return nil , nil
570591}
571592
@@ -578,6 +599,10 @@ func moveFile(sourcePath, destPath string) error {
578599 return nil
579600}
580601
602+ func boolToByte (b bool ) byte {
603+ return * (* byte )(unsafe .Pointer (& b ))
604+ }
605+
581606// this is not a routine that will be running throughout the lifetime of the index. It's purpose
582607// is to only train the vector index before the data ingestion starts.
583608func (s * Scorch ) trainerLoop () {
@@ -614,9 +639,7 @@ func (s *Scorch) trainerLoop() {
614639 } else {
615640 // merge the new segment with the existing one, no need to persist?
616641 // persist in a tmp file and then rename - is that a fair strategy?
617- fmt .Println ("merging centroid index" )
618642 s .segmentConfig ["training" ] = true
619- fmt .Println ("version while merging" , s .segPlugin .Version ())
620643 _ , _ , err := s .segPlugin .MergeEx ([]segment.Segment {s .centroidIndex .segment , sampleSeg },
621644 []* roaring.Bitmap {nil , nil }, filepath .Join (s .path , "centroid_index.tmp" ), s .closeCh , nil , s .segmentConfig )
622645 if err != nil {
@@ -670,6 +693,13 @@ func (s *Scorch) trainerLoop() {
670693 return
671694 }
672695
696+ err = centroidBucket .Put (util .BoltTrainCompleteKey , []byte {boolToByte (trainReq .train_complete )})
697+ if err != nil {
698+ trainReq .ackCh <- fmt .Errorf ("error updating centroid index complete: %v" , err )
699+ close (trainReq .ackCh )
700+ return
701+ }
702+
673703 // total number of vectors that have been processed so far for the training
674704 n := binary .PutUvarint (buf , uint64 (totalSamplesProcessed ))
675705 err = centroidBucket .Put (util .BoltVecSamplesProcessedKey , buf [:n ])
@@ -686,8 +716,14 @@ func (s *Scorch) trainerLoop() {
686716 return
687717 }
688718
719+ err = s .rootBolt .Sync ()
720+ if err != nil {
721+ trainReq .ackCh <- fmt .Errorf ("error committing bolt transaction: %v" , err )
722+ close (trainReq .ackCh )
723+ return
724+ }
725+
689726 // update the centroid index pointer
690- fmt .Println ("version" , s .segPlugin .Version ())
691727 centroidIndex , err := s .segPlugin .OpenEx (filepath .Join (s .path , "centroid_index" ), s .segmentConfig )
692728 if err != nil {
693729 trainReq .ackCh <- fmt .Errorf ("error opening centroid index: %v" , err )
@@ -706,6 +742,8 @@ func (s *Scorch) Train(batch *index.Batch) error {
706742 // regulate the Train function
707743 s .FireIndexEvent ()
708744
745+ // batch.InternalOps
746+
709747 var trainData []index.Document
710748 for key , doc := range batch .IndexOps {
711749 if doc != nil {
@@ -732,24 +770,26 @@ func (s *Scorch) Train(batch *index.Batch) error {
732770 return err
733771 }
734772
773+ train_complete := false
774+ if batch .InternalOps [string (util .BoltTrainCompleteKey )] != nil {
775+ train_complete , err = strconv .ParseBool (string (batch .InternalOps [string (util .BoltTrainCompleteKey )]))
776+ if err != nil {
777+ return err
778+ }
779+ }
735780 trainReq := & trainRequest {
736- sample : seg ,
737- vecCount : len (trainData ), // todo: multivector support
738- ackCh : make (chan error ),
781+ sample : seg ,
782+ vecCount : len (trainData ), // todo: multivector support
783+ train_complete : train_complete ,
784+ ackCh : make (chan error ),
739785 }
740786
741787 s .train <- trainReq
742788 err = <- trainReq .ackCh
743789 if err != nil {
744- fmt .Println ("error training" , err )
745790 return err
746791 }
747- fmt .Println ("got centroid index" )
748792
749- _ , err = s .getCentroidIndex ("emb" )
750- if err != nil {
751- return err
752- }
753793 fmt .Println ("number of bytes written to centroid index" , n )
754794 return err
755795}
@@ -761,7 +801,6 @@ func (s *Scorch) getCentroidIndex(field string) (*faiss.IndexImpl, error) {
761801 return nil , fmt .Errorf ("segment is not a centroid index segment" , s .centroidIndex .segment != nil )
762802 }
763803
764- fmt .Println ("getting coarse quantizer" , field )
765804 coarseQuantizer , err := centroidIndexSegment .GetCoarseQuantizer (field )
766805 if err != nil {
767806 return nil , err
0 commit comments