diff --git a/rag/collection.go b/rag/collection.go index 797a8ab..30771e9 100644 --- a/rag/collection.go +++ b/rag/collection.go @@ -26,7 +26,7 @@ func NewPersistentChromeCollection(llmClient *openai.Client, collectionName, dbP filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)), filePath, chromemDB, - maxChunkSize) + maxChunkSize, llmClient, embeddingModel) if err != nil { xlog.Error("Failed to create PersistentKB", err) os.Exit(1) @@ -44,7 +44,7 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)), filePath, ragDB, - maxChunkSize) + maxChunkSize, llmClient, embeddingModel) if err != nil { xlog.Error("Failed to create PersistentKB", err) os.Exit(1) diff --git a/rag/engine.go b/rag/engine.go index ae34077..6887aa1 100644 --- a/rag/engine.go +++ b/rag/engine.go @@ -8,6 +8,7 @@ import ( type Engine interface { Store(s string, metadata map[string]string) (engine.Result, error) StoreDocuments(s []string, metadata map[string]string) ([]engine.Result, error) + GetEmbeddingDimensions() (int, error) Reset() error Search(s string, similarEntries int) ([]types.Result, error) Count() int diff --git a/rag/engine/chromem.go b/rag/engine/chromem.go index 015834e..a80311f 100644 --- a/rag/engine/chromem.go +++ b/rag/engine/chromem.go @@ -39,6 +39,11 @@ func NewChromemDBCollection(collection, path string, openaiClient *openai.Client } chromem.collection = c + count := c.Count() + if count > 0 { + chromem.index = count + 1 + } + return chromem, nil } @@ -59,6 +64,20 @@ func (c *ChromemDB) Reset() error { return nil } +func (c *ChromemDB) GetEmbeddingDimensions() (int, error) { + count := c.collection.Count() + if count == 0 { + return 0, fmt.Errorf("no documents in collection") + } + + doc, err := c.collection.GetByID(context.Background(), fmt.Sprint(count)) + if err != nil { + return 0, fmt.Errorf("error getting document: %v", err) + } + + return len(doc.Embedding), nil +} + func (c *ChromemDB) embedding() chromem.EmbeddingFunc { return chromem.EmbeddingFunc( func(ctx context.Context, text string) ([]float32, error) { diff --git a/rag/engine/localai.go b/rag/engine/localai.go index ae45d10..0fafe0f 100644 --- a/rag/engine/localai.go +++ b/rag/engine/localai.go @@ -31,6 +31,10 @@ func (db *LocalAIRAGDB) Count() int { return 0 } +func (db *LocalAIRAGDB) GetEmbeddingDimensions() (int, error) { + return 0, fmt.Errorf("not implemented") +} + func (db *LocalAIRAGDB) StoreDocuments(s []string, metadata map[string]string) ([]Result, error) { results := []Result{} for _, content := range s { diff --git a/rag/persistency.go b/rag/persistency.go index 9462fe8..5ea2a20 100644 --- a/rag/persistency.go +++ b/rag/persistency.go @@ -2,6 +2,7 @@ package rag import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -15,6 +16,7 @@ import ( "github.com/mudler/localrecall/pkg/xlog" "github.com/mudler/localrecall/rag/engine" "github.com/mudler/localrecall/rag/types" + "github.com/sashabaranov/go-openai" ) // CollectionState represents the persistent state of a collection @@ -55,7 +57,7 @@ func loadDB(path string) (*CollectionState, error) { return state, nil } -func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChunkSize int) (*PersistentKB, error) { +func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChunkSize int, llmClient *openai.Client, embeddingModel string) (*PersistentKB, error) { // if file exists, try to load an existing state // if file does not exist, create a new state if err := os.MkdirAll(assetDir, 0755); err != nil { @@ -89,6 +91,24 @@ func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChun index: state.Index, } + // TODO: Automatically repopulate if embeddings dimensions are mismatching. + // To check if dimensions are mismatching, we can check the number of dimensions of the first embedding in the index if is the same as the + // dimension that the embedding model returns. + resp, err := llmClient.CreateEmbeddings(context.Background(), + openai.EmbeddingRequestStrings{ + Input: []string{"test"}, + Model: openai.EmbeddingModel(embeddingModel), + }, + ) + if err == nil && len(resp.Data) > 0 { + embedding := resp.Data[0].Embedding + embeddingDimensions, err := db.Engine.GetEmbeddingDimensions() + if err == nil && len(embedding) != embeddingDimensions { + xlog.Info("Embedding dimensions mismatch, repopulating", "embeddingDimensions", embeddingDimensions, "embedding", embedding) + return db, db.Repopulate() + } + } + return db, nil } diff --git a/test/e2e/persistency_test.go b/test/e2e/persistency_test.go index 82d8773..1395173 100644 --- a/test/e2e/persistency_test.go +++ b/test/e2e/persistency_test.go @@ -40,7 +40,7 @@ var _ = Describe("Persistency", func() { Expect(err).To(BeNil()) // Create new PersistentKB - kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize) + kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize, localAI, EmbeddingModel) Expect(err).To(BeNil()) }) diff --git a/test/e2e/source_manager_test.go b/test/e2e/source_manager_test.go index 57bbf31..5db3dd7 100644 --- a/test/e2e/source_manager_test.go +++ b/test/e2e/source_manager_test.go @@ -43,7 +43,7 @@ var _ = Describe("SourceManager", func() { Expect(err).To(BeNil()) // Create new PersistentKB - kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize) + kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize, localAI, EmbeddingModel) Expect(err).To(BeNil()) // Create source manager