Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cmd/server-bot/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func main() {
return
}

ragEmbedder := gemini_embedding.GeminiEmbeddingFunc(aiClient, "text-embedding-004")
ragEmbedder := gemini_embedding.GeminiEmbeddingFunc(aiClient, "gemini-embedding-001")
ragL, err := rag.New(logger, "bot-context/", ragEmbedder)
if err != nil {
logger.Error("failed to create RAG logic", slog.String("err", err.Error()))
Expand All @@ -152,16 +152,26 @@ func main() {

stopCh := make(chan os.Signal, 1)
signal.Notify(stopCh, os.Interrupt, os.Kill)
finishedCh := make(chan struct{}) // Signal the end of the graceful shutdown
go func() {
<-stopCh
logger.Info("shutting down server")
err := srv.Stop(context.Background())
if err != nil {
logger.Error("failed to stop server", slog.String("err", err.Error()))
}

logger.Info("flushing RAG database")
err = ragL.Close()
if err != nil {
logger.Error("failed to persist the database", slog.String("err", err.Error()))
}
close(finishedCh)
}()

logger.Info("starting server", slog.String("addr", addr))
if err := srv.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("server failed", slog.String("err", err.Error()))
}
<-finishedCh
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/go-telegram/bot v1.17.0
github.com/mark3labs/mcp-go v0.29.1-0.20250521213157-f99e5472f312
github.com/philippgille/chromem-go v0.7.0
google.golang.org/genai v1.5.0
google.golang.org/genai v1.33.0
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genai v1.5.0 h1:6wB3MCW4JpCMHURJH2gBNxCU/9iN1YjKYQj362mDTbY=
google.golang.org/genai v1.5.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY=
google.golang.org/genai v1.33.0 h1:DExzJZbSbxSRmwX2gCsZ+V9vb6rjdmsOAy47ASBgKvg=
google.golang.org/genai v1.33.0/go.mod h1:7pAilaICJlQBonjKKJNhftDFv3SREhZcTe9F6nRcjbg=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo=
Expand Down
2 changes: 1 addition & 1 deletion internal/ai/gemini/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func New(logger *slog.Logger, client *genai.Client, model string, history histor
// Conversion for schema
b, _ := t.InputSchema.MarshalJSON()
convSchema := &genai.Schema{}
schemaErr := convSchema.UnmarshalJSON(b)
schemaErr := json.Unmarshal(b, convSchema)
if schemaErr != nil {
slog.Error("Failed to unmarshal parameter schema", "error", schemaErr)

Expand Down
70 changes: 52 additions & 18 deletions internal/rag/rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ import (
"io/fs"
"log/slog"
"os"
"path"
"strings"

"github.com/philippgille/chromem-go"
)

const collectionKey = "rag-content"
const (
collectionKey = "rag-content"
dbSaveName = "database.db"
)

// Logic .
type Logic struct {
Expand All @@ -20,67 +25,96 @@ type Logic struct {

db *chromem.DB // Database for RAG content
embeddedDocs int
embedFn chromem.EmbeddingFunc
}

// New .
func New(logger *slog.Logger, ragPath string, embedder chromem.EmbeddingFunc) (*Logic, error) {

// TODO: we could add an export/import to persist the db
db := chromem.NewDB()
_, err := db.CreateCollection(collectionKey, nil, embedder) // Just to make sure the collection exists
logger.Info("try to load RAG db")
db, err := loadSavedDB(ragPath)
if err != nil {
logger.Error("failed to create RAG collection", slog.String("collection", collectionKey), slog.String("error", err.Error()))
// Doesn't exist or failed, recreate it
logger.Info("init new RAG db")
db = chromem.NewDB()
_, err := db.CreateCollection(collectionKey, nil, embedder) // Just to make sure the collection exists
if err != nil {
logger.Error("failed to create RAG collection", slog.String("collection", collectionKey), slog.String("error", err.Error()))

return nil, err
return nil, err
}
}

l := &Logic{
logger: logger,
ragPath: ragPath,

db: db,
db: db,
embedFn: embedder,
}

return l, l.loadContent()
}

func (l *Logic) Close() error {
return l.db.ExportToFile(path.Join(l.ragPath, dbSaveName), true, "")
}

func loadSavedDB(dbPath string) (*chromem.DB, error) {
db := chromem.NewDB()
err := db.ImportFromFile(path.Join(dbPath, dbSaveName), "")
if err != nil {
return nil, err
}

return db, nil
}

func (l *Logic) loadContent() error {
l.logger.Info("started loading rag content", slog.String("path", l.ragPath))
dir := os.DirFS(l.ragPath)

coll := l.db.GetCollection(collectionKey, nil) // we can leave the embeddingFunc since it was already set during creation
// we need to set embed function since we might have loaded an existing db
coll := l.db.GetCollection(collectionKey, l.embedFn)

ctx := context.Background()
id := 1
err := fs.WalkDir(dir, ".", func(path string, d fs.DirEntry, err error) error {
err := fs.WalkDir(dir, ".", func(fName string, d fs.DirEntry, err error) error {
if err != nil {
return fmt.Errorf("failed to walk dir %s%s: %w", dir, path, err)
return fmt.Errorf("failed to walk dir %s%s: %w", dir, fName, err)
}
if d.IsDir() || d.Name() == ".gitkeep" {
return nil // skip directories and .gitkeep files
if d.IsDir() || d.Name() == ".gitkeep" || strings.HasSuffix(fName, ".loaded") || d.Name() == dbSaveName {
return nil // skip directories, .gitkeep, loaded files and the db file itself
}

l.logger.Info("loading rag content", slog.String("path", path))
l.logger.Info("loading rag content", slog.String("file", fName))

fmt.Println(path)
f, err := dir.Open(path)
f, err := dir.Open(fName)
if err != nil {
l.logger.Error("failed to open rag content file", slog.String("path", path), slog.String("error", err.Error()))
l.logger.Error("failed to open rag content file", slog.String("file", fName), slog.String("error", err.Error()))

return err
}
defer func() { _ = f.Close() }()

b, err := io.ReadAll(f)
if err != nil {
l.logger.Error("failed to read rag content file", slog.String("path", path), slog.String("error", err.Error()))
l.logger.Error("failed to read rag content file", slog.String("file", fName), slog.String("error", err.Error()))

return err
}
err = coll.AddDocument(ctx, chromem.Document{
ID: fmt.Sprintf("%d", id),
Content: string(b),
})
if err != nil {
return err
}

fullPath := path.Join(l.ragPath, fName)
err = os.Rename(fullPath, fullPath+".loaded")
if err != nil {
return err
}

id += 1
return err
Expand Down