diff --git a/README.md b/README.md index dae98e0..0ad7409 100644 --- a/README.md +++ b/README.md @@ -1109,6 +1109,45 @@ result, err := cogito.ExecuteTools(llm, fragment, cogito.EnableStrictGuidelines) ``` + +### Automatic Conversation Compaction + +Cogito can automatically compact conversations to prevent context overflow when token usage exceeds a threshold. This is useful for long-running conversations with LLMs that have context limits. + +**How it works:** + +1. After each LLM call, Cogito checks if the token count exceeds the threshold +2. If exceeded, it generates a summary of the conversation history using an LLM +3. The original messages are replaced with a condensed summary, preserving context + +**Basic Usage:** + +```go +// Enable automatic compaction with a token threshold of 4000 +// This will trigger compaction when the conversation exceeds 4000 tokens +result, err := cogito.ExecuteTools(llm, fragment, + cogito.WithTools(searchTool), + cogito.WithCompactionThreshold(4000)) +``` + +**Customizing Compaction:** + +```go +// Set custom compaction options +result, err := cogito.ExecuteTools(llm, fragment, + cogito.WithTools(searchTool), + cogito.WithCompactionThreshold(4000), // Trigger at 4000 tokens + cogito.WithCompactionKeepMessages(5), // Keep last 5 messages (default: 10) +) +``` + +**Notes:** + +- Compaction requires token usage data from the LLM (supported by OpenAI, LocalAI with token usage enabled) +- If `LastUsage` is not available, Cogito falls back to estimating tokens from message count +- The summary prompt uses the conversation compaction prompt type +- Compaction preserves `Status` fields like `LastUsage`, `ToolsCalled`, etc. + ### Custom Prompts ```go diff --git a/clients/localai_client.go b/clients/localai_client.go index cab9ef7..4d05714 100644 --- a/clients/localai_client.go +++ b/clients/localai_client.go @@ -81,17 +81,17 @@ func (m *localAICompletionMessage) UnmarshalJSON(data []byte) error { // CreateChatCompletion sends the chat completion request and parses the response, // including LocalAI's optional "reasoning" field, into LLMReply.ReasoningContent. -func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, error) { +func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, cogito.LLMUsage, error) { request.Model = llm.model body, err := json.Marshal(request) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: marshal request: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: marshal request: %w", err) } url := llm.baseURL + "/chat/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: new request: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: new request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") @@ -101,21 +101,21 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open resp, err := llm.client.Do(req) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: request: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: read response: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: read response: %w", err) } if resp.StatusCode != http.StatusOK { var errRes openai.ErrorResponse if json.Unmarshal(respBody, &errRes) == nil && errRes.Error != nil { - return cogito.LLMReply{}, errRes.Error + return cogito.LLMReply{}, cogito.LLMUsage{}, errRes.Error } - return cogito.LLMReply{}, &openai.RequestError{ + return cogito.LLMReply{}, cogito.LLMUsage{}, &openai.RequestError{ HTTPStatus: resp.Status, HTTPStatusCode: resp.StatusCode, Err: fmt.Errorf("localai: %s", string(respBody)), @@ -125,11 +125,11 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open var localResp localAIChatCompletionResponse if err := json.Unmarshal(respBody, &localResp); err != nil { - return cogito.LLMReply{}, fmt.Errorf("localai: unmarshal response: %w", err) + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: unmarshal response: %w", err) } if len(localResp.Choices) == 0 { - return cogito.LLMReply{}, fmt.Errorf("localai: no choices in response") + return cogito.LLMReply{}, cogito.LLMUsage{}, fmt.Errorf("localai: no choices in response") } choice := localResp.Choices[0] @@ -157,30 +157,42 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open // Ensure ReasoningContent is set for downstream (e.g. tools.go). response.Choices[0].Message.ReasoningContent = reasoning + usage := cogito.LLMUsage{ + PromptTokens: localResp.Usage.PromptTokens, + CompletionTokens: localResp.Usage.CompletionTokens, + TotalTokens: localResp.Usage.TotalTokens, + } + return cogito.LLMReply{ ChatCompletionResponse: response, ReasoningContent: reasoning, - }, nil + }, usage, nil } // Ask prompts the LLM with the provided messages and returns a Fragment // containing the response. Uses CreateChatCompletion so reasoning is preserved. +// The Fragment's Status.LastUsage is updated with the token usage. func (llm *LocalAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fragment, error) { messages := f.GetMessages() request := openai.ChatCompletionRequest{ Model: llm.model, Messages: messages, } - reply, err := llm.CreateChatCompletion(ctx, request) + reply, usage, err := llm.CreateChatCompletion(ctx, request) if err != nil { return cogito.Fragment{}, err } if len(reply.ChatCompletionResponse.Choices) == 0 { return cogito.Fragment{}, fmt.Errorf("localai: no choices in response") } - return cogito.Fragment{ + result := cogito.Fragment{ Messages: append(f.Messages, reply.ChatCompletionResponse.Choices[0].Message), ParentFragment: &f, - Status: &cogito.Status{}, - }, nil + Status: f.Status, + } + if result.Status == nil { + result.Status = &cogito.Status{} + } + result.Status.LastUsage = usage + return result, nil } diff --git a/clients/openai_client.go b/clients/openai_client.go index 4dbc69e..fcdc504 100644 --- a/clients/openai_client.go +++ b/clients/openai_client.go @@ -27,6 +27,7 @@ func NewOpenAILLM(model, apiKey, baseURL string) *OpenAIClient { // and returns a Fragment containing the response. // The Fragment.GetMessages() method automatically handles force-text-reply // when tool calls are present in the conversation history. +// The Fragment's Status.LastUsage is updated with the token usage. func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fragment, error) { // Use Fragment.GetMessages() which automatically adds force-text-reply // system message when tool calls are detected in the conversation @@ -40,27 +41,47 @@ func (llm *OpenAIClient) Ask(ctx context.Context, f cogito.Fragment) (cogito.Fra }, ) - if err == nil && len(resp.Choices) > 0 { - return cogito.Fragment{ + if err != nil { + return cogito.Fragment{}, err + } + + if len(resp.Choices) > 0 { + usage := cogito.LLMUsage{ + PromptTokens: resp.Usage.PromptTokens, + CompletionTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + } + result := cogito.Fragment{ Messages: append(f.Messages, resp.Choices[0].Message), ParentFragment: &f, - Status: &cogito.Status{}, - }, nil + Status: f.Status, + } + if result.Status == nil { + result.Status = &cogito.Status{} + } + result.Status.LastUsage = usage + return result, nil } - return cogito.Fragment{}, err + return cogito.Fragment{}, nil } - -func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, error) { +func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (cogito.LLMReply, cogito.LLMUsage, error) { request.Model = llm.model response, err := llm.client.CreateChatCompletion(ctx, request) if err != nil { - return cogito.LLMReply{}, err + return cogito.LLMReply{}, cogito.LLMUsage{}, err } + + usage := cogito.LLMUsage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + } + return cogito.LLMReply{ ChatCompletionResponse: response, ReasoningContent: response.Choices[0].Message.ReasoningContent, - }, nil + }, usage, nil } // NewOpenAIService creates a new OpenAI service instance diff --git a/fragment.go b/fragment.go index d136d6a..8e29890 100644 --- a/fragment.go +++ b/fragment.go @@ -32,6 +32,7 @@ type InjectedMessage struct { } type Status struct { + LastUsage LLMUsage // Track token usage from the last LLM call Iterations int ToolsCalled Tools ToolResults []ToolStatus @@ -97,6 +98,7 @@ func NewEmptyFragment() Fragment { ReasoningLog: []string{}, ToolsCalled: Tools{}, ToolResults: []ToolStatus{}, + LastUsage: LLMUsage{}, }, } } @@ -109,6 +111,7 @@ func NewFragment(messages ...openai.ChatCompletionMessage) Fragment { ReasoningLog: []string{}, ToolsCalled: Tools{}, ToolResults: []ToolStatus{}, + LastUsage: LLMUsage{}, }, } } @@ -210,11 +213,13 @@ func (r Fragment) ExtractStructure(ctx context.Context, llm LLM, s structures.St }, } - resp, err := llm.CreateChatCompletion(ctx, decision) + resp, usage, err := llm.CreateChatCompletion(ctx, decision) if err != nil { return err } + r.Status.LastUsage = usage + if len(resp.ChatCompletionResponse.Choices) != 1 { return fmt.Errorf("no choices: %d", len(resp.ChatCompletionResponse.Choices)) } @@ -271,11 +276,13 @@ func (f Fragment) SelectTool(ctx context.Context, llm LLM, availableTools Tools, } } - resp, err := llm.CreateChatCompletion(ctx, decision) + resp, usage, err := llm.CreateChatCompletion(ctx, decision) if err != nil { return Fragment{}, nil, err } + f.Status.LastUsage = usage + if len(resp.ChatCompletionResponse.Choices) != 1 { return Fragment{}, nil, fmt.Errorf("no choices: %d", len(resp.ChatCompletionResponse.Choices)) } diff --git a/fragment_e2e_test.go b/fragment_e2e_test.go index c810241..c862d16 100644 --- a/fragment_e2e_test.go +++ b/fragment_e2e_test.go @@ -156,7 +156,7 @@ var _ = Describe("Result test", Label("e2e"), func() { Content: "What's the weather today in San Francisco?", }) - newFragment, result, err := fragment.SelectTool(context.TODO(), *defaultLLM, Tools{ + newFragment, result, err := fragment.SelectTool(context.TODO(), defaultLLM, Tools{ NewToolDefinition( (&GetWeatherTool{}), WeatherArgs{}, diff --git a/llm.go b/llm.go index d2b4193..21af0ad 100644 --- a/llm.go +++ b/llm.go @@ -6,9 +6,16 @@ import ( "github.com/sashabaranov/go-openai" ) +// LLMUsage represents token usage information from an LLM response +type LLMUsage struct { + PromptTokens int + CompletionTokens int + TotalTokens int +} + type LLM interface { Ask(ctx context.Context, f Fragment) (Fragment, error) - CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, error) + CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) } type LLMReply struct { diff --git a/options.go b/options.go index d9c5157..76907b5 100644 --- a/options.go +++ b/options.go @@ -63,23 +63,29 @@ type Options struct { todos *structures.TODOList messagesManipulator func([]openai.ChatCompletionMessage) []openai.ChatCompletionMessage + + // Compaction options - automatic conversation compaction based on token count + compactionThreshold int // Token count threshold that triggers compaction (0 = disabled) + compactionKeepMessages int // Number of recent messages to keep after compaction } type Option func(*Options) func defaultOptions() *Options { return &Options{ - maxIterations: 1, - maxAttempts: 1, - maxRetries: 5, - loopDetectionSteps: 0, - forceReasoning: false, - maxAdjustmentAttempts: 5, - sinkStateTool: &defaultSinkStateTool{}, - sinkState: true, - context: context.Background(), - statusCallback: func(s string) {}, - reasoningCallback: func(s string) {}, + maxIterations: 1, + maxAttempts: 1, + maxRetries: 5, + loopDetectionSteps: 0, + forceReasoning: false, + maxAdjustmentAttempts: 5, + sinkStateTool: &defaultSinkStateTool{}, + sinkState: true, + context: context.Background(), + statusCallback: func(s string) {}, + reasoningCallback: func(s string) {}, + compactionThreshold: 0, // Disabled by default + compactionKeepMessages: 10, // Keep 10 recent messages by default } } @@ -367,6 +373,24 @@ func WithMessageInjectionResultChan(ch chan MessageInjectionResult) func(o *Opti } } +// WithCompactionThreshold sets the token count threshold that triggers automatic +// conversation compaction. When total tokens in the response >= threshold, +// the conversation will be compacted to stay within the limit. +// Set to 0 (default) to disable automatic compaction. +func WithCompactionThreshold(threshold int) func(o *Options) { + return func(o *Options) { + o.compactionThreshold = threshold + } +} + +// WithCompactionKeepMessages sets the number of recent messages to keep after +// compaction. Default is 10. This only applies when WithCompactionThreshold is set. +func WithCompactionKeepMessages(count int) func(o *Options) { + return func(o *Options) { + o.compactionKeepMessages = count + } +} + type defaultSinkStateTool struct{} func (d *defaultSinkStateTool) Execute(args map[string]any) (string, any, error) { diff --git a/prompt/prompt.go b/prompt/prompt.go index 6006065..aa7e712 100644 --- a/prompt/prompt.go +++ b/prompt/prompt.go @@ -3,44 +3,46 @@ package prompt type PromptType uint const ( - GapAnalysisType PromptType = iota - ContentImproverType PromptType = iota - PromptBooleanType PromptType = iota - PromptIdentifyGoalType PromptType = iota - PromptGoalAchievedType PromptType = iota - PromptPlanType PromptType = iota - PromptReEvaluatePlanType PromptType = iota - PromptSubtaskExtractionType PromptType = iota - PromptPlanExecutionType PromptType = iota - PromptGuidelinesType PromptType = iota - PromptGuidelinesExtractionType PromptType = iota - PromptPlanDecisionType PromptType = iota - PromptParameterReasoningType PromptType = iota - PromptTODOGenerationType PromptType = iota - PromptTODOWorkType PromptType = iota - PromptTODOReviewType PromptType = iota - PromptTODOTrackingType PromptType = iota + GapAnalysisType PromptType = iota + ContentImproverType PromptType = iota + PromptBooleanType PromptType = iota + PromptIdentifyGoalType PromptType = iota + PromptGoalAchievedType PromptType = iota + PromptPlanType PromptType = iota + PromptReEvaluatePlanType PromptType = iota + PromptSubtaskExtractionType PromptType = iota + PromptPlanExecutionType PromptType = iota + PromptGuidelinesType PromptType = iota + PromptGuidelinesExtractionType PromptType = iota + PromptPlanDecisionType PromptType = iota + PromptParameterReasoningType PromptType = iota + PromptTODOGenerationType PromptType = iota + PromptTODOWorkType PromptType = iota + PromptTODOReviewType PromptType = iota + PromptTODOTrackingType PromptType = iota + PromptConversationCompactionType PromptType = iota ) var ( defaultPromptMap PromptMap = map[PromptType]Prompt{ - GapAnalysisType: PromptGapsAnalysis, - ContentImproverType: PromptContentImprover, - PromptBooleanType: PromptExtractBoolean, - PromptIdentifyGoalType: PromptIdentifyGoal, - PromptGoalAchievedType: PromptGoalAchieved, - PromptPlanType: PromptPlan, - PromptReEvaluatePlanType: PromptReEvaluatePlan, - PromptSubtaskExtractionType: PromptSubtaskExtraction, - PromptPlanExecutionType: PromptPlanExecution, - PromptGuidelinesType: PromptGuidelines, - PromptGuidelinesExtractionType: PromptGuidelinesExtraction, - PromptPlanDecisionType: DecideIfPlanningIsNeeded, - PromptParameterReasoningType: PromptParameterReasoning, - PromptTODOGenerationType: PromptTODOGeneration, - PromptTODOWorkType: PromptTODOWork, - PromptTODOReviewType: PromptTODOReview, - PromptTODOTrackingType: PromptTODOTracking, + GapAnalysisType: PromptGapsAnalysis, + ContentImproverType: PromptContentImprover, + PromptBooleanType: PromptExtractBoolean, + PromptIdentifyGoalType: PromptIdentifyGoal, + PromptGoalAchievedType: PromptGoalAchieved, + PromptPlanType: PromptPlan, + PromptReEvaluatePlanType: PromptReEvaluatePlan, + PromptSubtaskExtractionType: PromptSubtaskExtraction, + PromptPlanExecutionType: PromptPlanExecution, + PromptGuidelinesType: PromptGuidelines, + PromptGuidelinesExtractionType: PromptGuidelinesExtraction, + PromptPlanDecisionType: DecideIfPlanningIsNeeded, + PromptParameterReasoningType: PromptParameterReasoning, + PromptTODOGenerationType: PromptTODOGeneration, + PromptTODOWorkType: PromptTODOWork, + PromptTODOReviewType: PromptTODOReview, + PromptTODOTrackingType: PromptTODOTracking, + PromptConversationCompactionType: PromptConversationCompaction, } PromptGuidelinesExtraction = NewPrompt("What guidelines should be applied? return only the numbers of the guidelines by using the json tool with a list of integers corresponding to the guidelines.") @@ -328,4 +330,20 @@ Use the "json" tool to return an updated TODO list with: - Completed TODOs marked as completed - Any new TODOs that were identified - Updated feedback for TODOs if provided`) + + PromptConversationCompaction = NewPrompt(`You are an AI assistant that summarizes a conversation history to preserve important context while reducing token count. + +Analyze the conversation history and create a concise summary that preserves: +1. The original user request/goal +2. Key decisions and reasoning +3. Important tool results +4. Current state of the task + +Conversation History: +{{.Context}} + +Tool Results: +{{.ToolResults}} + +Provide a summary that allows continuing the task without losing critical context. Be concise but comprehensive.`) ) diff --git a/prompt/type.go b/prompt/type.go index 98f5665..0a6bd78 100644 --- a/prompt/type.go +++ b/prompt/type.go @@ -43,3 +43,8 @@ func (p PromptMap) GetPrompt(t PromptType) Prompt { return prompter } + +// DefaultPrompts returns the default prompt map +func DefaultPrompts() PromptMap { + return defaultPromptMap +} diff --git a/tests/mock/client.go b/tests/mock/client.go index 13183d6..607ee06 100644 --- a/tests/mock/client.go +++ b/tests/mock/client.go @@ -19,12 +19,20 @@ type MockOpenAIClient struct { AskError error CreateChatCompletionError error FragmentHistory []Fragment + + // Token usage for responses + AskUsage []LLMUsage + AskUsageIndex int + CreateChatCompletionUsage []LLMUsage + CreateChatCompletionUsageIndex int } func NewMockOpenAIClient() *MockOpenAIClient { return &MockOpenAIClient{ AskResponses: []Fragment{}, CreateChatCompletionResponses: []openai.ChatCompletionResponse{}, + AskUsage: []LLMUsage{}, + CreateChatCompletionUsage: []LLMUsage{}, } } @@ -48,26 +56,45 @@ func (m *MockOpenAIClient) Ask(ctx context.Context, f Fragment) (Fragment, error response.Messages = append(f.Messages, response.Messages...) response.ParentFragment = &f + // Get usage if available and set it in the Status + var usage LLMUsage + if m.AskUsageIndex < len(m.AskUsage) { + usage = m.AskUsage[m.AskUsageIndex] + m.AskUsageIndex++ + } + if response.Status == nil { + response.Status = f.Status + } + response.Status.LastUsage = usage + return response, nil } -func (m *MockOpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, error) { +func (m *MockOpenAIClient) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) { if m.CreateChatCompletionError != nil { - return LLMReply{}, m.CreateChatCompletionError + return LLMReply{}, LLMUsage{}, m.CreateChatCompletionError } if m.CreateChatCompletionIndex >= len(m.CreateChatCompletionResponses) { - return LLMReply{}, fmt.Errorf("no more CreateChatCompletion responses configured") + return LLMReply{}, LLMUsage{}, fmt.Errorf("no more CreateChatCompletion responses configured") } response := m.CreateChatCompletionResponses[m.CreateChatCompletionIndex] m.CreateChatCompletionIndex++ xlog.Info("CreateChatCompletion response", "response", response) + + // Get usage if available + var usage LLMUsage + if m.CreateChatCompletionUsageIndex < len(m.CreateChatCompletionUsage) { + usage = m.CreateChatCompletionUsage[m.CreateChatCompletionUsageIndex] + m.CreateChatCompletionUsageIndex++ + } + return LLMReply{ ChatCompletionResponse: response, ReasoningContent: response.Choices[0].Message.ReasoningContent, - }, nil + }, usage, nil } // Helper methods for setting up mock responses @@ -87,6 +114,7 @@ func (m *MockOpenAIClient) SetCreateChatCompletionResponse(response openai.ChatC func (m *MockOpenAIClient) AddCreateChatCompletionFunction(name, args string) { m.SetCreateChatCompletionResponse( openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ { Message: openai.ChatCompletionMessage{ @@ -109,3 +137,14 @@ func (m *MockOpenAIClient) AddCreateChatCompletionFunction(name, args string) { func (m *MockOpenAIClient) SetCreateChatCompletionError(err error) { m.CreateChatCompletionError = err } + +// SetUsage sets token usage for the next responses +func (m *MockOpenAIClient) SetUsage(promptTokens, completionTokens, totalTokens int) { + usage := LLMUsage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: totalTokens, + } + m.AskUsage = append(m.AskUsage, usage) + m.CreateChatCompletionUsage = append(m.CreateChatCompletionUsage, usage) +} diff --git a/tools.go b/tools.go index 1fa6f89..435281b 100644 --- a/tools.go +++ b/tools.go @@ -39,6 +39,7 @@ type decisionResult struct { toolChoices []*ToolChoice message string reasoning string + usage LLMUsage } type ToolDefinitionInterface interface { @@ -203,7 +204,7 @@ func decision(ctx context.Context, llm LLM, conversation []openai.ChatCompletion var lastErr error for attempts := 0; attempts < maxRetries; attempts++ { - resp, err := llm.CreateChatCompletion(ctx, decision) + resp, usage, err := llm.CreateChatCompletion(ctx, decision) if err != nil { lastErr = err xlog.Warn("Attempt to make a decision failed", "attempt", attempts+1, "error", err) @@ -225,7 +226,7 @@ func decision(ctx context.Context, llm LLM, conversation []openai.ChatCompletion if len(msg.ToolCalls) == 0 { // No tool call - the LLM just responded with text - return &decisionResult{message: msg.Content, reasoning: reasoning}, nil + return &decisionResult{message: msg.Content, reasoning: reasoning, usage: usage}, nil } // Process all tool calls @@ -254,6 +255,7 @@ func decision(ctx context.Context, llm LLM, conversation []openai.ChatCompletion toolChoices: toolChoices, message: msg.Content, reasoning: reasoning, + usage: usage, } return result, nil } @@ -568,7 +570,7 @@ func pickTool(ctx context.Context, llm LLM, fragment Fragment, tools Tools, opts } // Return the tool choices without parameters - they'll be generated separately - return &decisionResult{toolChoices: toolChoices, reasoning: reasoning}, nil + return &decisionResult{toolChoices: toolChoices, reasoning: reasoning, usage: intentionResult.usage}, nil } func decideToPlan(llm LLM, f Fragment, tools Tools, opts ...Option) (bool, error) { @@ -702,6 +704,8 @@ func toolSelection(llm LLM, f Fragment, tools Tools, guidelines Guidelines, tool selectedTools, reasoning := results.toolChoices, results.reasoning if len(selectedTools) == 0 { + f.Status.LastUsage = results.usage + // No tool was selected, reasoning contains the response xlog.Debug("[toolSelection] No tool selected", "reasoning", reasoning) o.statusCallback(reasoning) @@ -770,7 +774,7 @@ func toolSelection(llm LLM, f Fragment, tools Tools, guidelines Guidelines, tool Role: AssistantMessageRole.String(), ToolCalls: toolCalls, }) - + resultFragment.Status.LastUsage = results.usage return resultFragment, selectedTools, false, "", nil } @@ -884,19 +888,54 @@ TOOL_LOOP: o.statusCallback("Max total iterations reached, stopping execution") } - // Preserve the status before calling Ask + // Compact before final Ask if threshold exceeded (we would not reach compaction check in next iteration) + if o.compactionThreshold > 0 { + var compacted bool + var compactErr error + f, compacted, compactErr = checkAndCompact(o.context, llm, f, o.compactionThreshold, o.compactionKeepMessages, o.prompts) + if compactErr != nil { + return f, fmt.Errorf("failed to compact: %w", compactErr) + } + if compacted { + xlog.Debug("Fragment compacted before final response") + } + } + status := f.Status + parentBeforeAsk := f.ParentFragment f, err := llm.Ask(o.context, f) if err != nil { return f, fmt.Errorf("failed to ask LLM: %w", err) } - // Restore the status - f.Status = status + f.Status.ToolResults = status.ToolResults + f.Status.ToolsCalled = status.ToolsCalled + f.Status.LastUsage = status.LastUsage + f.Status.Iterations = status.Iterations + f.Status.ReasoningLog = status.ReasoningLog + f.Status.TODOs = status.TODOs + f.Status.TODOIteration = status.TODOIteration + f.Status.TODOPhase = status.TODOPhase + // Preserve original parent (LLM.Ask often sets response.ParentFragment to the request fragment) + if parentBeforeAsk != nil { + f.ParentFragment = parentBeforeAsk + } + return f, nil } totalIterations++ + // Check and compact if token threshold exceeded (before running next tool loop iteration) + if o.compactionThreshold > 0 { + f, compacted, err := checkAndCompact(o.context, llm, f, o.compactionThreshold, o.compactionKeepMessages, o.prompts) + if err != nil { + return f, fmt.Errorf("failed to compact: %w", err) + } + if compacted { + xlog.Debug("Fragment compacted successfully before next tool loop iteration") + } + } + // get guidelines and tools for the current fragment tools, guidelines, toolPrompts, err := usableTools(llm, f, opts...) if err != nil { @@ -1136,14 +1175,15 @@ Please provide revised tool call based on this feedback.`, finalToolsToExecute = toolsToExecute } - // Update fragment with the message (ID should already be set in ToolCall) - f = f.AddLastMessage(selectedToolFragment) - // Add skipped tools to fragment for _, skippedTool := range toolsToSkip { f = f.AddToolMessage("Tool call skipped by user", skippedTool.ID) } + // Update fragment with the message (ID should already be set in ToolCall) + f = f.AddLastMessage(selectedToolFragment) + f.Status.LastUsage = selectedToolFragment.Status.LastUsage + // Check context before executing tools select { case <-o.context.Done(): @@ -1284,14 +1324,23 @@ Please provide revised tool call based on this feedback.`, } - var err error // If sink state was found, stop execution after processing all tools if hasSinkState { xlog.Debug("Sink state was found, stopping execution after processing tools") - f, err = llm.Ask(o.context, f) + status := f.Status + f, err := llm.Ask(o.context, f) if err != nil { return f, fmt.Errorf("failed to ask LLM: %w", err) } + + f.Status.ToolResults = status.ToolResults + f.Status.ToolsCalled = status.ToolsCalled + f.Status.LastUsage = status.LastUsage + f.Status.Iterations = status.Iterations + f.Status.ReasoningLog = status.ReasoningLog + f.Status.TODOs = status.TODOs + f.Status.TODOIteration = status.TODOIteration + f.Status.TODOPhase = status.TODOPhase } if len(f.Status.ToolsCalled) == 0 { @@ -1313,3 +1362,157 @@ Please provide revised tool call based on this feedback.`, return f, nil } + +// compactFragment compacts the conversation by generating a summary of the history +// and keeping only the most recent messages. +// Returns a new fragment with the summary prepended and recent messages appended. +func compactFragment(ctx context.Context, llm LLM, f Fragment, keepMessages int, prompts prompt.PromptMap) (Fragment, error) { + xlog.Debug("[compactFragment] Starting conversation compaction", "currentMessages", len(f.Messages), "keepMessages", keepMessages) + + // Get the conversation context (everything except the most recent messages) + var contextMessages []openai.ChatCompletionMessage + var toolResults []string + + if len(f.Messages) > keepMessages { + contextMessages = f.Messages[:len(f.Messages)-keepMessages] + } else { + contextMessages = f.Messages + } + + // Extract tool results from context + for _, msg := range contextMessages { + if msg.Role == "tool" { + toolResults = append(toolResults, msg.Content) + } + } + + // Build context string + contextStr := "" + for _, msg := range contextMessages { + if msg.Role == "system" { + continue // Skip system messages in summary + } + contextStr += fmt.Sprintf("%s: %s\n", msg.Role, msg.Content) + } + + // Build tool results string + toolResultsStr := "" + for i, result := range toolResults { + toolResultsStr += fmt.Sprintf("Tool result %d: %s\n", i+1, result) + } + + // Render the compaction prompt + prompter := prompts.GetPrompt(prompt.PromptConversationCompactionType) + compactionData := struct { + Context string + ToolResults string + }{ + Context: contextStr, + ToolResults: toolResultsStr, + } + + compactionPrompt, err := prompter.Render(compactionData) + if err != nil { + return f, fmt.Errorf("failed to render compaction prompt: %w", err) + } + + // Ask the LLM to generate a summary + summaryFragment := NewEmptyFragment().AddMessage("user", compactionPrompt) + summaryFragment, err = llm.Ask(ctx, summaryFragment) + if err != nil { + return f, fmt.Errorf("failed to generate compaction summary: %w", err) + } + + // Get the summary from the LLM response + var summary string + if len(summaryFragment.Messages) > 0 { + summary = summaryFragment.Messages[len(summaryFragment.Messages)-1].Content + } + + xlog.Debug("[compactFragment] Generated summary", "summaryLength", len(summary)) + + // Build new fragment with summary + recent messages + newFragment := NewEmptyFragment() + + // Add system message indicating compaction + newFragment = newFragment.AddMessage("system", "[This conversation has been compacted to reduce token count. The following is a summary of previous context:]") + + // Add the summary + newFragment = newFragment.AddMessage("assistant", summary) + + // Add the recent messages we want to keep + if len(f.Messages) > keepMessages { + recentMessages := f.Messages[len(f.Messages)-keepMessages:] + for _, msg := range recentMessages { + newFragment = newFragment.AddMessage(MessageRole(msg.Role), msg.Content) + // Preserve tool calls if any + if len(msg.ToolCalls) > 0 { + lastMsg := newFragment.Messages[len(newFragment.Messages)-1] + lastMsg.ToolCalls = msg.ToolCalls + newFragment.Messages[len(newFragment.Messages)-1] = lastMsg + } + } + } else { + // If we don't have more than keepMessages, just use what we have + for _, msg := range f.Messages { + newFragment = newFragment.AddMessage(MessageRole(msg.Role), msg.Content) + } + } + + // Preserve parent fragment and status + newFragment.ParentFragment = f.ParentFragment + if f.Status != nil { + newFragment.Status = &Status{ + ReasoningLog: f.Status.ReasoningLog, + ToolsCalled: f.Status.ToolsCalled, + ToolResults: f.Status.ToolResults, + PastActions: f.Status.PastActions, + InjectedMessages: f.Status.InjectedMessages, + Iterations: f.Status.Iterations, + } + } + + xlog.Debug("[compactFragment] Compaction complete", "newMessages", len(newFragment.Messages)) + + return newFragment, nil +} + +// checkAndCompact checks if actual token count from LLM response exceeds threshold and performs compaction if needed +// Returns the (potentially compacted) fragment and whether compaction was performed +func checkAndCompact(ctx context.Context, llm LLM, f Fragment, threshold int, keepMessages int, prompts prompt.PromptMap) (Fragment, bool, error) { + if threshold <= 0 { + return f, false, nil // Compaction disabled + } + + // Use the actual usage tokens from the last LLM call stored in Status + totalUsedTokens := 0 + if f.Status != nil && f.Status.LastUsage.TotalTokens > 0 { + totalUsedTokens = f.Status.LastUsage.TotalTokens + xlog.Debug("[checkAndCompact] Using actual usage tokens from LLM response", "totalUsedTokens", totalUsedTokens, "threshold", threshold) + } else { + // Fallback to rough estimate if no usage data available (first iteration) + for _, msg := range f.Messages { + if msg.Role == "assistant" || msg.Role == "tool" { + totalUsedTokens += len(msg.Content) / 4 // Rough estimate + } + } + // Also count tool call arguments + for _, msg := range f.Messages { + for _, tc := range msg.ToolCalls { + totalUsedTokens += len(tc.Function.Name) + len(tc.Function.Arguments) + } + } + xlog.Debug("[checkAndCompact] Using rough estimate (no usage data)", "totalUsedTokens", totalUsedTokens, "threshold", threshold) + } + + if totalUsedTokens >= threshold { + xlog.Debug("[checkAndCompact] Token threshold exceeded", "totalUsedTokens", totalUsedTokens, "threshold", threshold) + compacted, err := compactFragment(ctx, llm, f, keepMessages, prompts) + if err != nil { + return f, false, err + } + return compacted, true, nil + } + + return f, false, nil +} diff --git a/tools_test.go b/tools_test.go index 59575f5..2eb6303 100644 --- a/tools_test.go +++ b/tools_test.go @@ -2,6 +2,7 @@ package cogito_test import ( "fmt" + "strings" . "github.com/mudler/cogito" "github.com/mudler/cogito/tests/mock" @@ -975,3 +976,261 @@ var _ = Describe("ExecuteTools", func() { }) }) }) + +var _ = Describe("ExecuteTools with Compaction", func() { + var mockLLM *mock.MockOpenAIClient + var originalFragment Fragment + + BeforeEach(func() { + mockLLM = mock.NewMockOpenAIClient() + originalFragment = NewEmptyFragment(). + AddMessage(UserMessageRole, "Task 1"). + AddMessage(AssistantMessageRole, "Done 1") + }) + + Context("WithCompactionThreshold", func() { + It("should not compact when threshold is disabled (0)", func() { + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`) + + mockTool := mock.NewMockTool("search", "Search for information") + mock.SetRunResult(mockTool, "Result") + mockLLM.SetAskResponse("LLM result") + mockLLM.SetUsage(100, 100, 1000) + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Usage: openai.Usage{ + TotalTokens: 1000, + PromptTokens: 100, + CompletionTokens: 100, + }, + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "No more tools needed.", + }, + }, + }, + }) + + result, err := ExecuteTools(mockLLM, originalFragment, WithTools(mockTool), + WithCompactionThreshold(0), + ) + + Expect(err).ToNot(HaveOccurred()) + Expect(len(result.Messages)).ToNot(Equal(len(originalFragment.Messages)), fmt.Sprintf("result: %+v", result)) + Expect(result.Status.LastUsage.TotalTokens).To(BeNumerically(">", 0)) + Expect(len(result.Messages)).To(Equal(5)) + }) + + It("should not compact when tokens below threshold", func() { + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`) + mockTool := mock.NewMockTool("search", "Search for information") + mock.SetRunResult(mockTool, "Result") + mockLLM.SetAskResponse("LLM result") + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "No more tools needed.", + }, + }, + }, + }) + + // Create fragment with low token count + smallFragment := NewEmptyFragment(). + AddMessage(UserMessageRole, "Hi"). + AddMessage(AssistantMessageRole, "Hello") + + result, err := ExecuteTools(mockLLM, smallFragment, WithTools(mockTool), + WithCompactionThreshold(100000), + WithCompactionKeepMessages(2)) + + Expect(err).ToNot(HaveOccurred()) + // Should not be compacted - still has original messages + Expect(len(result.Messages)).To(BeNumerically(">", 2)) + }) + + It("should compact when token threshold is exceeded", func() { + mockTool := mock.NewMockTool("search", "Search for information") + + // First tool selection + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`) + mock.SetRunResult(mockTool, "Result") + mockLLM.SetAskResponse("LLM result") + + // After tool execution, no more tools needed + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "No more tools needed.", + }, + }, + }, + }) + + // Create a large fragment with high token count + largeFragment := NewEmptyFragment(). + AddMessage(UserMessageRole, "Task 1"). + AddMessage(AssistantMessageRole, "Answer to task 1"). + AddMessage(ToolMessageRole, "Result 1"). + AddMessage(UserMessageRole, "Task 2"). + AddMessage(AssistantMessageRole, "Answer to task 2"). + AddMessage(ToolMessageRole, "Result 2") + + // Set the usage to exceed threshold + mockLLM.SetUsage(100, 100, 5000) + + // Mock the compaction summary response + summaryFragment := NewEmptyFragment(). + AddMessage(AssistantMessageRole, "Summary of conversation history.") + mockLLM.AskResponses = append([]Fragment{summaryFragment}, mockLLM.AskResponses...) + + result, err := ExecuteTools(mockLLM, largeFragment, WithTools(mockTool), + WithCompactionThreshold(1000), + WithCompactionKeepMessages(1)) + + Expect(err).ToNot(HaveOccurred()) + + Expect(len(result.Messages)).To(BeNumerically(">", 0)) + + Expect(result.Messages[0].Role).To(Equal("system"), fmt.Sprintf("result: %+v", result)) + Expect(result.Messages[0].Content).To(ContainSubstring("compacted"), fmt.Sprintf("result: %+v", result)) + Expect(len(result.Messages)).To(BeNumerically("<", len(largeFragment.Messages))) + }) + + It("should preserve parent fragment after compaction", func() { + mockTool := mock.NewMockTool("search", "Search for information") + + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`) + mock.SetRunResult(mockTool, "Result") + mockLLM.SetAskResponse("LLM result") + + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "No more tools needed.", + }, + }, + }, + }) + + // Create a fragment with a parent + parentFragment := NewEmptyFragment().AddMessage(UserMessageRole, "Parent task") + largeFragment := NewEmptyFragment(). + AddMessage(UserMessageRole, "Task 1"). + AddMessage(AssistantMessageRole, strings.Repeat("response ", 5000)) + largeFragment.ParentFragment = &parentFragment + + // Set usage to exceed threshold + mockLLM.SetUsage(100, 100, 5000) + + // Mock the compaction summary response (may be used in-loop and again before final Ask) + summaryFragment := NewEmptyFragment(). + AddMessage(AssistantMessageRole, "Summary of conversation.") + mockLLM.AskResponses = append([]Fragment{summaryFragment}, mockLLM.AskResponses...) + mockLLM.AskResponses = append([]Fragment{summaryFragment}, mockLLM.AskResponses...) + + result, err := ExecuteTools(mockLLM, largeFragment, WithTools(mockTool), + WithCompactionThreshold(1000), + WithCompactionKeepMessages(1)) + + Expect(err).ToNot(HaveOccurred()) + Expect(result.ParentFragment).ToNot(BeNil()) + Expect(result.ParentFragment.Messages[0].Role).To(Equal(UserMessageRole.String())) + }) + + It("should preserve status after compaction", func() { + mockTool := mock.NewMockTool("search", "Search for information") + + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`) + mock.SetRunResult(mockTool, "Result") + mockLLM.SetAskResponse("LLM result") + + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "No more tools needed.", + }, + }, + }, + }) + + // Create fragment with status + largeFragment := NewEmptyFragment(). + AddMessage(UserMessageRole, "Task 1"). + AddMessage(AssistantMessageRole, strings.Repeat("response ", 5000)) + largeFragment.Status = &Status{ + Iterations: 5, + ReasoningLog: []string{"reasoning1", "reasoning2"}, + } + + // Set usage to exceed threshold + mockLLM.SetUsage(100, 100, 5000) + + // Mock the compaction summary response (may be used in-loop and again before final Ask) + summaryFragment := NewEmptyFragment(). + AddMessage(AssistantMessageRole, "Summary of conversation.") + mockLLM.AskResponses = append([]Fragment{summaryFragment}, mockLLM.AskResponses...) + mockLLM.AskResponses = append([]Fragment{summaryFragment}, mockLLM.AskResponses...) + + result, err := ExecuteTools(mockLLM, largeFragment, WithTools(mockTool), + WithCompactionThreshold(1000), + WithCompactionKeepMessages(1)) + + Expect(err).ToNot(HaveOccurred()) + Expect(result.Status).ToNot(BeNil()) + // Original had Iterations: 5; one tool loop iteration was run, so 6 + Expect(result.Status.Iterations).To(Equal(6)) + Expect(result.Status.ReasoningLog).To(Equal([]string{"reasoning1", "reasoning2"})) + }) + + It("should use rough token estimate when LastUsage is not set", func() { + mockTool := mock.NewMockTool("search", "Search for information") + + mockLLM.AddCreateChatCompletionFunction("search", `{"query": "test"}`) + mock.SetRunResult(mockTool, "Result") + mockLLM.SetAskResponse("LLM result") + + mockLLM.SetCreateChatCompletionResponse(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Role: AssistantMessageRole.String(), + Content: "No more tools needed.", + }, + }, + }, + }) + + // Large fragment without LastUsage set + largeFragment := NewEmptyFragment(). + AddMessage(UserMessageRole, "Task 1"). + AddMessage(AssistantMessageRole, strings.Repeat("response with lots of content ", 500)). + AddMessage(ToolMessageRole, "Result 1") + + // Mock the compaction summary response (may be used in-loop and again before final Ask) + summaryFragment := NewEmptyFragment(). + AddMessage(AssistantMessageRole, "Summary.") + mockLLM.AskResponses = append([]Fragment{summaryFragment}, mockLLM.AskResponses...) + mockLLM.AskResponses = append([]Fragment{summaryFragment}, mockLLM.AskResponses...) + + result, err := ExecuteTools(mockLLM, largeFragment, WithTools(mockTool), + WithCompactionThreshold(1000), + WithCompactionKeepMessages(1)) + + Expect(err).ToNot(HaveOccurred()) + // Should be compacted based on rough estimate + if len(result.Messages) > 0 { + Expect(result.Messages[0].Role).To(Equal("system")) + } + }) + }) +})