Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/copilot/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
)

type CopilotClient interface {
Type() (result model.ApiType)
Chat(ctx context.Context, msg interface{}, ch chan interface{}) (err error)
Type() model.ApiType
Chat(ctx context.Context, messages []Message, ch chan<- interface{}) error
}

func NewClient(config *gorm.CopilotConfig) (CopilotClient, error) {
Expand Down
11 changes: 4 additions & 7 deletions pkg/copilot/copilot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"testing"

"github.com/lw396/ChatCopilot/internal/repository/gorm"
ollama "github.com/ollama/ollama/api"
)

func TestChat(t *testing.T) {
Expand All @@ -19,12 +18,10 @@ func TestChat(t *testing.T) {
t.Error("erorr:", err)
}

messages := []ollama.Message{
{
Role: "user",
Content: "你好",
},
}
messages := []Message{{
Role: RoleUser,
Content: "你好",
}}
ch := make(chan interface{})
err = client.Chat(context.Background(), messages, ch)
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions pkg/copilot/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package copilot

const (
RoleSystem = "system"
RoleUser = "user"
RoleAssistant = "assistant"
)

// Message represents a role-based chat message that can be
// translated into provider-specific request structures.
type Message struct {
Role string
Content string
}
43 changes: 20 additions & 23 deletions pkg/copilot/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package copilot

import (
"context"
"fmt"

"github.com/lw396/ChatCopilot/internal/model"
ollama "github.com/ollama/ollama/api"
Expand All @@ -15,41 +14,39 @@ type OllamaClient struct {
topP float32
}

func (c *OllamaClient) Type() (result model.ApiType) {
func (c *OllamaClient) Type() model.ApiType {
return model.Ollama
}

func (c *OllamaClient) Chat(ctx context.Context, msg interface{}, ch chan interface{}) (err error) {
message := msg.([]ollama.Message)
stream := true
func (c *OllamaClient) Chat(ctx context.Context, messages []Message, ch chan<- interface{}) error {
if len(messages) == 0 {
close(ch)
return nil
}

ollamaMessages := make([]ollama.Message, 0, len(messages))
for _, message := range messages {
ollamaMessages = append(ollamaMessages, ollama.Message{
Role: message.Role,
Content: message.Content,
})
}

stream := true
req := &ollama.ChatRequest{
Model: c.model,
Messages: message,
Messages: ollamaMessages,
Stream: &stream,
Options: map[string]interface{}{
"temperature": c.temperature,
"top_p": c.topP,
},
}

errCh := make(chan error)
respFunc := func(resp ollama.ChatResponse) error {
fmt.Print(resp.Message.Content)
defer close(ch)

return c.client.Chat(ctx, req, func(resp ollama.ChatResponse) error {
ch <- resp
return nil
}
go func() {
errCh <- nil
if err = c.client.Chat(ctx, req, respFunc); err != nil {
errCh <- err
return
}
defer close(ch)
}()
if err := <-errCh; err != nil {
return err
}

return
})
}
37 changes: 20 additions & 17 deletions pkg/copilot/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,46 @@ type OpenaiClient struct {
topP float32
}

func (c *OpenaiClient) Type() (result model.ApiType) {
func (c *OpenaiClient) Type() model.ApiType {
return model.Openai
}

func (c *OpenaiClient) Chat(ctx context.Context, msg interface{}, ch chan interface{}) (err error) {
// messages := msg.([]openai.ChatCompletionMessage)
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: "Hello!",
},
func (c *OpenaiClient) Chat(ctx context.Context, messages []Message, ch chan<- interface{}) error {
if len(messages) == 0 {
close(ch)
return nil
}
fmt.Println(messages)

reqMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
for _, message := range messages {
reqMessages = append(reqMessages, openai.ChatCompletionMessage{
Role: message.Role,
Content: message.Content,
})
}

stream, err := c.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
Model: c.model,
Messages: messages,
Messages: reqMessages,
Temperature: c.temperature,
TopP: c.topP,
Stream: true,
})
if err != nil {
return
return err
}
defer stream.Close()
defer close(ch)

for {
var response openai.ChatCompletionStreamResponse
response, err = stream.Recv()
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
return
return nil
}
if err != nil {
err = errors.New("Stream error: " + err.Error())
return
return fmt.Errorf("stream error: %w", err)
}

fmt.Println(response.Choices[0].Delta.Content)
ch <- response
}
}
69 changes: 16 additions & 53 deletions service/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ import (
"github.com/lw396/ChatCopilot/internal/errors"
"github.com/lw396/ChatCopilot/internal/model"
"github.com/lw396/ChatCopilot/internal/repository/gorm"
copilotpkg "github.com/lw396/ChatCopilot/pkg/copilot"
"github.com/lw396/ChatCopilot/pkg/util"
ollama "github.com/ollama/ollama/api"
"github.com/sashabaranov/go-openai"
)

func (a *Service) AddChatCopilot(ctx context.Context, req *gorm.ChatCopilot) (err error) {
Expand Down Expand Up @@ -39,7 +38,7 @@ func (a *Service) AddChatCopilot(ctx context.Context, req *gorm.ChatCopilot) (er
}

func (a *Service) GetChatTips(ctx context.Context, usrname string, ch chan interface{}) (err error) {
copilot, err := a.rep.GetChatCopilotByUsrName(ctx, usrname)
cp, err := a.rep.GetChatCopilotByUsrName(ctx, usrname)
if err != nil {
return
}
Expand All @@ -49,76 +48,40 @@ func (a *Service) GetChatTips(ctx context.Context, usrname string, ch chan inter
return
}

message, err := a.HandleMessageFormat(ctx, messages, copilot)
formattedMessages, err := a.HandleMessageFormat(ctx, messages, cp)
if err != nil {
return
}

err = a.copilot.Chat(ctx, message, ch)
if err != nil {
return
}

return
return a.copilot.Chat(ctx, formattedMessages, ch)
}

func (a *Service) HandleMessageFormat(ctx context.Context, messages []*gorm.MessageContent, copilot *gorm.ChatCopilot) (
result interface{}, err error) {
switch a.copilot.Type() {
case model.Ollama:
result = a.HandleOllamaMessage(messages, copilot)
case model.Openai:
result = a.HandleOpanaiMessage(messages, copilot)
default:
err = errors.New(errors.CodeNotSupport, "not support")
func (a *Service) HandleMessageFormat(_ context.Context, messages []*gorm.MessageContent, cp *gorm.ChatCopilot) (
[]copilotpkg.Message, error) {
if cp == nil {
return nil, errors.New(errors.CodeNotSupport, "not support")
}
return
}

func (a *Service) HandleOllamaMessage(messages []*gorm.MessageContent, copilot *gorm.ChatCopilot) (
result []ollama.Message) {
result = append(result, ollama.Message{
Role: "system",
Content: copilot.Prompt.Prompt,
formatted := make([]copilotpkg.Message, 0, len(messages)+1)
formatted = append(formatted, copilotpkg.Message{
Role: copilotpkg.RoleSystem,
Content: cp.Prompt.Prompt,
})

for _, msg := range messages {
if msg.MessageType != model.MsgTypeText {
continue
}

role := "user"
role := copilotpkg.RoleUser
if !msg.Des {
role = "assistant"
role = copilotpkg.RoleAssistant
}
result = append(result, ollama.Message{
formatted = append(formatted, copilotpkg.Message{
Role: role,
Content: msg.Content,
})
}
return
}

func (a *Service) HandleOpanaiMessage(messages []*gorm.MessageContent, copilot *gorm.ChatCopilot) (
result []openai.ChatCompletionMessage) {
result = append(result, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: copilot.Prompt.Prompt,
})

for _, msg := range messages {
if msg.MessageType != model.MsgTypeText {
continue
}

role := openai.ChatMessageRoleUser
if !msg.Des {
role = openai.ChatMessageRoleAssistant
}
result = append(result, openai.ChatCompletionMessage{
Role: role,
Content: msg.Content,
})
}
return
return formatted, nil
}