diff --git a/clients/localai_client.go b/clients/localai_client.go index cab9ef7..622eb61 100644 --- a/clients/localai_client.go +++ b/clients/localai_client.go @@ -81,17 +81,17 @@ func (m *localAICompletionMessage) UnmarshalJSON(data []byte) error { // CreateChatCompletion sends the chat completion request and parses the response, // including LocalAI's optional "reasoning" field, into LLMReply.ReasoningContent. -func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, error) { +func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, cogito.LLMUsage, error) { request.Model = llm.model body, err := json.Marshal(request) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: marshal request: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: marshal request: %w", err) } url := llm.baseURL + "/chat/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: new request: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: new request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -101,21 +101,21 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open resp, err := llm.client.Do(req) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: request: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: read response: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: read response: %w", err) } if resp.StatusCode != http.StatusOK { var errRes openai.ErrorResponse if json.Unmarshal(respBody, &errRes) == nil && errRes.Error != nil { - return cogito.LLMReply{}, errRes.Error + return cogito.LLMReply{}, cogito.LLMUsage{}, errRes.Error } - return cogito.LLMReply{}, &openai.RequestError{ + return cogito.LLMReply{}, cogito.LLMUsage{}, &openai.RequestError{ HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: fmt.Errorf("localai: %s", string(respBody)), @@ -125,11 +125,11 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open var localResp localAIChatCompletionResponse if err := json.Unmarshal(respBody, &localResp); err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: unmarshal response: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: unmarshal response: %w", err) } if len(localResp.Choices) == 0 { - return cogito.LLMReply{}, fmt.Errorf("localai: no choices in response") + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: no choices in response") } choice := localResp.Choices[0] @@ -157,30 +157,36 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open // Ensure ReasoningContent is set for downstream (e.g. tools.go). response.Choices[0].Message.ReasoningContent = reasoning + usage := cogito.LLMUsage{ + PromptTokens: localResp.Usage.PromptTokens, + CompletionTokens: localResp.Usage.CompletionTokens, + TotalTokens: localResp.Usage.TotalTokens, + } + return cogito.LLMReply{ ChatCompletionResponse: response, ReasoningContent: reasoning, - }, nil + }, usage, nil } // Ask prompts the LLM with the provided messages and returns a Fragment // containing the response. Uses CreateChatCompletion so reasoning is preserved. -func (llm *LocalAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fragment, error) { +func (llm *LocalAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fragment, cogito.LLMUsage, error) { messages := f.GetMessages() request := openai.ChatCompletionRequest{ Model: llm.model, Messages: messages, } - reply, err := llm.CreateChatCompletion(ctx, request) + reply, _, err := llm.CreateChatCompletion(ctx, request) if err != nil { - return cogito.Fragment{}, err + return cogito.Fragment{}, cogito.LLMUsage{}, err } if len(reply.ChatCompletionResponse.Choices) == 0 { - return cogito.Fragment{}, fmt.Errorf("localai: no choices in response") + return cogito.Fragment{}, cogito.LLMUsage{}, fmt.Errorf("localai: no choices in response") } return cogito.Fragment{ Messages: append(f.Messages, reply.ChatCompletionResponse.Choices[0].Message), ParentFragment: &f, Status: &cogito.Status{}, - }, nil + }, cogito.LLMUsage{}, nil } diff --git a/clients/openai_client.go b/clients/openai_client.go index 4dbc69e..e1585f3 100644 --- a/clients/openai_client.go +++ b/clients/openai_client.go @@ -27,7 +27,7 @@ func NewOpenAILLM(model, apiKey, baseURL string) *OpenAIClient { // and returns a Fragment containing the response. // The Fragment.GetMessages() method automatically handles force-text-reply // when tool calls are present in the conversation history. -func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fragment, error) { +func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fragment, cogito.LLMUsage, error) { // Use Fragment.GetMessages() which automatically adds force-text-reply // system message when tool calls are detected in the conversation messages := f.GetMessages() @@ -40,27 +40,43 @@ func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fra }, ) - if err == nil && len(resp.Choices) > 0 { + if err != nil { + return cogito.Fragment{}, cogito.LLMUsage{}, err + } + + if len(resp.Choices) > 0 { + usage := cogito.LLMUsage{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } return cogito.Fragment{ Messages: append(f.Messages, resp.Choices[0].Message), ParentFragment: &f, Status: &cogito.Status{}, - }, nil + }, usage, nil } - return cogito.Fragment{}, err + return cogito.Fragment{}, cogito.LLMUsage{}, nil } -func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, error) { +func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, cogito.LLMUsage, error) { request.Model = llm.model response, err := llm.client.CreateChatCompletion(ctx, request) if err != nil { - return cogito.LLMReply{}, err + return cogito.LLMReply{}, cogito.LLMUsage{}, err } + + usage := cogito.LLMUsage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + } + return cogito.LLMReply{ ChatCompletionResponse: response, ReasoningContent: response.Choices[0].Message.ReasoningContent, - }, nil + }, usage, nil } // NewOpenAIService creates a new OpenAI service instance diff --git a/extractors.go b/extractors.go index 5568a67..dde6c15 100644 --- a/extractors.go +++ b/extractors.go @@ -68,7 +68,7 @@ func ExtractKnowledgeGaps(llm LLM, f Fragment, opts ...Option) ([]string, error) xlog.Debug("Analyzing knowledge gaps", "prompt", prompt) newFragment := NewEmptyFragment().AddMessage("system", prompt) - f, err = llm.Ask(o.context, newFragment) + f, _, err = llm.Ask(o.context, newFragment) if err != nil { return nil, err } diff --git a/fragment.go b/fragment.go index d136d6a..fd1619b 100644 --- a/fragment.go +++ b/fragment.go @@ -32,6 +32,7 @@ type InjectedMessage struct { } type Status struct { + LastUsage LLMUsage // Track token usage from the last LLM call Iterations int ToolsCalled Tools ToolResults []ToolStatus @@ -210,7 +211,7 @@ func (r Fragment) ExtractStructure(ctx context.Context, llm LLM, s structures.St }, } - resp, err := llm.CreateChatCompletion(ctx, decision) + resp, _, err := llm.CreateChatCompletion(ctx, decision) if err != nil { return err } @@ -271,7 +272,7 @@ func (f Fragment) SelectTool(ctx context.Context, llm LLM, availableTools Tools, } } - resp, err := llm.CreateChatCompletion(ctx, decision) + resp, _, err := llm.CreateChatCompletion(ctx, decision) if err != nil { return Fragment{}, nil, err } diff --git a/fragment_e2e_test.go b/fragment_e2e_test.go index c810241..474acaa 100644 --- a/fragment_e2e_test.go +++ b/fragment_e2e_test.go @@ -120,7 +120,7 @@ var _ = Describe("Result test", Label("e2e"), func() { fragment := NewEmptyFragment().AddMessage("user", "Write a short poem about the sea in less than 20 words.") - result, err := defaultLLM.Ask(context.TODO(), fragment) + result, _, err := defaultLLM.Ask(context.TODO(), fragment) Expect(err).ToNot(HaveOccurred()) @@ -156,7 +156,7 @@ var _ = Describe("Result test", Label("e2e"), func() { Content: "What's the weather today in San Francisco?", }) - newFragment, result, err := fragment.SelectTool(context.TODO(), *defaultLLM, Tools{ + newFragment, result, err := fragment.SelectTool(context.TODO(), defaultLLM, Tools{ NewToolDefinition( (&GetWeatherTool{}), WeatherArgs{}, diff --git a/goal.go b/goal.go index 833ca0c..3336a23 100644 --- a/goal.go +++ b/goal.go @@ -33,7 +33,7 @@ func ExtractGoal(llm LLM, f Fragment, opts ...Option) (*structures.Goal, error) goalConv := NewEmptyFragment().AddMessage("user", prompt) - reasoningGoal, err := llm.Ask(o.context, goalConv) + reasoningGoal, _, err := llm.Ask(o.context, goalConv) if err != nil { return nil, fmt.Errorf("failed to ask LLM for goal identification: %w", err) } @@ -91,7 +91,7 @@ func IsGoalAchieved(llm LLM, f Fragment, goal *structures.Goal, opts ...Option) } goalAchievedConv := NewEmptyFragment().AddMessage("user", prompt, multimedias...) - reasoningGoal, err := llm.Ask(o.context, goalAchievedConv) + reasoningGoal, _, err := llm.Ask(o.context, goalAchievedConv) if err != nil { return nil, fmt.Errorf("failed to ask LLM for goal identification: %w", err) } diff --git a/guidelines.go b/guidelines.go index 350485c..5f02433 100644 --- a/guidelines.go +++ b/guidelines.go @@ -70,7 +70,7 @@ func GetRelevantGuidelines(llm LLM, guidelines Guidelines, fragment Fragment, op guidelineConv := NewEmptyFragment().AddMessage("user", guidelinePrompt) - guidelineResult, err := llm.Ask(o.context, guidelineConv) + guidelineResult, _, err := llm.Ask(o.context, guidelineConv) if err != nil { return Guidelines{}, fmt.Errorf("failed to ask LLM for guidelines: %w", err) } diff --git a/llm.go b/llm.go index d2b4193..5443c3e 100644 --- a/llm.go +++ b/llm.go @@ -6,9 +6,16 @@ import ( "github.com/sashabaranov/go-openai" ) +// LLMUsage represents token usage information from an LLM response +type LLMUsage struct { + PromptTokens int + CompletionTokens int + TotalTokens int +} + type LLM interface { - Ask(ctx context.Context, f Fragment) (Fragment, error) - CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, error) + Ask(ctx context.Context, f Fragment) (Fragment, LLMUsage, error) + CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) } type LLMReply struct { diff --git a/options.go b/options.go index d9c5157..76907b5 100644 --- a/options.go +++ b/options.go @@ -63,23 +63,29 @@ type Options struct { todos *structures.TODOList messagesManipulator func([]openai.ChatCompletionMessage) []openai.ChatCompletionMessage + + // Compaction options - automatic conversation compaction based on token count + compactionThreshold int // Token count threshold that triggers compaction (0 = disabled) + compactionKeepMessages int // Number of recent messages to keep after compaction } type Option func(*Options) func defaultOptions() *Options { return &Options{ - maxIterations: 1, - maxAttempts: 1, - maxRetries: 5, - loopDetectionSteps: 0, - forceReasoning: false, - maxAdjustmentAttempts: 5, - sinkStateTool: &defaultSinkStateTool{}, - sinkState: true, - context: context.Background(), - statusCallback: func(s string) {}, - reasoningCallback: func(s string) {}, + maxIterations: 1, + maxAttempts: 1, + maxRetries: 5, + loopDetectionSteps: 0, + forceReasoning: false, + maxAdjustmentAttempts: 5, + sinkStateTool: &defaultSinkStateTool{}, + sinkState: true, + context: context.Background(), + statusCallback: func(s string) {}, + reasoningCallback: func(s string) {}, + compactionThreshold: 0, // Disabled by default + compactionKeepMessages: 10, // Keep 10 recent messages by default } } @@ -367,6 +373,24 @@ func WithMessageInjectionResultChan(ch chan MessageInjectionResult) func(o *Opti } } +// WithCompactionThreshold sets the token count threshold that triggers automatic +// conversation compaction. When total tokens in the response >= threshold, +// the conversation will be compacted to stay within the limit. +// Set to 0 (default) to disable automatic compaction. +func WithCompactionThreshold(threshold int) func(o *Options) { + return func(o *Options) { + o.compactionThreshold = threshold + } +} + +// WithCompactionKeepMessages sets the number of recent messages to keep after +// compaction. Default is 10. This only applies when WithCompactionThreshold is set. +func WithCompactionKeepMessages(count int) func(o *Options) { + return func(o *Options) { + o.compactionKeepMessages = count + } +} + type defaultSinkStateTool struct{} func (d *defaultSinkStateTool) Execute(args map[string]any) (string, any, error) { diff --git a/plan.go b/plan.go index fd11d00..dc194da 100644 --- a/plan.go +++ b/plan.go @@ -111,7 +111,7 @@ func applyPlanFromPrompt(llm LLM, o *Options, planPrompt string, feedbackConv *F multimedias = feedbackConv.Multimedia } planConv := NewEmptyFragment().AddMessage("user", planPrompt, multimedias...) - reasoningPlan, err := llm.Ask(o.context, planConv) + reasoningPlan, _, err := llm.Ask(o.context, planConv) if err != nil { return nil, fmt.Errorf("failed to ask LLM for plan identification: %w", err) } @@ -165,7 +165,7 @@ func ExtractTODOs(llm LLM, plan *structures.Plan, goal *structures.Goal, opts .. } todoConv := NewEmptyFragment().AddMessage("user", promptStr) - reasoningTodo, err := llm.Ask(o.context, todoConv) + reasoningTodo, _, err := llm.Ask(o.context, todoConv) if err != nil { return nil, fmt.Errorf("failed to ask LLM for TODO generation: %w", err) } @@ -518,7 +518,7 @@ func executeReviewPhase(reviewerLLMs []LLM, workFragment Fragment, goal *structu } // Get the reasoning from the review - reviewResult, err := reviewerLLM.Ask(o.context, reviewFragment) + reviewResult, _, err := reviewerLLM.Ask(o.context, reviewFragment) if err != nil { return NewEmptyFragment(), false, fmt.Errorf("failed to get review result: %w", err) } diff --git a/prompt/prompt.go b/prompt/prompt.go index 6006065..aa7e712 100644 --- a/prompt/prompt.go +++ b/prompt/prompt.go @@ -3,44 +3,46 @@ package prompt type PromptType uint const ( - GapAnalysisType PromptType = iota - ContentImproverType PromptType = iota - PromptBooleanType PromptType = iota - PromptIdentifyGoalType PromptType = iota - PromptGoalAchievedType PromptType = iota - PromptPlanType PromptType = iota - PromptReEvaluatePlanType PromptType = iota - PromptSubtaskExtractionType PromptType = iota - PromptPlanExecutionType PromptType = iota - PromptGuidelinesType PromptType = iota - PromptGuidelinesExtractionType PromptType = iota - PromptPlanDecisionType PromptType = iota - PromptParameterReasoningType PromptType = iota - PromptTODOGenerationType PromptType = iota - PromptTODOWorkType PromptType = iota - PromptTODOReviewType PromptType = iota - PromptTODOTrackingType PromptType = iota + GapAnalysisType PromptType = iota + ContentImproverType PromptType = iota + PromptBooleanType PromptType = iota + PromptIdentifyGoalType PromptType = iota + PromptGoalAchievedType PromptType = iota + PromptPlanType PromptType = iota + PromptReEvaluatePlanType PromptType = iota + PromptSubtaskExtractionType PromptType = iota + PromptPlanExecutionType PromptType = iota + PromptGuidelinesType PromptType = iota + PromptGuidelinesExtractionType PromptType = iota + PromptPlanDecisionType PromptType = iota + PromptParameterReasoningType PromptType = iota + PromptTODOGenerationType PromptType = iota + PromptTODOWorkType PromptType = iota + PromptTODOReviewType PromptType = iota + PromptTODOTrackingType PromptType = iota + PromptConversationCompactionType PromptType = iota ) var ( defaultPromptMap PromptMap = map[PromptType]Prompt{ - GapAnalysisType: PromptGapsAnalysis, - ContentImproverType: PromptContentImprover, - PromptBooleanType: PromptExtractBoolean, - PromptIdentifyGoalType: PromptIdentifyGoal, - PromptGoalAchievedType: PromptGoalAchieved, - PromptPlanType: PromptPlan, - PromptReEvaluatePlanType: PromptReEvaluatePlan, - PromptSubtaskExtractionType: PromptSubtaskExtraction, - PromptPlanExecutionType: PromptPlanExecution, - PromptGuidelinesType: PromptGuidelines, - PromptGuidelinesExtractionType: PromptGuidelinesExtraction, - PromptPlanDecisionType: DecideIfPlanningIsNeeded, - PromptParameterReasoningType: PromptParameterReasoning, - PromptTODOGenerationType: PromptTODOGeneration, - PromptTODOWorkType: PromptTODOWork, - PromptTODOReviewType: PromptTODOReview, - PromptTODOTrackingType: PromptTODOTracking, + GapAnalysisType: PromptGapsAnalysis, + ContentImproverType: PromptContentImprover, + PromptBooleanType: PromptExtractBoolean, + PromptIdentifyGoalType: PromptIdentifyGoal, + PromptGoalAchievedType: PromptGoalAchieved, + PromptPlanType: PromptPlan, + PromptReEvaluatePlanType: PromptReEvaluatePlan, + PromptSubtaskExtractionType: PromptSubtaskExtraction, + PromptPlanExecutionType: PromptPlanExecution, + PromptGuidelinesType: PromptGuidelines, + PromptGuidelinesExtractionType: PromptGuidelinesExtraction, + PromptPlanDecisionType: DecideIfPlanningIsNeeded, + PromptParameterReasoningType: PromptParameterReasoning, + PromptTODOGenerationType: PromptTODOGeneration, + PromptTODOWorkType: PromptTODOWork, + PromptTODOReviewType: PromptTODOReview, + PromptTODOTrackingType: PromptTODOTracking, + PromptConversationCompactionType: PromptConversationCompaction, } PromptGuidelinesExtraction = NewPrompt("What guidelines should be applied? return only the numbers of the guidelines by using the json tool with a list of integers corresponding to the guidelines.") @@ -328,4 +330,20 @@ Use the "json" tool to return an updated TODO list with: - Completed TODOs marked as completed - Any new TODOs that were identified - Updated feedback for TODOs if provided`) + + PromptConversationCompaction = NewPrompt(`You are an AI assistant that summarizes a conversation history to preserve important context while reducing token count. + +Analyze the conversation history and create a concise summary that preserves: +1. The original user request/goal +2. Key decisions and reasoning +3. Important tool results +4. Current state of the task + +Conversation History: +{{.Context}} + +Tool Results: +{{.ToolResults}} + +Provide a summary that allows continuing the task without losing critical context. Be concise but comprehensive.`) ) diff --git a/reviewer.go b/reviewer.go index 3392271..62257f3 100644 --- a/reviewer.go +++ b/reviewer.go @@ -97,5 +97,9 @@ func improveContent(llm LLM, f Fragment, refinedMessage string, gaps []string, o newFragment.ParentFragment = f.ParentFragment - return llm.Ask(o.context, newFragment) + _, _, err = llm.Ask(o.context, newFragment) + if err != nil { + return Fragment{}, err + } + return newFragment, nil } diff --git a/reviewer_e2e_test.go b/reviewer_e2e_test.go index 0c86d9f..3df1461 100644 --- a/reviewer_e2e_test.go +++ b/reviewer_e2e_test.go @@ -16,7 +16,7 @@ var _ = Describe("cogito test", Label("e2e"), func() { conv := NewEmptyFragment().AddMessage("user", "Explain how a combustion engine works in less than 100 words.") - result, err := defaultLLM.Ask(context.TODO(), conv) + result, _, err := defaultLLM.Ask(context.TODO(), conv) Expect(err).ToNot(HaveOccurred()) @@ -30,7 +30,7 @@ var _ = Describe("cogito test", Label("e2e"), func() { conv := NewEmptyFragment().AddMessage("user", "What are the latest news today?") - result, err := defaultLLM.Ask(context.TODO(), conv) + result, _, err := defaultLLM.Ask(context.TODO(), conv) Expect(err).ToNot(HaveOccurred()) Expect(result.String()).ToNot(BeEmpty()) diff --git a/tests/mock/client.go b/tests/mock/client.go index 13183d6..a6b1df2 100644 --- a/tests/mock/client.go +++ b/tests/mock/client.go @@ -19,23 +19,31 @@ type MockOpenAIClient struct { AskError error CreateChatCompletionError error FragmentHistory []Fragment + + // Token usage for responses + AskUsage []LLMUsage + AskUsageIndex int + CreateChatCompletionUsage []LLMUsage + CreateChatCompletionUsageIndex int } func NewMockOpenAIClient() *MockOpenAIClient { return &MockOpenAIClient{ AskResponses: []Fragment{}, CreateChatCompletionResponses: []openai.ChatCompletionResponse{}, + AskUsage: []LLMUsage{}, + CreateChatCompletionUsage: []LLMUsage{}, } } -func (m *MockOpenAIClient) Ask(ctx context.Context, f Fragment) (Fragment, error) { +func (m *MockOpenAIClient) Ask(ctx context.Context, f Fragment) (Fragment, LLMUsage, error) { m.FragmentHistory = append(m.FragmentHistory, f) if m.AskError != nil { - return Fragment{}, m.AskError + return Fragment{}, LLMUsage{}, m.AskError } if m.AskResponseIndex >= len(m.AskResponses) { - return Fragment{}, fmt.Errorf("no more Ask responses configured") + return Fragment{}, LLMUsage{}, fmt.Errorf("no more Ask responses configured") } response := m.AskResponses[m.AskResponseIndex] @@ -48,26 +56,41 @@ func (m *MockOpenAIClient) Ask(ctx context.Context, f Fragment) (Fragment, error response.Messages = append(f.Messages, response.Messages...) response.ParentFragment = &f - return response, nil + // Get usage if available + var usage LLMUsage + if m.AskUsageIndex < len(m.AskUsage) { + usage = m.AskUsage[m.AskUsageIndex] + m.AskUsageIndex++ + } + + return response, usage, nil } -func (m *MockOpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, error) { +func (m *MockOpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) { if m.CreateChatCompletionError != nil { - return LLMReply{}, m.CreateChatCompletionError + return LLMReply{}, LLMUsage{}, m.CreateChatCompletionError } if m.CreateChatCompletionIndex >= len(m.CreateChatCompletionResponses) { - return LLMReply{}, fmt.Errorf("no more CreateChatCompletion responses configured") + return LLMReply{}, LLMUsage{}, fmt.Errorf("no more CreateChatCompletion responses configured") } response := m.CreateChatCompletionResponses[m.CreateChatCompletionIndex] m.CreateChatCompletionIndex++ xlog.Info("CreateChatCompletion response", "response", response) + + // Get usage if available + var usage LLMUsage + if m.CreateChatCompletionUsageIndex < len(m.CreateChatCompletionUsage) { + usage = m.CreateChatCompletionUsage[m.CreateChatCompletionUsageIndex] + m.CreateChatCompletionUsageIndex++ + } + return LLMReply{ ChatCompletionResponse: response, ReasoningContent: response.Choices[0].Message.ReasoningContent, - }, nil + }, usage, nil } // Helper methods for setting up mock responses @@ -109,3 +132,14 @@ func (m *MockOpenAIClient) AddCreateChatCompletionFunction(name, args string) { func (m *MockOpenAIClient) SetCreateChatCompletionError(err error) { m.CreateChatCompletionError = err } + +// SetUsage sets token usage for the next responses +func (m *MockOpenAIClient) SetUsage(promptTokens, completionTokens, totalTokens int) { + usage := LLMUsage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: totalTokens, + } + m.AskUsage = append(m.AskUsage, usage) + m.CreateChatCompletionUsage = append(m.CreateChatCompletionUsage, usage) +} diff --git a/tools.go b/tools.go index 1fa6f89..03a0e69 100644 --- a/tools.go +++ b/tools.go @@ -203,7 +203,7 @@ func decision(ctx context.Context, llm LLM, conversation []openai.ChatCompletion var lastErr error for attempts := 0; attempts < maxRetries; attempts++ { - resp, err := llm.CreateChatCompletion(ctx, decision) + resp, _, err := llm.CreateChatCompletion(ctx, decision) if err != nil { lastErr = err xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", err) @@ -602,7 +602,7 @@ func decideToPlan(llm LLM, f Fragment, tools Tools, opts ...Option) (bool, error return false, fmt.Errorf("failed to render content improver prompt: %w", err) } - planDecision, err := llm.Ask(o.context, NewEmptyFragment().AddMessage("user", prompt)) + planDecision, _, err := llm.Ask(o.context, NewEmptyFragment().AddMessage("user", prompt)) if err != nil { return false, fmt.Errorf("failed to ask LLM for plan decision: %w", err) } @@ -886,12 +886,28 @@ TOOL_LOOP: // Preserve the status before calling Ask status := f.Status - f, err := llm.Ask(o.context, f) + f, usage, err := llm.Ask(o.context, f) if err != nil { return f, fmt.Errorf("failed to ask LLM: %w", err) } - // Restore the status + // Store usage tokens + if f.Status != nil { + f.Status.LastUsage = usage + } + // Restore the status (preserving LastUsage) + status.LastUsage = usage f.Status = status + + // Check and compact if threshold exceeded + if o.compactionThreshold > 0 { + f, compacted, err := checkAndCompact(o.context, llm, f, o.compactionThreshold, o.compactionKeepMessages, o.prompts) + if err != nil { + return f, fmt.Errorf("failed to compact: %w", err) + } + if compacted { + xlog.Debug("Fragment compacted successfully after max iterations") + } + } return f, nil } @@ -1284,14 +1300,29 @@ Please provide revised tool call based on this feedback.`, } - var err error // If sink state was found, stop execution after processing all tools if hasSinkState { xlog.Debug("Sink state was found, stopping execution after processing tools") - f, err = llm.Ask(o.context, f) + f, usage, err := llm.Ask(o.context, f) if err != nil { return f, fmt.Errorf("failed to ask LLM: %w", err) } + + // Store usage tokens for compaction check + if f.Status != nil { + f.Status.LastUsage = usage + } + + // Check and compact if threshold exceeded + if o.compactionThreshold > 0 { + f, compacted, err := checkAndCompact(o.context, llm, f, o.compactionThreshold, o.compactionKeepMessages, o.prompts) + if err != nil { + return f, fmt.Errorf("failed to compact: %w", err) + } + if compacted { + xlog.Debug("Fragment compacted successfully after sink state") + } + } } if len(f.Status.ToolsCalled) == 0 { @@ -1313,3 +1344,157 @@ Please provide revised tool call based on this feedback.`, return f, nil } + +// compactFragment compacts the conversation by generating a summary of the history +// and keeping only the most recent messages. +// Returns a new fragment with the summary prepended and recent messages appended. +func compactFragment(ctx context.Context, llm LLM, f Fragment, keepMessages int, prompts prompt.PromptMap) (Fragment, error) { + xlog.Debug("[compactFragment] Starting conversation compaction", "currentMessages", len(f.Messages), "keepMessages", keepMessages) + + // Get the conversation context (everything except the most recent messages) + var contextMessages []openai.ChatCompletionMessage + var toolResults []string + + if len(f.Messages) > keepMessages { + contextMessages = f.Messages[:len(f.Messages)-keepMessages] + } else { + contextMessages = f.Messages + } + + // Extract tool results from context + for _, msg := range contextMessages { + if msg.Role == "tool" { + toolResults = append(toolResults, msg.Content) + } + } + + // Build context string + contextStr := "" + for _, msg := range contextMessages { + if msg.Role == "system" { + continue // Skip system messages in summary + } + contextStr += fmt.Sprintf("%s: %s\n", msg.Role, msg.Content) + } + + // Build tool results string + toolResultsStr := "" + for i, result := range toolResults { + toolResultsStr += fmt.Sprintf("Tool result %d: %s\n", i+1, result) + } + + // Render the compaction prompt + prompter := prompts.GetPrompt(prompt.PromptConversationCompactionType) + compactionData := struct { + Context string + ToolResults string + }{ + Context: contextStr, + ToolResults: toolResultsStr, + } + + compactionPrompt, err := prompter.Render(compactionData) + if err != nil { + return f, fmt.Errorf("failed to render compaction prompt: %w", err) + } + + // Ask the LLM to generate a summary + summaryFragment := NewEmptyFragment().AddMessage("user", compactionPrompt) + summaryFragment, _, err = llm.Ask(ctx, summaryFragment) + if err != nil { + return f, fmt.Errorf("failed to generate compaction summary: %w", err) + } + + // Get the summary from the LLM response + var summary string + if len(summaryFragment.Messages) > 0 { + summary = summaryFragment.Messages[len(summaryFragment.Messages)-1].Content + } + + xlog.Debug("[compactFragment] Generated summary", "summaryLength", len(summary)) + + // Build new fragment with summary + recent messages + newFragment := NewEmptyFragment() + + // Add system message indicating compaction + newFragment = newFragment.AddMessage("system", "[This conversation has been compacted to reduce token count. The following is a summary of previous context:]") + + // Add the summary + newFragment = newFragment.AddMessage("assistant", summary) + + // Add the recent messages we want to keep + if len(f.Messages) > keepMessages { + recentMessages := f.Messages[len(f.Messages)-keepMessages:] + for _, msg := range recentMessages { + newFragment = newFragment.AddMessage(MessageRole(msg.Role), msg.Content) + // Preserve tool calls if any + if len(msg.ToolCalls) > 0 { + lastMsg := newFragment.Messages[len(newFragment.Messages)-1] + lastMsg.ToolCalls = msg.ToolCalls + newFragment.Messages[len(newFragment.Messages)-1] = lastMsg + } + } + } else { + // If we don't have more than keepMessages, just use what we have + for _, msg := range f.Messages { + newFragment = newFragment.AddMessage(MessageRole(msg.Role), msg.Content) + } + } + + // Preserve parent fragment and status + newFragment.ParentFragment = f.ParentFragment + if f.Status != nil { + newFragment.Status = &Status{ + ReasoningLog: f.Status.ReasoningLog, + ToolsCalled: f.Status.ToolsCalled, + ToolResults: f.Status.ToolResults, + PastActions: f.Status.PastActions, + InjectedMessages: f.Status.InjectedMessages, + Iterations: f.Status.Iterations, + } + } + + xlog.Debug("[compactFragment] Compaction complete", "newMessages", len(newFragment.Messages)) + + return newFragment, nil +} + +// checkAndCompact checks if actual token count from LLM response exceeds threshold and performs compaction if needed +// Returns the (potentially compacted) fragment and whether compaction was performed +func checkAndCompact(ctx context.Context, llm LLM, f Fragment, threshold int, keepMessages int, prompts prompt.PromptMap) (Fragment, bool, error) { + if threshold <= 0 { + return f, false, nil // Compaction disabled + } + + // Use the actual usage tokens from the last LLM call stored in Status + totalUsedTokens := 0 + if f.Status != nil && f.Status.LastUsage.TotalTokens > 0 { + totalUsedTokens = f.Status.LastUsage.TotalTokens + xlog.Debug("[checkAndCompact] Using actual usage tokens from LLM response", "totalUsedTokens", totalUsedTokens, "threshold", threshold) + } else { + // Fallback to rough estimate if no usage data available (first iteration) + for _, msg := range f.Messages { + if msg.Role == "assistant" || msg.Role == "tool" { + totalUsedTokens += len(msg.Content) / 4 // Rough estimate + } + } + // Also count tool call arguments + for _, msg := range f.Messages { + for _, tc := range msg.ToolCalls { + totalUsedTokens += len(tc.Function.Name) + len(tc.Function.Arguments) + } + } + xlog.Debug("[checkAndCompact] Using rough estimate (no usage data)", "totalUsedTokens", totalUsedTokens, "threshold", threshold) + } + + if totalUsedTokens >= threshold { + xlog.Debug("[checkAndCompact] Token threshold exceeded", "totalUsedTokens", totalUsedTokens, "threshold", threshold) + compacted, err := compactFragment(ctx, llm, f, keepMessages, prompts) + if err != nil { + return f, false, err + } + return compacted, true, nil + } + + return f, false, nil +}