diff --git a/pkg/copilot/copilot.go b/pkg/copilot/copilot.go index 33118fc..2b84979 100644 --- a/pkg/copilot/copilot.go +++ b/pkg/copilot/copilot.go @@ -14,8 +14,8 @@ import ( ) type CopilotClient interface { - Type() (result model.ApiType) - Chat(ctx context.Context, msg interface{}, ch chan interface{}) (err error) + Type() model.ApiType + Chat(ctx context.Context, messages []Message, ch chan<- interface{}) error } func NewClient(config *gorm.CopilotConfig) (CopilotClient, error) { diff --git a/pkg/copilot/copilot_test.go b/pkg/copilot/copilot_test.go index fc9956c..ce0f4f8 100644 --- a/pkg/copilot/copilot_test.go +++ b/pkg/copilot/copilot_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/lw396/ChatCopilot/internal/repository/gorm" - ollama "github.com/ollama/ollama/api" ) func TestChat(t *testing.T) { @@ -19,12 +18,10 @@ func TestChat(t *testing.T) { t.Error("erorr:", err) } - messages := []ollama.Message{ - { - Role: "user", - Content: "你好", - }, - } + messages := []Message{{ + Role: RoleUser, + Content: "你好", + }} ch := make(chan interface{}) err = client.Chat(context.Background(), messages, ch) if err != nil { diff --git a/pkg/copilot/message.go b/pkg/copilot/message.go new file mode 100644 index 0000000..5de1079 --- /dev/null +++ b/pkg/copilot/message.go @@ -0,0 +1,14 @@ +package copilot + +const ( + RoleSystem = "system" + RoleUser = "user" + RoleAssistant = "assistant" +) + +// Message represents a role-based chat message that can be +// translated into provider-specific request structures. +type Message struct { + Role string + Content string +} diff --git a/pkg/copilot/ollama.go b/pkg/copilot/ollama.go index 67ce18c..b654410 100644 --- a/pkg/copilot/ollama.go +++ b/pkg/copilot/ollama.go @@ -2,7 +2,6 @@ package copilot import ( "context" - "fmt" "github.com/lw396/ChatCopilot/internal/model" ollama "github.com/ollama/ollama/api" @@ -15,17 +14,28 @@ type OllamaClient struct { topP float32 } -func (c *OllamaClient) Type() (result model.ApiType) { +func (c *OllamaClient) Type() model.ApiType { return model.Ollama } -func (c *OllamaClient) Chat(ctx context.Context, msg interface{}, ch chan interface{}) (err error) { - message := msg.([]ollama.Message) - stream := true +func (c *OllamaClient) Chat(ctx context.Context, messages []Message, ch chan<- interface{}) error { + if len(messages) == 0 { + close(ch) + return nil + } + + ollamaMessages := make([]ollama.Message, 0, len(messages)) + for _, message := range messages { + ollamaMessages = append(ollamaMessages, ollama.Message{ + Role: message.Role, + Content: message.Content, + }) + } + stream := true req := &ollama.ChatRequest{ Model: c.model, - Messages: message, + Messages: ollamaMessages, Stream: &stream, Options: map[string]interface{}{ "temperature": c.temperature, @@ -33,23 +43,10 @@ func (c *OllamaClient) Chat(ctx context.Context, msg interface{}, ch chan interf }, } - errCh := make(chan error) - respFunc := func(resp ollama.ChatResponse) error { - fmt.Print(resp.Message.Content) + defer close(ch) + + return c.client.Chat(ctx, req, func(resp ollama.ChatResponse) error { ch <- resp return nil - } - go func() { - errCh <- nil - if err = c.client.Chat(ctx, req, respFunc); err != nil { - errCh <- err - return - } - defer close(ch) - }() - if err := <-errCh; err != nil { - return err - } - - return + }) } diff --git a/pkg/copilot/openai.go b/pkg/copilot/openai.go index 5fc5954..8ba778d 100644 --- a/pkg/copilot/openai.go +++ b/pkg/copilot/openai.go @@ -17,43 +17,46 @@ type OpenaiClient struct { topP float32 } -func (c *OpenaiClient) Type() (result model.ApiType) { +func (c *OpenaiClient) Type() model.ApiType { return model.Openai } -func (c *OpenaiClient) Chat(ctx context.Context, msg interface{}, ch chan interface{}) (err error) { - // messages := msg.([]openai.ChatCompletionMessage) - messages := []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: "Hello!", - }, +func (c *OpenaiClient) Chat(ctx context.Context, messages []Message, ch chan<- interface{}) error { + if len(messages) == 0 { + close(ch) + return nil } - fmt.Println(messages) + + reqMessages := make([]openai.ChatCompletionMessage, 0, len(messages)) + for _, message := range messages { + reqMessages = append(reqMessages, openai.ChatCompletionMessage{ + Role: message.Role, + Content: message.Content, + }) + } + stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{ Model: c.model, - Messages: messages, + Messages: reqMessages, Temperature: c.temperature, TopP: c.topP, Stream: true, }) if err != nil { - return + return err } defer stream.Close() + defer close(ch) for { - var response openai.ChatCompletionStreamResponse - response, err = stream.Recv() + response, err := stream.Recv() if errors.Is(err, io.EOF) { - return + return nil } if err != nil { - err = errors.New("Stream error: " + err.Error()) - return + return fmt.Errorf("stream error: %w", err) } - fmt.Println(response.Choices[0].Delta.Content) ch <- response } } diff --git a/service/copilot.go b/service/copilot.go index 73ccd7f..882cedc 100644 --- a/service/copilot.go +++ b/service/copilot.go @@ -7,9 +7,8 @@ import ( "github.com/lw396/ChatCopilot/internal/errors" "github.com/lw396/ChatCopilot/internal/model" "github.com/lw396/ChatCopilot/internal/repository/gorm" + copilotpkg "github.com/lw396/ChatCopilot/pkg/copilot" "github.com/lw396/ChatCopilot/pkg/util" - ollama "github.com/ollama/ollama/api" - "github.com/sashabaranov/go-openai" ) func (a *Service) AddChatCopilot(ctx context.Context, req *gorm.ChatCopilot) (err error) { @@ -39,7 +38,7 @@ func (a *Service) AddChatCopilot(ctx context.Context, req *gorm.ChatCopilot) (er } func (a *Service) GetChatTips(ctx context.Context, usrname string, ch chan interface{}) (err error) { - copilot, err := a.rep.GetChatCopilotByUsrName(ctx, usrname) + cp, err := a.rep.GetChatCopilotByUsrName(ctx, usrname) if err != nil { return } @@ -49,37 +48,24 @@ func (a *Service) GetChatTips(ctx context.Context, usrname string, ch chan inter return } - message, err := a.HandleMessageFormat(ctx, messages, copilot) + formattedMessages, err := a.HandleMessageFormat(ctx, messages, cp) if err != nil { return } - err = a.copilot.Chat(ctx, message, ch) - if err != nil { - return - } - - return + return a.copilot.Chat(ctx, formattedMessages, ch) } -func (a *Service) HandleMessageFormat(ctx context.Context, messages []*gorm.MessageContent, copilot *gorm.ChatCopilot) ( - result interface{}, err error) { - switch a.copilot.Type() { - case model.Ollama: - result = a.HandleOllamaMessage(messages, copilot) - case model.Openai: - result = a.HandleOpanaiMessage(messages, copilot) - default: - err = errors.New(errors.CodeNotSupport, "not support") +func (a *Service) HandleMessageFormat(_ context.Context, messages []*gorm.MessageContent, cp *gorm.ChatCopilot) ( + []copilotpkg.Message, error) { + if cp == nil { + return nil, errors.New(errors.CodeNotSupport, "not support") } - return -} -func (a *Service) HandleOllamaMessage(messages []*gorm.MessageContent, copilot *gorm.ChatCopilot) ( - result []ollama.Message) { - result = append(result, ollama.Message{ - Role: "system", - Content: copilot.Prompt.Prompt, + formatted := make([]copilotpkg.Message, 0, len(messages)+1) + formatted = append(formatted, copilotpkg.Message{ + Role: copilotpkg.RoleSystem, + Content: cp.Prompt.Prompt, }) for _, msg := range messages { @@ -87,38 +73,15 @@ func (a *Service) HandleOllamaMessage(messages []*gorm.MessageContent, copilot * continue } - role := "user" + role := copilotpkg.RoleUser if !msg.Des { - role = "assistant" + role = copilotpkg.RoleAssistant } - result = append(result, ollama.Message{ + formatted = append(formatted, copilotpkg.Message{ Role: role, Content: msg.Content, }) } - return -} - -func (a *Service) HandleOpanaiMessage(messages []*gorm.MessageContent, copilot *gorm.ChatCopilot) ( - result []openai.ChatCompletionMessage) { - result = append(result, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleSystem, - Content: copilot.Prompt.Prompt, - }) - for _, msg := range messages { - if msg.MessageType != model.MsgTypeText { - continue - } - - role := openai.ChatMessageRoleUser - if !msg.Des { - role = openai.ChatMessageRoleAssistant - } - result = append(result, openai.ChatCompletionMessage{ - Role: role, - Content: msg.Content, - }) - } - return + return formatted, nil }