diff --git a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift index 8828751..db4eab3 100644 --- a/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/OpenAILanguageModel.swift @@ -1050,59 +1050,60 @@ private enum Responses { outputs.append(object) case .tool(let id): - let toolMessage = msg - // Wrap user content into a single top-level message as required by Responses API - var contentBlocks: [JSONValue] - switch toolMessage.content { + let outputValue: JSONValue + switch msg.content { case .text(let text): - contentBlocks = [ - .object(["type": .string("input_text"), "text": .string(text)]) - ] + outputValue = .string(text) case .blocks(let blocks): - contentBlocks = blocks.map { block in - switch block { - case .text(let text): - return .object(["type": .string("input_text"), "text": .string(text)]) - case .imageURL(let url): - return .object([ - "type": .string("input_image"), - "image_url": .object(["url": .string(url)]), - ]) + outputValue = .array( + blocks.map { block in + switch block { + case .text(let text): + return .object(["type": .string("input_text"), "text": .string(text)]) + case .imageURL(let url): + return .object([ + "type": .string("input_image"), + "image_url": .string(url), + ]) + } } - } - } - let outputString: String - if contentBlocks.count > 1 { - let encoder = JSONEncoder() - if let data = try? encoder.encode(JSONValue.array(contentBlocks)), - let str = String(data: data, encoding: .utf8) - { - outputString = str - } else { - outputString = "[]" - } - } else if let block = contentBlocks.first { - let encoder = JSONEncoder() - if let data = try? encoder.encode(block), - let str = String(data: data, encoding: .utf8) - { - outputString = str - } else { - outputString = "{}" - } - } else { - outputString = "{}" + ) } outputs.append( .object([ "type": .string("function_call_output"), "call_id": .string(id), - "output": .string(outputString), + "output": outputValue, ]) ) case .raw(rawContent: let rawContent): - outputs.append(rawContent) + // Convert Chat Completions assistant+tool_calls to Responses API function_call items + if case .object(let assistantMessageObject) = rawContent, + case .string(let messageRole) = assistantMessageObject["role"], + messageRole == "assistant", + case .array(let assistantToolCalls) = assistantMessageObject["tool_calls"] + { + for assistantToolCall in assistantToolCalls { + if case .object(let toolCallObject) = assistantToolCall, + case .string(let toolCallID) = toolCallObject["id"], + case .object(let functionCallObject) = toolCallObject["function"], + case .string(let functionName) = functionCallObject["name"], + case .string(let functionArguments) = functionCallObject["arguments"] + { + outputs.append( + .object([ + "type": .string("function_call"), + "call_id": .string(toolCallID), + "name": .string(functionName), + "arguments": .string(functionArguments), + ]) + ) + } + } + } else { + outputs.append(rawContent) + } case .system: let systemMessage = msg diff --git a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift index 2a7a4b8..4b4d2a5 100644 --- a/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/OpenAILanguageModelTests.swift @@ -1,4 +1,5 @@ import Foundation +import JSONSchema import Testing @testable import AnyLanguageModel @@ -14,7 +15,6 @@ struct OpenAILanguageModelTests { } @Test func apiVariantParameterization() throws { - // Test that both API variants can be created and have correct properties for apiVariant in [OpenAILanguageModel.APIVariant.chatCompletions, .responses] { let model = OpenAILanguageModel(apiKey: "test-key", model: "test-model", apiVariant: apiVariant) #expect(model.apiVariant == apiVariant) @@ -97,7 +97,6 @@ struct OpenAILanguageModelTests { maximumResponseTokens: 50 ) - // Set custom options (extraBody will be merged into the request) options[custom: OpenAILanguageModel.self] = .init( extraBody: ["user": .string("test-user-id")] ) @@ -138,19 +137,77 @@ struct OpenAILanguageModelTests { } @Test func withTools() async throws { - let weatherTool = WeatherTool() + let weatherTool = spy(on: WeatherTool()) let session = LanguageModelSession(model: model, tools: [weatherTool]) - let response = try await session.respond(to: "How's the weather in San Francisco?") + var options = GenerationOptions() + options[custom: OpenAILanguageModel.self] = .init( + maxToolCalls: 1 + ) + + let response = try await withOpenAIRateLimitRetry { + try await session.respond( + to: "Call getWeather for San Francisco exactly once, then summarize in one sentence.", + options: options + ) + } + + #expect(!response.content.isEmpty) + let calls = await weatherTool.calls + #expect(!calls.isEmpty) + if let firstCall = calls.first { + #expect(firstCall.arguments.city.localizedCaseInsensitiveContains("san")) + } + var foundToolCall = false var foundToolOutput = false - for case let .toolOutput(toolOutput) in response.transcriptEntries { - #expect(toolOutput.toolName == "getWeather") - foundToolOutput = true + for entry in response.transcriptEntries { + switch entry { + case .toolCalls(let toolCalls): + #expect(!toolCalls.isEmpty) + if let firstToolCall = toolCalls.first { + #expect(firstToolCall.toolName == "getWeather") + } + foundToolCall = true + case .toolOutput(let toolOutput): + #expect(toolOutput.toolName == "getWeather") + foundToolOutput = true + default: + break + } } + #expect(foundToolCall) #expect(foundToolOutput) } + @Test func withToolsConversationContinuesAcrossTurns() async throws { + let weatherTool = spy(on: WeatherTool()) + let session = LanguageModelSession(model: model, tools: [weatherTool]) + + var options = GenerationOptions() + options[custom: OpenAILanguageModel.self] = .init( + maxToolCalls: 1 + ) + + _ = try await withOpenAIRateLimitRetry { + try await session.respond( + to: "Call getWeather for San Francisco exactly once, then reply with only: done", + options: options + ) + } + + let secondResponse = try await withOpenAIRateLimitRetry { + try await session.respond( + to: "Which city did the tool call use? Reply with city only." + ) + } + #expect(!secondResponse.content.isEmpty) + #expect(secondResponse.content.localizedCaseInsensitiveContains("san")) + + let calls = await weatherTool.calls + #expect(calls.count >= 1) + } + @Suite("Structured Output") struct StructuredOutputTests { @Generable @@ -316,7 +373,6 @@ struct OpenAILanguageModelTests { maximumResponseTokens: 50 ) - // Set custom options (extraBody will be merged into the request) options[custom: OpenAILanguageModel.self] = .init( extraBody: ["user": "test-user-id"] ) @@ -459,4 +515,73 @@ struct OpenAILanguageModelTests { } } } + + @Suite("OpenAILanguageModel Responses Request Body") + struct ResponsesRequestBodyTests { + private let model = "test-model" + + private func inputArray(from body: JSONValue) -> [JSONValue]? { + guard case let .object(obj) = body else { return nil } + guard case let .array(input)? = obj["input"] else { return nil } + return input + } + + private func stringValue(_ value: JSONValue?) -> String? { + guard case let .string(text)? = value else { return nil } + return text + } + + private func firstObject(withType type: String, in input: [JSONValue]) -> [String: JSONValue]? { + for value in input { + guard case let .object(obj) = value else { continue } + guard case let .string(foundType)? = obj["type"], foundType == type else { continue } + return obj + } + return nil + } + + private func containsKey(_ value: JSONValue, key: String) -> Bool { + guard case let .object(obj) = value else { return false } + return obj[key] != nil + } + + private func makePrompt(_ text: String = "Continue.") -> Transcript.Prompt { + Transcript.Prompt(segments: [.text(.init(content: text))]) + } + + private func makeTranscriptWithToolCalls() throws -> Transcript { + let arguments = try GeneratedContent(json: #"{"city":"Paris"}"#) + let call = Transcript.ToolCall(id: "call-1", toolName: "getWeather", arguments: arguments) + let toolCalls = Transcript.ToolCalls([call]) + return Transcript(entries: [ + .toolCalls(toolCalls), + .prompt(makePrompt()), + ]) + } + } +} + +private func withOpenAIRateLimitRetry( + maxAttempts: Int = 4, + operation: @escaping () async throws -> T +) async throws -> T { + var attempt = 1 + while true { + do { + return try await operation() + } catch let error as URLSessionError { + if case .httpError(_, let detail) = error, + detail.contains("rate_limit_exceeded"), + attempt < maxAttempts + { + let delaySeconds = UInt64(attempt) + try await Task.sleep(nanoseconds: delaySeconds * 1_000_000_000) + attempt += 1 + continue + } + throw error + } catch { + throw error + } + } }