Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 42 additions & 41 deletions Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1050,59 +1050,60 @@ private enum Responses {
outputs.append(object)

case .tool(let id):
let toolMessage = msg
// Wrap user content into a single top-level message as required by Responses API
var contentBlocks: [JSONValue]
switch toolMessage.content {
let outputValue: JSONValue
switch msg.content {
case .text(let text):
contentBlocks = [
.object(["type": .string("input_text"), "text": .string(text)])
]
outputValue = .string(text)
case .blocks(let blocks):
contentBlocks = blocks.map { block in
switch block {
case .text(let text):
return .object(["type": .string("input_text"), "text": .string(text)])
case .imageURL(let url):
return .object([
"type": .string("input_image"),
"image_url": .object(["url": .string(url)]),
])
outputValue = .array(
blocks.map { block in
switch block {
case .text(let text):
return .object(["type": .string("input_text"), "text": .string(text)])
case .imageURL(let url):
return .object([
"type": .string("input_image"),
"image_url": .string(url),
])
}
}
}
}
let outputString: String
if contentBlocks.count > 1 {
let encoder = JSONEncoder()
if let data = try? encoder.encode(JSONValue.array(contentBlocks)),
let str = String(data: data, encoding: .utf8)
{
outputString = str
} else {
outputString = "[]"
}
} else if let block = contentBlocks.first {
let encoder = JSONEncoder()
if let data = try? encoder.encode(block),
let str = String(data: data, encoding: .utf8)
{
outputString = str
} else {
outputString = "{}"
}
} else {
outputString = "{}"
)
}
outputs.append(
.object([
"type": .string("function_call_output"),
"call_id": .string(id),
"output": .string(outputString),
"output": outputValue,
])
)

case .raw(rawContent: let rawContent):
outputs.append(rawContent)
// Convert Chat Completions assistant+tool_calls to Responses API function_call items
if case .object(let assistantMessageObject) = rawContent,
case .string(let messageRole) = assistantMessageObject["role"],
messageRole == "assistant",
case .array(let assistantToolCalls) = assistantMessageObject["tool_calls"]
{
for assistantToolCall in assistantToolCalls {
if case .object(let toolCallObject) = assistantToolCall,
case .string(let toolCallID) = toolCallObject["id"],
case .object(let functionCallObject) = toolCallObject["function"],
case .string(let functionName) = functionCallObject["name"],
case .string(let functionArguments) = functionCallObject["arguments"]
{
outputs.append(
.object([
"type": .string("function_call"),
"call_id": .string(toolCallID),
"name": .string(functionName),
"arguments": .string(functionArguments),
])
)
}
}
} else {
outputs.append(rawContent)
}

case .system:
let systemMessage = msg
Expand Down
141 changes: 133 additions & 8 deletions Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Foundation
import JSONSchema
import Testing

@testable import AnyLanguageModel
Expand All @@ -14,7 +15,6 @@ struct OpenAILanguageModelTests {
}

@Test func apiVariantParameterization() throws {
// Test that both API variants can be created and have correct properties
for apiVariant in [OpenAILanguageModel.APIVariant.chatCompletions, .responses] {
let model = OpenAILanguageModel(apiKey: "test-key", model: "test-model", apiVariant: apiVariant)
#expect(model.apiVariant == apiVariant)
Expand Down Expand Up @@ -97,7 +97,6 @@ struct OpenAILanguageModelTests {
maximumResponseTokens: 50
)

// Set custom options (extraBody will be merged into the request)
options[custom: OpenAILanguageModel.self] = .init(
extraBody: ["user": .string("test-user-id")]
)
Expand Down Expand Up @@ -138,19 +137,77 @@ struct OpenAILanguageModelTests {
}

@Test func withTools() async throws {
let weatherTool = WeatherTool()
let weatherTool = spy(on: WeatherTool())
let session = LanguageModelSession(model: model, tools: [weatherTool])

let response = try await session.respond(to: "How's the weather in San Francisco?")
var options = GenerationOptions()
options[custom: OpenAILanguageModel.self] = .init(
maxToolCalls: 1
)

let response = try await withOpenAIRateLimitRetry {
try await session.respond(
to: "Call getWeather for San Francisco exactly once, then summarize in one sentence.",
options: options
)
}

#expect(!response.content.isEmpty)
let calls = await weatherTool.calls
#expect(!calls.isEmpty)
if let firstCall = calls.first {
#expect(firstCall.arguments.city.localizedCaseInsensitiveContains("san"))
}

var foundToolCall = false
var foundToolOutput = false
for case let .toolOutput(toolOutput) in response.transcriptEntries {
#expect(toolOutput.toolName == "getWeather")
foundToolOutput = true
for entry in response.transcriptEntries {
switch entry {
case .toolCalls(let toolCalls):
#expect(!toolCalls.isEmpty)
if let firstToolCall = toolCalls.first {
#expect(firstToolCall.toolName == "getWeather")
}
foundToolCall = true
case .toolOutput(let toolOutput):
#expect(toolOutput.toolName == "getWeather")
foundToolOutput = true
default:
break
}
}
#expect(foundToolCall)
#expect(foundToolOutput)
}

@Test func withToolsConversationContinuesAcrossTurns() async throws {
let weatherTool = spy(on: WeatherTool())
let session = LanguageModelSession(model: model, tools: [weatherTool])

var options = GenerationOptions()
options[custom: OpenAILanguageModel.self] = .init(
maxToolCalls: 1
)

_ = try await withOpenAIRateLimitRetry {
try await session.respond(
to: "Call getWeather for San Francisco exactly once, then reply with only: done",
options: options
)
}

let secondResponse = try await withOpenAIRateLimitRetry {
try await session.respond(
to: "Which city did the tool call use? Reply with city only."
)
}
#expect(!secondResponse.content.isEmpty)
#expect(secondResponse.content.localizedCaseInsensitiveContains("san"))

let calls = await weatherTool.calls
#expect(calls.count >= 1)
}

@Suite("Structured Output")
struct StructuredOutputTests {
@Generable
Expand Down Expand Up @@ -316,7 +373,6 @@ struct OpenAILanguageModelTests {
maximumResponseTokens: 50
)

// Set custom options (extraBody will be merged into the request)
options[custom: OpenAILanguageModel.self] = .init(
extraBody: ["user": "test-user-id"]
)
Expand Down Expand Up @@ -459,4 +515,73 @@ struct OpenAILanguageModelTests {
}
}
}

@Suite("OpenAILanguageModel Responses Request Body")
struct ResponsesRequestBodyTests {
private let model = "test-model"

private func inputArray(from body: JSONValue) -> [JSONValue]? {
guard case let .object(obj) = body else { return nil }
guard case let .array(input)? = obj["input"] else { return nil }
return input
}

private func stringValue(_ value: JSONValue?) -> String? {
guard case let .string(text)? = value else { return nil }
return text
}

private func firstObject(withType type: String, in input: [JSONValue]) -> [String: JSONValue]? {
for value in input {
guard case let .object(obj) = value else { continue }
guard case let .string(foundType)? = obj["type"], foundType == type else { continue }
return obj
}
return nil
}

private func containsKey(_ value: JSONValue, key: String) -> Bool {
guard case let .object(obj) = value else { return false }
return obj[key] != nil
}

private func makePrompt(_ text: String = "Continue.") -> Transcript.Prompt {
Transcript.Prompt(segments: [.text(.init(content: text))])
}

private func makeTranscriptWithToolCalls() throws -> Transcript {
let arguments = try GeneratedContent(json: #"{"city":"Paris"}"#)
let call = Transcript.ToolCall(id: "call-1", toolName: "getWeather", arguments: arguments)
let toolCalls = Transcript.ToolCalls([call])
return Transcript(entries: [
.toolCalls(toolCalls),
.prompt(makePrompt()),
])
}
}
}

private func withOpenAIRateLimitRetry<T>(
maxAttempts: Int = 4,
operation: @escaping () async throws -> T
) async throws -> T {
var attempt = 1
while true {
do {
return try await operation()
} catch let error as URLSessionError {
if case .httpError(_, let detail) = error,
detail.contains("rate_limit_exceeded"),
attempt < maxAttempts
{
let delaySeconds = UInt64(attempt)
try await Task.sleep(nanoseconds: delaySeconds * 1_000_000_000)
attempt += 1
continue
}
throw error
} catch {
throw error
}
}
}