Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rag/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func NewPersistentChromeCollection(llmClient *openai.Client, collectionName, dbP
filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)),
filePath,
chromemDB,
maxChunkSize)
maxChunkSize, llmClient, embeddingModel)
if err != nil {
xlog.Error("Failed to create PersistentKB", err)
os.Exit(1)
Expand All @@ -44,7 +44,7 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co
filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)),
filePath,
ragDB,
maxChunkSize)
maxChunkSize, llmClient, embeddingModel)
if err != nil {
xlog.Error("Failed to create PersistentKB", err)
os.Exit(1)
Expand Down
1 change: 1 addition & 0 deletions rag/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
type Engine interface {
Store(s string, metadata map[string]string) (engine.Result, error)
StoreDocuments(s []string, metadata map[string]string) ([]engine.Result, error)
GetEmbeddingDimensions() (int, error)
Reset() error
Search(s string, similarEntries int) ([]types.Result, error)
Count() int
Expand Down
19 changes: 19 additions & 0 deletions rag/engine/chromem.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ func NewChromemDBCollection(collection, path string, openaiClient *openai.Client
}
chromem.collection = c

count := c.Count()
if count > 0 {
chromem.index = count + 1
}

return chromem, nil
}

Expand All @@ -59,6 +64,20 @@ func (c *ChromemDB) Reset() error {
return nil
}

func (c *ChromemDB) GetEmbeddingDimensions() (int, error) {
count := c.collection.Count()
if count == 0 {
return 0, fmt.Errorf("no documents in collection")
}

doc, err := c.collection.GetByID(context.Background(), fmt.Sprint(count))
if err != nil {
return 0, fmt.Errorf("error getting document: %v", err)
}

return len(doc.Embedding), nil
}

func (c *ChromemDB) embedding() chromem.EmbeddingFunc {
return chromem.EmbeddingFunc(
func(ctx context.Context, text string) ([]float32, error) {
Expand Down
4 changes: 4 additions & 0 deletions rag/engine/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ func (db *LocalAIRAGDB) Count() int {
return 0
}

func (db *LocalAIRAGDB) GetEmbeddingDimensions() (int, error) {
return 0, fmt.Errorf("not implemented")
}

func (db *LocalAIRAGDB) StoreDocuments(s []string, metadata map[string]string) ([]Result, error) {
results := []Result{}
for _, content := range s {
Expand Down
22 changes: 21 additions & 1 deletion rag/persistency.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rag

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -15,6 +16,7 @@ import (
"github.com/mudler/localrecall/pkg/xlog"
"github.com/mudler/localrecall/rag/engine"
"github.com/mudler/localrecall/rag/types"
"github.com/sashabaranov/go-openai"
)

// CollectionState represents the persistent state of a collection
Expand Down Expand Up @@ -55,7 +57,7 @@ func loadDB(path string) (*CollectionState, error) {
return state, nil
}

func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChunkSize int) (*PersistentKB, error) {
func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChunkSize int, llmClient *openai.Client, embeddingModel string) (*PersistentKB, error) {
// if file exists, try to load an existing state
// if file does not exist, create a new state
if err := os.MkdirAll(assetDir, 0755); err != nil {
Expand Down Expand Up @@ -89,6 +91,24 @@ func NewPersistentCollectionKB(stateFile, assetDir string, store Engine, maxChun
index: state.Index,
}

// TODO: Automatically repopulate if embeddings dimensions are mismatching.
// To check if dimensions are mismatching, we can check the number of dimensions of the first embedding in the index if is the same as the
// dimension that the embedding model returns.
resp, err := llmClient.CreateEmbeddings(context.Background(),
openai.EmbeddingRequestStrings{
Input: []string{"test"},
Model: openai.EmbeddingModel(embeddingModel),
},
)
if err == nil && len(resp.Data) > 0 {
embedding := resp.Data[0].Embedding
embeddingDimensions, err := db.Engine.GetEmbeddingDimensions()
if err == nil && len(embedding) != embeddingDimensions {
xlog.Info("Embedding dimensions mismatch, repopulating", "embeddingDimensions", embeddingDimensions, "embedding", embedding)
return db, db.Repopulate()
}
}

return db, nil
}

Expand Down
2 changes: 1 addition & 1 deletion test/e2e/persistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ var _ = Describe("Persistency", func() {
Expect(err).To(BeNil())

// Create new PersistentKB
kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize)
kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize, localAI, EmbeddingModel)
Expect(err).To(BeNil())
})

Expand Down
2 changes: 1 addition & 1 deletion test/e2e/source_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var _ = Describe("SourceManager", func() {
Expect(err).To(BeNil())

// Create new PersistentKB
kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize)
kb, err = rag.NewPersistentCollectionKB(stateFile, assetDir, chromemEngine, DefaultChunkSize, localAI, EmbeddingModel)
Expect(err).To(BeNil())

// Create source manager
Expand Down