diff --git a/clients/localai_client.go b/clients/localai_client.go index 4963d06..63f923c 100644 --- a/clients/localai_client.go +++ b/clients/localai_client.go @@ -1,6 +1,7 @@ package clients import ( + "bufio" "bytes" "context" "encoding/json" @@ -13,8 +14,9 @@ import ( "github.com/sashabaranov/go-openai" ) -// Ensure LocalAIClient implements cogito.LLM at compile time. +// Ensure LocalAIClient implements cogito.LLM and cogito.StreamingLLM at compile time. var _ cogito.LLM = (*LocalAIClient)(nil) +var _ cogito.StreamingLLM = (*LocalAIClient)(nil) // LocalAIClient is an LLM client for LocalAI-compatible APIs. It uses the same // request format as OpenAI but parses an additional "reasoning" field in the @@ -192,6 +194,144 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open }, usage, nil } +// localAIStreamToolCallFunction represents the function part of a streaming tool call delta. +type localAIStreamToolCallFunction struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + +// localAIStreamToolCall represents a single tool call delta in a streaming chunk. +type localAIStreamToolCall struct { + Index *int `json:"index,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Function localAIStreamToolCallFunction `json:"function,omitempty"` +} + +// localAIStreamDelta represents the delta object in a streaming chunk. +type localAIStreamDelta struct { + Content string `json:"content,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + ToolCalls []localAIStreamToolCall `json:"tool_calls,omitempty"` +} + +// localAIStreamChoice represents a single choice in a streaming chunk. +type localAIStreamChoice struct { + Delta localAIStreamDelta `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// localAIStreamChunk represents a single SSE chunk from LocalAI streaming. +type localAIStreamChunk struct { + Choices []localAIStreamChoice `json:"choices"` +} + +// CreateChatCompletionStream streams chat completion events via a channel using SSE. +func (llm *LocalAIClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan cogito.StreamEvent, error) { + request.Model = llm.model + request.Stream = true + + body, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("localai stream: marshal request: %w", err) + } + + url := llm.baseURL + "/chat/completions" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("localai stream: new request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + if llm.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+llm.apiKey) + } + + resp, err := llm.client.Do(req) + if err != nil { + return nil, fmt.Errorf("localai stream: request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("localai stream: status %d: %s", resp.StatusCode, string(respBody)) + } + + ch := make(chan cogito.StreamEvent, 64) + go func() { + defer close(ch) + defer resp.Body.Close() + + var lastFinishReason string + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimPrefix(line, "data: ") + + if data == "[DONE]" { + ch <- cogito.StreamEvent{Type: cogito.StreamEventDone, FinishReason: lastFinishReason} + return + } + + var chunk localAIStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + if len(chunk.Choices) == 0 { + continue + } + + delta := chunk.Choices[0].Delta + reasoning := delta.Reasoning + if reasoning == "" { + reasoning = delta.ReasoningContent + } + if reasoning != "" { + ch <- cogito.StreamEvent{Type: cogito.StreamEventReasoning, Content: reasoning} + } + if delta.Content != "" { + ch <- cogito.StreamEvent{Type: cogito.StreamEventContent, Content: delta.Content} + } + + // Tool call deltas + for _, tc := range delta.ToolCalls { + idx := 0 + if tc.Index != nil { + idx = *tc.Index + } + ch <- cogito.StreamEvent{ + Type: cogito.StreamEventToolCall, + ToolName: tc.Function.Name, + ToolArgs: tc.Function.Arguments, + ToolCallID: tc.ID, + ToolCallIndex: idx, + } + } + + // Capture finish_reason + if chunk.Choices[0].FinishReason != "" { + lastFinishReason = chunk.Choices[0].FinishReason + } + } + + if err := scanner.Err(); err != nil { + ch <- cogito.StreamEvent{Type: cogito.StreamEventError, Error: err} + return + } + // If we reach here without [DONE], still emit done + ch <- cogito.StreamEvent{Type: cogito.StreamEventDone, FinishReason: lastFinishReason} + }() + + return ch, nil +} + // Ask prompts the LLM with the provided messages and returns a Fragment // containing the response. Uses CreateChatCompletion so reasoning is preserved. // The Fragment's Status.LastUsage is updated with the token usage. diff --git a/clients/openai_client.go b/clients/openai_client.go index fcdc504..519251b 100644 --- a/clients/openai_client.go +++ b/clients/openai_client.go @@ -2,12 +2,15 @@ package clients import ( "context" + "errors" + "io" "github.com/mudler/cogito" "github.com/sashabaranov/go-openai" ) var _ cogito.LLM = (*OpenAIClient)(nil) +var _ cogito.StreamingLLM = (*OpenAIClient)(nil) type OpenAIClient struct { model string @@ -84,6 +87,69 @@ func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request opena }, usage, nil } +// CreateChatCompletionStream streams chat completion events via a channel. +func (llm *OpenAIClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan cogito.StreamEvent, error) { + request.Model = llm.model + request.Stream = true + + stream, err := llm.client.CreateChatCompletionStream(ctx, request) + if err != nil { + return nil, err + } + + ch := make(chan cogito.StreamEvent, 64) + go func() { + defer close(ch) + defer stream.Close() + + var lastFinishReason string + + for { + resp, err := stream.Recv() + if errors.Is(err, io.EOF) { + ch <- cogito.StreamEvent{Type: cogito.StreamEventDone, FinishReason: lastFinishReason} + return + } + if err != nil { + ch <- cogito.StreamEvent{Type: cogito.StreamEventError, Error: err} + return + } + if len(resp.Choices) == 0 { + continue + } + delta := resp.Choices[0].Delta + if delta.ReasoningContent != "" { + ch <- cogito.StreamEvent{Type: cogito.StreamEventReasoning, Content: delta.ReasoningContent} + } + if delta.Content != "" { + ch <- cogito.StreamEvent{Type: cogito.StreamEventContent, Content: delta.Content} + } + + // Tool call deltas + for _, tc := range delta.ToolCalls { + idx := 0 + if tc.Index != nil { + idx = *tc.Index + } + ch <- cogito.StreamEvent{ + Type: cogito.StreamEventToolCall, + ToolName: tc.Function.Name, + ToolArgs: tc.Function.Arguments, + ToolCallID: tc.ID, + ToolCallIndex: idx, + } + } + + // Capture finish_reason (arrives on last chunk) + if resp.Choices[0].FinishReason != "" { + lastFinishReason = string(resp.Choices[0].FinishReason) + } + } + }() + + return ch, nil +} + // NewOpenAIService creates a new OpenAI service instance func openaiClient(apiKey string, baseURL string) *openai.Client { config := openai.DefaultConfig(apiKey) diff --git a/llm.go b/llm.go index 21af0ad..b634b9a 100644 --- a/llm.go +++ b/llm.go @@ -18,6 +18,13 @@ type LLM interface { CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) } +// StreamingLLM extends LLM with streaming support. +// Consumers should type-assert: if sllm, ok := llm.(StreamingLLM); ok { ... } +type StreamingLLM interface { + LLM + CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan StreamEvent, error) +} + type LLMReply struct { ChatCompletionResponse openai.ChatCompletionResponse ReasoningContent string diff --git a/options.go b/options.go index 76907b5..793f4fe 100644 --- a/options.go +++ b/options.go @@ -64,6 +64,9 @@ type Options struct { messagesManipulator func([]openai.ChatCompletionMessage) []openai.ChatCompletionMessage + // Streaming callback for live token delivery + streamCallback StreamCallback + // Compaction options - automatic conversation compaction based on token count compactionThreshold int // Token count threshold that triggers compaction (0 = disabled) compactionKeepMessages int // Number of recent messages to keep after compaction @@ -391,6 +394,14 @@ func WithCompactionKeepMessages(count int) func(o *Options) { } } +// WithStreamCallback sets a callback to receive streaming events during execution. +// When set alongside a StreamingLLM, final answer generation will stream token-by-token. +func WithStreamCallback(fn StreamCallback) func(o *Options) { + return func(o *Options) { + o.streamCallback = fn + } +} + type defaultSinkStateTool struct{} func (d *defaultSinkStateTool) Execute(args map[string]any) (string, any, error) { diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..813f03b --- /dev/null +++ b/stream.go @@ -0,0 +1,31 @@ +package cogito + +// StreamEventType identifies the kind of streaming event. +type StreamEventType string + +const ( + StreamEventReasoning StreamEventType = "reasoning" // LLM thinking delta + StreamEventContent StreamEventType = "content" // answer text delta + StreamEventToolCall StreamEventType = "tool_call" // tool selected + args + StreamEventToolResult StreamEventType = "tool_result" // tool execution result + StreamEventStatus StreamEventType = "status" // status message + StreamEventDone StreamEventType = "done" // stream complete + StreamEventError StreamEventType = "error" // error +) + +// StreamEvent represents a single streaming event from the LLM or tool pipeline. +type StreamEvent struct { + Type StreamEventType + Content string // text delta (reasoning/content) + ToolName string // for tool_call/tool_result — name (first chunk only) + ToolArgs string // for tool_call: argument delta string + ToolCallID string // OpenAI tool_call ID (first chunk only) + ToolCallIndex int // which tool call (for parallel tool calls) + ToolResult string // tool result text + FinishReason string // "stop", "tool_calls", etc. (populated on done) + Error error // populated on error + Usage LLMUsage // populated on done +} + +// StreamCallback is a function that receives streaming events. +type StreamCallback func(StreamEvent) diff --git a/streaming_toolcall_test.go b/streaming_toolcall_test.go new file mode 100644 index 0000000..d1a9d25 --- /dev/null +++ b/streaming_toolcall_test.go @@ -0,0 +1,204 @@ +package cogito + +import ( + "context" + "testing" + + "github.com/sashabaranov/go-openai" +) + +// mockStreamingLLM implements both LLM and StreamingLLM for testing. +type mockStreamingLLM struct { + events []StreamEvent +} + +func (m *mockStreamingLLM) Ask(ctx context.Context, f Fragment) (Fragment, error) { + return f, nil +} + +func (m *mockStreamingLLM) CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, LLMUsage, error) { + return LLMReply{}, LLMUsage{}, nil +} + +func (m *mockStreamingLLM) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan StreamEvent, error) { + ch := make(chan StreamEvent, len(m.events)) + for _, ev := range m.events { + ch <- ev + } + close(ch) + return ch, nil +} + +func TestAskWithStreamingSingleToolCall(t *testing.T) { + llm := &mockStreamingLLM{ + events: []StreamEvent{ + {Type: StreamEventReasoning, Content: "Let me search for that."}, + {Type: StreamEventToolCall, ToolCallID: "call_abc123", ToolName: "search", ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolArgs: `{"que`, ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolArgs: `ry": "`, ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolArgs: `photosynthesis"}`, ToolCallIndex: 0}, + {Type: StreamEventDone, FinishReason: "tool_calls"}, + }, + } + + f := NewEmptyFragment().AddMessage(UserMessageRole, "What is photosynthesis?") + + var received []StreamEvent + cb := func(ev StreamEvent) { + received = append(received, ev) + } + + result, err := askWithStreaming(context.Background(), llm, f, cb) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should have received all events via callback + if len(received) != 6 { + t.Fatalf("expected 6 callback events, got %d", len(received)) + } + + // Check the resulting fragment's last message + lastMsg := result.Messages[len(result.Messages)-1] + if lastMsg.Role != "assistant" { + t.Fatalf("expected assistant role, got %s", lastMsg.Role) + } + if lastMsg.ReasoningContent != "Let me search for that." { + t.Fatalf("expected reasoning content, got %q", lastMsg.ReasoningContent) + } + if len(lastMsg.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(lastMsg.ToolCalls)) + } + + tc := lastMsg.ToolCalls[0] + if tc.ID != "call_abc123" { + t.Errorf("expected tool call ID 'call_abc123', got %q", tc.ID) + } + if tc.Function.Name != "search" { + t.Errorf("expected tool name 'search', got %q", tc.Function.Name) + } + if tc.Function.Arguments != `{"query": "photosynthesis"}` { + t.Errorf("expected accumulated args, got %q", tc.Function.Arguments) + } + if tc.Type != openai.ToolTypeFunction { + t.Errorf("expected tool type function, got %q", tc.Type) + } +} + +func TestAskWithStreamingParallelToolCalls(t *testing.T) { + llm := &mockStreamingLLM{ + events: []StreamEvent{ + // Interleaved deltas for two parallel tool calls + {Type: StreamEventToolCall, ToolCallID: "call_1", ToolName: "search", ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolCallID: "call_2", ToolName: "weather", ToolCallIndex: 1}, + {Type: StreamEventToolCall, ToolArgs: `{"query":`, ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolArgs: `{"city":`, ToolCallIndex: 1}, + {Type: StreamEventToolCall, ToolArgs: `"test"}`, ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolArgs: `"NYC"}`, ToolCallIndex: 1}, + {Type: StreamEventDone, FinishReason: "tool_calls"}, + }, + } + + f := NewEmptyFragment().AddMessage(UserMessageRole, "Search and weather") + result, err := askWithStreaming(context.Background(), llm, f, func(ev StreamEvent) {}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + lastMsg := result.Messages[len(result.Messages)-1] + if len(lastMsg.ToolCalls) != 2 { + t.Fatalf("expected 2 tool calls, got %d", len(lastMsg.ToolCalls)) + } + + // Verify order matches index appearance order + if lastMsg.ToolCalls[0].ID != "call_1" || lastMsg.ToolCalls[0].Function.Name != "search" { + t.Errorf("first tool call mismatch: %+v", lastMsg.ToolCalls[0]) + } + if lastMsg.ToolCalls[0].Function.Arguments != `{"query":"test"}` { + t.Errorf("first tool call args mismatch: %q", lastMsg.ToolCalls[0].Function.Arguments) + } + if lastMsg.ToolCalls[1].ID != "call_2" || lastMsg.ToolCalls[1].Function.Name != "weather" { + t.Errorf("second tool call mismatch: %+v", lastMsg.ToolCalls[1]) + } + if lastMsg.ToolCalls[1].Function.Arguments != `{"city":"NYC"}` { + t.Errorf("second tool call args mismatch: %q", lastMsg.ToolCalls[1].Function.Arguments) + } +} + +func TestAskWithStreamingMixedContentAndToolCalls(t *testing.T) { + llm := &mockStreamingLLM{ + events: []StreamEvent{ + {Type: StreamEventContent, Content: "I'll help "}, + {Type: StreamEventContent, Content: "you with that."}, + {Type: StreamEventToolCall, ToolCallID: "call_x", ToolName: "lookup", ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolArgs: `{"id": 42}`, ToolCallIndex: 0}, + {Type: StreamEventDone, FinishReason: "tool_calls"}, + }, + } + + f := NewEmptyFragment().AddMessage(UserMessageRole, "Look up item 42") + result, err := askWithStreaming(context.Background(), llm, f, func(ev StreamEvent) {}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + lastMsg := result.Messages[len(result.Messages)-1] + if lastMsg.Content != "I'll help you with that." { + t.Errorf("content mismatch: %q", lastMsg.Content) + } + if len(lastMsg.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(lastMsg.ToolCalls)) + } + if lastMsg.ToolCalls[0].Function.Name != "lookup" { + t.Errorf("tool name mismatch: %q", lastMsg.ToolCalls[0].Function.Name) + } +} + +func TestAskWithStreamingNoToolCalls(t *testing.T) { + llm := &mockStreamingLLM{ + events: []StreamEvent{ + {Type: StreamEventContent, Content: "Hello world"}, + {Type: StreamEventDone, FinishReason: "stop"}, + }, + } + + f := NewEmptyFragment().AddMessage(UserMessageRole, "Hi") + result, err := askWithStreaming(context.Background(), llm, f, func(ev StreamEvent) {}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + lastMsg := result.Messages[len(result.Messages)-1] + if lastMsg.Content != "Hello world" { + t.Errorf("content mismatch: %q", lastMsg.Content) + } + if len(lastMsg.ToolCalls) != 0 { + t.Errorf("expected no tool calls, got %d", len(lastMsg.ToolCalls)) + } +} + +func TestAskWithStreamingFinishReason(t *testing.T) { + var doneEvent StreamEvent + + llm := &mockStreamingLLM{ + events: []StreamEvent{ + {Type: StreamEventToolCall, ToolCallID: "call_1", ToolName: "fn", ToolCallIndex: 0}, + {Type: StreamEventToolCall, ToolArgs: `{}`, ToolCallIndex: 0}, + {Type: StreamEventDone, FinishReason: "tool_calls"}, + }, + } + + f := NewEmptyFragment().AddMessage(UserMessageRole, "test") + _, err := askWithStreaming(context.Background(), llm, f, func(ev StreamEvent) { + if ev.Type == StreamEventDone { + doneEvent = ev + } + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if doneEvent.FinishReason != "tool_calls" { + t.Errorf("expected finish reason 'tool_calls', got %q", doneEvent.FinishReason) + } +} diff --git a/tools.go b/tools.go index 389397b..6812157 100644 --- a/tools.go +++ b/tools.go @@ -232,6 +232,135 @@ func normalizeSystemMessages(messages []openai.ChatCompletionMessage) []openai.C return result } +// decisionWithStreaming is like decision but uses streaming when a StreamingLLM and +// callback are available, forwarding reasoning/content/tool_call deltas live. +// Falls back to decision() when streaming is not possible. +func decisionWithStreaming(ctx context.Context, llm LLM, conversation []openai.ChatCompletionMessage, + tools Tools, forceTool string, maxRetries int, streamCB StreamCallback) (*decisionResult, error) { + + sllm, isStreaming := llm.(StreamingLLM) + if !isStreaming || streamCB == nil { + return decision(ctx, llm, conversation, tools, forceTool, maxRetries) + } + + req := openai.ChatCompletionRequest{ + Messages: normalizeSystemMessages(conversation), + Tools: tools.ToOpenAI(), + } + + if forceTool != "" { + req.ToolChoice = openai.ToolChoice{ + Type: openai.ToolTypeFunction, + Function: openai.ToolFunction{Name: forceTool}, + } + } + + xlog.Debug("[decisionWithStreaming] available tools for selection", "tools", tools.Names()) + + var lastErr error + for attempts := 0; attempts < maxRetries; attempts++ { + ch, err := sllm.CreateChatCompletionStream(ctx, req) + if err != nil { + lastErr = err + xlog.Warn("Streaming attempt to make a decision failed", "attempt", attempts+1, "error", err) + time.Sleep(time.Duration(attempts+1) * time.Second) + continue + } + + var contentBuf strings.Builder + var reasoningBuf strings.Builder + toolCallMap := make(map[int]*openai.ToolCall) + var toolCallOrder []int + var streamErr error + var usage LLMUsage + + for ev := range ch { + streamCB(ev) + switch ev.Type { + case StreamEventContent: + contentBuf.WriteString(ev.Content) + case StreamEventReasoning: + reasoningBuf.WriteString(ev.Content) + case StreamEventToolCall: + idx := ev.ToolCallIndex + tc, exists := toolCallMap[idx] + if !exists { + tc = &openai.ToolCall{ + Type: openai.ToolTypeFunction, + } + toolCallMap[idx] = tc + toolCallOrder = append(toolCallOrder, idx) + } + if ev.ToolCallID != "" { + tc.ID = ev.ToolCallID + } + if ev.ToolName != "" { + tc.Function.Name = ev.ToolName + } + tc.Function.Arguments += ev.ToolArgs + case StreamEventDone: + usage = ev.Usage + case StreamEventError: + streamErr = ev.Error + } + } + + if streamErr != nil { + lastErr = streamErr + xlog.Warn("Streaming decision encountered error", "attempt", attempts+1, "error", streamErr) + time.Sleep(time.Duration(attempts+1) * time.Second) + continue + } + + // Build tool calls slice in index order + var toolCalls []openai.ToolCall + for _, idx := range toolCallOrder { + toolCalls = append(toolCalls, *toolCallMap[idx]) + } + + reasoning := reasoningBuf.String() + content := contentBuf.String() + + xlog.Debug("[decisionWithStreaming] processed", "message", content, "reasoning", reasoning) + + if len(toolCalls) == 0 { + return &decisionResult{message: content, reasoning: reasoning, usage: usage}, nil + } + + // Process all tool calls + toolChoices := make([]*ToolChoice, 0, len(toolCalls)) + allParsed := true + for _, toolCall := range toolCalls { + arguments := make(map[string]any) + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &arguments); err != nil { + lastErr = err + xlog.Warn("Attempt to parse streamed tool arguments failed", "attempt", attempts+1, "error", err) + allParsed = false + break + } + toolChoices = append(toolChoices, &ToolChoice{ + Name: toolCall.Function.Name, + Arguments: arguments, + }) + } + + if !allParsed { + time.Sleep(time.Duration(attempts+1) * time.Second) + continue + } + + xlog.Debug("[decisionWithStreaming] tools selected", "message", content, "toolChoices", len(toolChoices)) + return &decisionResult{ + toolChoices: toolChoices, + message: content, + reasoning: reasoning, + usage: usage, + }, nil + } + + return nil, fmt.Errorf("failed to make a streaming decision after %d attempts: %w", maxRetries, lastErr) +} + // decision forces the LLM to make a tool choice with retry logic // Similar to agent.go's decision function but adapted for cogito's architecture func decision(ctx context.Context, llm LLM, conversation []openai.ChatCompletionMessage, @@ -362,12 +491,12 @@ func generateToolParameters(o *Options, llm LLM, tool ToolDefinitionInterface, c } // Use decision with reasoning tool to force structured output - paramReasoningResult, err := decision(o.context, llm, + paramReasoningResult, err := decisionWithStreaming(o.context, llm, append(conversation, openai.ChatCompletionMessage{ Role: "system", Content: paramPrompt, }), - Tools{reasoningTool()}, "reasoning", o.maxRetries) + Tools{reasoningTool()}, "reasoning", o.maxRetries, o.streamCallback) if err != nil { xlog.Warn("Failed to get parameter reasoning, using original reasoning", "error", err) // Fall back to original single-step approach @@ -403,7 +532,7 @@ func generateToolParameters(o *Options, llm LLM, tool ToolDefinitionInterface, c } // Use decision to force parameter generation - result, err := decision(o.context, llm, conv, Tools{tool}, toolFunc.Name, o.maxRetries) + result, err := decisionWithStreaming(o.context, llm, conv, Tools{tool}, toolFunc.Name, o.maxRetries, o.streamCallback) if err != nil { return nil, fmt.Errorf("failed to generate parameters for tool %s: %w", toolFunc.Name, err) } @@ -434,7 +563,7 @@ func pickTool(ctx context.Context, llm LLM, fragment Fragment, tools Tools, opts // If not forcing reasoning, try direct tool selection if !o.forceReasoning { xlog.Debug("[pickTool] Using direct tool selection") - result, err := decision(ctx, llm, messages, tools, "", o.maxRetries) + result, err := decisionWithStreaming(ctx, llm, messages, tools, "", o.maxRetries, o.streamCallback) if err != nil { return nil, fmt.Errorf("tool selection failed: %w", err) } @@ -465,12 +594,12 @@ func pickTool(ctx context.Context, llm LLM, fragment Fragment, tools Tools, opts } } - reasoningResult, err := decision(ctx, llm, + reasoningResult, err := decisionWithStreaming(ctx, llm, append(messages, openai.ChatCompletionMessage{ Role: "user", Content: reasoningPrompt, }), - Tools{reasoningTool()}, "reasoning", o.maxRetries) + Tools{reasoningTool()}, "reasoning", o.maxRetries, o.streamCallback) if err != nil { return nil, fmt.Errorf("failed to get reasoning: %w", err) } @@ -530,9 +659,9 @@ func pickTool(ctx context.Context, llm LLM, fragment Fragment, tools Tools, opts }) } - intentionResult, err := decision(ctx, llm, + intentionResult, err := decisionWithStreaming(ctx, llm, intentionMessages, - intentionTools, intentionToolName, o.maxRetries) + intentionTools, intentionToolName, o.maxRetries, o.streamCallback) if err != nil { return nil, fmt.Errorf("failed to pick tool via intention: %w", err) } @@ -849,6 +978,89 @@ func (s *SessionState) Resume(llm LLM, opts ...Option) (Fragment, error) { return ExecuteTools(llm, s.Fragment, append(opts, WithStartWithAction(s.ToolChoice))...) } +// askWithStreaming calls llm.Ask() but uses streaming when available and a stream callback is set. +// It type-asserts the LLM to StreamingLLM, streams events via the callback, and accumulates +// the full response into a Fragment identical to what Ask() would return. +func askWithStreaming(ctx context.Context, llm LLM, f Fragment, streamCB StreamCallback) (Fragment, error) { + sllm, isStreaming := llm.(StreamingLLM) + if !isStreaming || streamCB == nil { + return llm.Ask(ctx, f) + } + + messages := f.GetMessages() + ch, err := sllm.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{ + Messages: messages, + }) + if err != nil { + // Fall back to non-streaming on error + xlog.Warn("Streaming failed, falling back to non-streaming", "error", err) + return llm.Ask(ctx, f) + } + + var contentBuf strings.Builder + var reasoningBuf strings.Builder + var lastErr error + + // Tool call accumulator + toolCallMap := make(map[int]*openai.ToolCall) + var toolCallOrder []int + + for ev := range ch { + streamCB(ev) + switch ev.Type { + case StreamEventContent: + contentBuf.WriteString(ev.Content) + case StreamEventReasoning: + reasoningBuf.WriteString(ev.Content) + case StreamEventToolCall: + idx := ev.ToolCallIndex + tc, exists := toolCallMap[idx] + if !exists { + tc = &openai.ToolCall{ + Type: openai.ToolTypeFunction, + } + toolCallMap[idx] = tc + toolCallOrder = append(toolCallOrder, idx) + } + if ev.ToolCallID != "" { + tc.ID = ev.ToolCallID + } + if ev.ToolName != "" { + tc.Function.Name = ev.ToolName + } + tc.Function.Arguments += ev.ToolArgs + case StreamEventError: + lastErr = ev.Error + } + } + + if lastErr != nil { + return f, fmt.Errorf("streaming error: %w", lastErr) + } + + // Build tool calls slice in index order + var toolCalls []openai.ToolCall + for _, idx := range toolCallOrder { + toolCalls = append(toolCalls, *toolCallMap[idx]) + } + + msg := openai.ChatCompletionMessage{ + Role: "assistant", + Content: contentBuf.String(), + ReasoningContent: reasoningBuf.String(), + ToolCalls: toolCalls, + } + result := Fragment{ + Messages: append(f.Messages, msg), + ParentFragment: &f, + Status: f.Status, + } + if result.Status == nil { + result.Status = &Status{} + } + return result, nil +} + // ExecuteTools runs a fragment through an LLM, and executes Tools. It returns a new fragment with the tool result at the end // The result is guaranteed that can be called afterwards with llm.Ask() to explain the result to the user. func ExecuteTools(llm LLM, f Fragment, opts ...Option) (Fragment, error) { @@ -965,7 +1177,7 @@ TOOL_LOOP: status := f.Status parentBeforeAsk := f.ParentFragment - f, err := llm.Ask(o.context, f) + f, err := askWithStreaming(o.context, llm, f, o.streamCallback) if err != nil { return f, fmt.Errorf("failed to ask LLM: %w", err) } @@ -1397,7 +1609,7 @@ Please provide revised tool call based on this feedback.`, xlog.Debug("Sink state was found, stopping execution after processing tools") status := f.Status var err error - f, err = llm.Ask(o.context, f) + f, err = askWithStreaming(o.context, llm, f, o.streamCallback) if err != nil { return f, fmt.Errorf("failed to ask LLM: %w", err) }