Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 141 additions & 1 deletion clients/localai_client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clients

import (
"bufio"
"bytes"
"context"
"encoding/json"
Expand All @@ -13,8 +14,9 @@ import (
"github.com/sashabaranov/go-openai"
)

// Ensure LocalAIClient implements cogito.LLM at compile time.
// Ensure LocalAIClient implements cogito.LLM and cogito.StreamingLLM at compile time.
var _ cogito.LLM = (*LocalAIClient)(nil)
var _ cogito.StreamingLLM = (*LocalAIClient)(nil)

// LocalAIClient is an LLM client for LocalAI-compatible APIs. It uses the same
// request format as OpenAI but parses an additional "reasoning" field in the
Expand Down Expand Up @@ -192,6 +194,144 @@ func (llm *LocalAIClient) CreateChatCompletion(ctx context.Context, request open
}, usage, nil
}

// localAIStreamToolCallFunction represents the function part of a streaming tool call delta.
type localAIStreamToolCallFunction struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}

// localAIStreamToolCall represents a single tool call delta in a streaming chunk.
type localAIStreamToolCall struct {
Index *int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function localAIStreamToolCallFunction `json:"function,omitempty"`
}

// localAIStreamDelta represents the delta object in a streaming chunk.
type localAIStreamDelta struct {
Content string `json:"content,omitempty"`
Reasoning string `json:"reasoning,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []localAIStreamToolCall `json:"tool_calls,omitempty"`
}

// localAIStreamChoice represents a single choice in a streaming chunk.
type localAIStreamChoice struct {
Delta localAIStreamDelta `json:"delta"`
FinishReason string `json:"finish_reason,omitempty"`
}

// localAIStreamChunk represents a single SSE chunk from LocalAI streaming.
type localAIStreamChunk struct {
Choices []localAIStreamChoice `json:"choices"`
}

// CreateChatCompletionStream streams chat completion events via a channel using SSE.
func (llm *LocalAIClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan cogito.StreamEvent, error) {
request.Model = llm.model
request.Stream = true

body, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("localai stream: marshal request: %w", err)
}

url := llm.baseURL + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("localai stream: new request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
if llm.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+llm.apiKey)
}

resp, err := llm.client.Do(req)
if err != nil {
return nil, fmt.Errorf("localai stream: request: %w", err)
}

if resp.StatusCode != http.StatusOK {
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("localai stream: status %d: %s", resp.StatusCode, string(respBody))
}

ch := make(chan cogito.StreamEvent, 64)
go func() {
defer close(ch)
defer resp.Body.Close()

var lastFinishReason string

scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()

if !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")

if data == "[DONE]" {
ch <- cogito.StreamEvent{Type: cogito.StreamEventDone, FinishReason: lastFinishReason}
return
}

var chunk localAIStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
if len(chunk.Choices) == 0 {
continue
}

delta := chunk.Choices[0].Delta
reasoning := delta.Reasoning
if reasoning == "" {
reasoning = delta.ReasoningContent
}
if reasoning != "" {
ch <- cogito.StreamEvent{Type: cogito.StreamEventReasoning, Content: reasoning}
}
if delta.Content != "" {
ch <- cogito.StreamEvent{Type: cogito.StreamEventContent, Content: delta.Content}
}

// Tool call deltas
for _, tc := range delta.ToolCalls {
idx := 0
if tc.Index != nil {
idx = *tc.Index
}
ch <- cogito.StreamEvent{
Type: cogito.StreamEventToolCall,
ToolName: tc.Function.Name,
ToolArgs: tc.Function.Arguments,
ToolCallID: tc.ID,
ToolCallIndex: idx,
}
}

// Capture finish_reason
if chunk.Choices[0].FinishReason != "" {
lastFinishReason = chunk.Choices[0].FinishReason
}
}

if err := scanner.Err(); err != nil {
ch <- cogito.StreamEvent{Type: cogito.StreamEventError, Error: err}
return
}
// If we reach here without [DONE], still emit done
ch <- cogito.StreamEvent{Type: cogito.StreamEventDone, FinishReason: lastFinishReason}
}()

return ch, nil
}

// Ask prompts the LLM with the provided messages and returns a Fragment
// containing the response. Uses CreateChatCompletion so reasoning is preserved.
// The Fragment's Status.LastUsage is updated with the token usage.
Expand Down
66 changes: 66 additions & 0 deletions clients/openai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package clients

import (
"context"
"errors"
"io"

"github.com/mudler/cogito"
"github.com/sashabaranov/go-openai"
)

var _ cogito.LLM = (*OpenAIClient)(nil)
var _ cogito.StreamingLLM = (*OpenAIClient)(nil)

type OpenAIClient struct {
model string
Expand Down Expand Up @@ -84,6 +87,69 @@ func (llm *OpenAIClient) CreateChatCompletion(ctx context.Context, request opena
}, usage, nil
}

// CreateChatCompletionStream streams chat completion events via a channel.
func (llm *OpenAIClient) CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan cogito.StreamEvent, error) {
request.Model = llm.model
request.Stream = true

stream, err := llm.client.CreateChatCompletionStream(ctx, request)
if err != nil {
return nil, err
}

ch := make(chan cogito.StreamEvent, 64)
go func() {
defer close(ch)
defer stream.Close()

var lastFinishReason string

for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
ch <- cogito.StreamEvent{Type: cogito.StreamEventDone, FinishReason: lastFinishReason}
return
}
if err != nil {
ch <- cogito.StreamEvent{Type: cogito.StreamEventError, Error: err}
return
}
if len(resp.Choices) == 0 {
continue
}
delta := resp.Choices[0].Delta
if delta.ReasoningContent != "" {
ch <- cogito.StreamEvent{Type: cogito.StreamEventReasoning, Content: delta.ReasoningContent}
}
if delta.Content != "" {
ch <- cogito.StreamEvent{Type: cogito.StreamEventContent, Content: delta.Content}
}

// Tool call deltas
for _, tc := range delta.ToolCalls {
idx := 0
if tc.Index != nil {
idx = *tc.Index
}
ch <- cogito.StreamEvent{
Type: cogito.StreamEventToolCall,
ToolName: tc.Function.Name,
ToolArgs: tc.Function.Arguments,
ToolCallID: tc.ID,
ToolCallIndex: idx,
}
}

// Capture finish_reason (arrives on last chunk)
if resp.Choices[0].FinishReason != "" {
lastFinishReason = string(resp.Choices[0].FinishReason)
}
}
}()

return ch, nil
}

// NewOpenAIService creates a new OpenAI service instance
func openaiClient(apiKey string, baseURL string) *openai.Client {
config := openai.DefaultConfig(apiKey)
Expand Down
7 changes: 7 additions & 0 deletions llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ type LLM interface {
CreateChatCompletion(ctx context.Context, request openai.ChatCompletionRequest) (LLMReply, LLMUsage, error)
}

// StreamingLLM extends LLM with streaming support.
// Consumers should type-assert: if sllm, ok := llm.(StreamingLLM); ok { ... }
type StreamingLLM interface {
LLM
CreateChatCompletionStream(ctx context.Context, request openai.ChatCompletionRequest) (<-chan StreamEvent, error)
}

type LLMReply struct {
ChatCompletionResponse openai.ChatCompletionResponse
ReasoningContent string
Expand Down
11 changes: 11 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ type Options struct {

messagesManipulator func([]openai.ChatCompletionMessage) []openai.ChatCompletionMessage

// Streaming callback for live token delivery
streamCallback StreamCallback

// Compaction options - automatic conversation compaction based on token count
compactionThreshold int // Token count threshold that triggers compaction (0 = disabled)
compactionKeepMessages int // Number of recent messages to keep after compaction
Expand Down Expand Up @@ -391,6 +394,14 @@ func WithCompactionKeepMessages(count int) func(o *Options) {
}
}

// WithStreamCallback sets a callback to receive streaming events during execution.
// When set alongside a StreamingLLM, final answer generation will stream token-by-token.
func WithStreamCallback(fn StreamCallback) func(o *Options) {
return func(o *Options) {
o.streamCallback = fn
}
}

type defaultSinkStateTool struct{}

func (d *defaultSinkStateTool) Execute(args map[string]any) (string, any, error) {
Expand Down
31 changes: 31 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package cogito

// StreamEventType identifies the kind of streaming event.
type StreamEventType string

const (
StreamEventReasoning StreamEventType = "reasoning" // LLM thinking delta
StreamEventContent StreamEventType = "content" // answer text delta
StreamEventToolCall StreamEventType = "tool_call" // tool selected + args
StreamEventToolResult StreamEventType = "tool_result" // tool execution result
StreamEventStatus StreamEventType = "status" // status message
StreamEventDone StreamEventType = "done" // stream complete
StreamEventError StreamEventType = "error" // error
)

// StreamEvent represents a single streaming event from the LLM or tool pipeline.
type StreamEvent struct {
Type StreamEventType
Content string // text delta (reasoning/content)
ToolName string // for tool_call/tool_result — name (first chunk only)
ToolArgs string // for tool_call: argument delta string
ToolCallID string // OpenAI tool_call ID (first chunk only)
ToolCallIndex int // which tool call (for parallel tool calls)
ToolResult string // tool result text
FinishReason string // "stop", "tool_calls", etc. (populated on done)
Error error // populated on error
Usage LLMUsage // populated on done
}

// StreamCallback is a function that receives streaming events.
type StreamCallback func(StreamEvent)
Loading
Loading