diff --git a/src/main/java/dev/ai4j/openai4j/StreamingRequestExecutor.java b/src/main/java/dev/ai4j/openai4j/StreamingRequestExecutor.java index 3ef3e62..0a4cde4 100644 --- a/src/main/java/dev/ai4j/openai4j/StreamingRequestExecutor.java +++ b/src/main/java/dev/ai4j/openai4j/StreamingRequestExecutor.java @@ -167,7 +167,6 @@ public void onEvent(EventSource eventSource, String id, String type, String data } if ("[DONE]".equals(data)) { - streamingCompletionCallback.run(); return; } @@ -192,6 +191,8 @@ public void onClosed(EventSource eventSource) { if (logStreamingResponses) { log.debug("onClosed()"); } + + streamingCompletionCallback.run(); } @Override @@ -208,7 +209,9 @@ public void onFailure(EventSource eventSource, Throwable t, okhttp3.Response res if (logStreamingResponses) { log.debug("onFailure()", t); - responseLogger.log(response); + if (response != null) { + responseLogger.log(response); + } } if (t != null) { diff --git a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java index 10e5f06..2591600 100644 --- a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java +++ b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionAsyncTest.java @@ -6,6 +6,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -96,7 +98,7 @@ void testCustomizableApi(ChatCompletionModel model) throws Exception { void testTools(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -118,7 +120,7 @@ void testTools(ChatCompletionModel model) throws Exception { AssistantMessage assistantMessage = response.choices().get(0).message(); assertThat(assistantMessage.content()).isNull(); assertThat(assistantMessage.functionCall()).isNull(); - assertThat(assistantMessage.toolCalls()).isNotNull().hasSize(1); + assertThat(assistantMessage.toolCalls()).isNotNull().hasSizeBetween(1, 2); ToolCall toolCall = assistantMessage.toolCalls().get(0); assertThat(toolCall.id()).isNotBlank(); @@ -139,9 +141,24 @@ void testTools(ChatCompletionModel model) throws Exception { String currentWeather = currentWeather(location, unit); ToolMessage toolMessage = ToolMessage.from(toolCall.id(), currentWeather); + List messages = new ArrayList<>(); + messages.add(userMessage); + messages.add(assistantMessage); + + for (ToolCall toolCall2 : assistantMessage.toolCalls()) { + FunctionCall functionCall2 = toolCall2.function(); + Map arguments2 = argumentsAsMap(functionCall2.arguments()); + + String location2 = argument("location", functionCall2); + String unit2 = argument("unit", functionCall2); + String currentWeather2 = currentWeather(location2, unit2); + ToolMessage toolMessage2 = ToolMessage.from(toolCall2.id(), currentWeather2); + messages.add(toolMessage2); + } + ChatCompletionRequest secondRequest = ChatCompletionRequest.builder() .model(model) - .messages(userMessage, assistantMessage, toolMessage) + .messages(messages) .build(); // when @@ -158,7 +175,7 @@ void testTools(ChatCompletionModel model) throws Exception { void testFunctions(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -214,7 +231,7 @@ void testFunctions(ChatCompletionModel model) throws Exception { void testToolChoice(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -276,7 +293,7 @@ void testToolChoice(ChatCompletionModel model) throws Exception { void testFunctionChoice(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) diff --git a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionStreamingTest.java b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionStreamingTest.java index 335ba82..638e52e 100644 --- a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionStreamingTest.java +++ b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionStreamingTest.java @@ -7,7 +7,11 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; @@ -24,10 +28,13 @@ import static java.util.concurrent.Executors.newSingleThreadExecutor; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.params.provider.EnumSource.Mode.EXCLUDE; class ChatCompletionStreamingTest extends RateLimitAwareTest { + private static final Logger log = LoggerFactory.getLogger(ChatCompletionStreamingTest.class); + private final OpenAiClient client = OpenAiClient.builder() .baseUrl(System.getenv("OPENAI_BASE_URL")) .openAiApiKey(System.getenv("OPENAI_API_KEY")) @@ -121,7 +128,7 @@ void testCustomizableApi(ChatCompletionModel model) throws Exception { void testTools(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -139,12 +146,17 @@ void testTools(ChatCompletionModel model) throws Exception { .onPartialResponse(partialResponse -> { Delta delta = partialResponse.choices().get(0).delta(); assertThat(delta.content()).isNull(); - assertThat(delta.functionCall()).isNull(); + assertThat(delta.functionCall()).isIn(null, ""); if (delta.toolCalls() != null) { assertThat(delta.toolCalls()).hasSize(1); ToolCall toolCall = delta.toolCalls().get(0); + if (toolCall.index() > 0) { + // skip other function candidates + return; + } + assertThat(toolCall.type()).isIn(null, FUNCTION); assertThat(toolCall.function()).isNotNull(); @@ -177,7 +189,7 @@ void testTools(ChatCompletionModel model) throws Exception { .onError(future::completeExceptionally) .execute(); - AssistantMessage assistantMessage = future.get(30, SECONDS); + AssistantMessage assistantMessage = future.get(120, SECONDS); // then assertThat(assistantMessage.content()).isNull(); @@ -203,9 +215,24 @@ void testTools(ChatCompletionModel model) throws Exception { String currentWeather = currentWeather(location, unit); ToolMessage toolMessage = ToolMessage.from(toolCall.id(), currentWeather); + List messages = new ArrayList<>(); + messages.add(userMessage); + messages.add(assistantMessage); + + for (ToolCall toolCall2 : assistantMessage.toolCalls()) { + FunctionCall functionCall2 = toolCall2.function(); + Map arguments2 = argumentsAsMap(functionCall2.arguments()); + + String location2 = argument("location", functionCall2); + String unit2 = argument("unit", functionCall2); + String currentWeather2 = currentWeather(location2, unit2); + ToolMessage toolMessage2 = ToolMessage.from(toolCall2.id(), currentWeather2); + messages.add(toolMessage2); + } + ChatCompletionRequest secondRequest = ChatCompletionRequest.builder() .model(model) - .messages(userMessage, assistantMessage, toolMessage) + .messages(messages) .build(); // when @@ -236,7 +263,7 @@ void testTools(ChatCompletionModel model) throws Exception { void testFunctions(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -330,7 +357,7 @@ void testFunctions(ChatCompletionModel model) throws Exception { void testToolChoice(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -445,7 +472,7 @@ void testToolChoice(ChatCompletionModel model) throws Exception { void testFunctionChoice(ChatCompletionModel model) throws Exception { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -774,45 +801,31 @@ void testCancelStreamingAfterStreamingStarted() throws Exception { .logStreamingResponses() .build(); - AtomicBoolean streamingStarted = new AtomicBoolean(false); - AtomicBoolean streamingCancelled = new AtomicBoolean(false); - AtomicBoolean cancellationSucceeded = new AtomicBoolean(true); + final AtomicBoolean streamingCancelled = new AtomicBoolean(false); + final AtomicReference atomicReference = new AtomicReference<>(); + final CompletableFuture completableFuture = new CompletableFuture<>(); ResponseHandle responseHandle = client.chatCompletion("Write a poem about AI in 10 words") .onPartialResponse(partialResponse -> { - streamingStarted.set(true); - System.out.println("[[streaming started]]"); - if (streamingCancelled.get()) { - cancellationSucceeded.set(false); - System.out.println("[[cancellation failed]]"); + if (! streamingCancelled.getAndSet(true)) { + log.info("onPartialResponse thread {}", Thread.currentThread()); + + CompletableFuture.runAsync(() -> { + log.info("cancelling thread {}", Thread.currentThread()); + atomicReference.get().cancel(); + completableFuture.complete(null); + }); } }) - .onComplete(() -> { - cancellationSucceeded.set(false); - System.out.println("[[cancellation failed]]"); - }) - .onError(e -> { - cancellationSucceeded.set(false); - System.out.println("[[cancellation failed]]"); - }) + .onComplete(() -> fail("Response completed")) + .onError(e -> fail("Response errored")) .execute(); - while (!streamingStarted.get()) { - Thread.sleep(10); - } + log.info("Test thread {}", Thread.currentThread()); + atomicReference.set(responseHandle); + completableFuture.get(); - newSingleThreadExecutor().execute(() -> { - responseHandle.cancel(); - streamingCancelled.set(true); - System.out.println("[[streaming cancelled]]"); - }); - - while (!streamingCancelled.get()) { - Thread.sleep(10); - } - Thread.sleep(2000); - - assertThat(cancellationSucceeded).isTrue(); + assertThat(streamingCancelled).isTrue(); } @Test diff --git a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java index 251cc0c..07facd2 100644 --- a/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java +++ b/src/test/java/dev/ai4j/openai4j/chat/ChatCompletionTest.java @@ -6,7 +6,9 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import java.util.ArrayList; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import static dev.ai4j.openai4j.chat.ChatCompletionModel.GPT_4O_MINI; @@ -24,6 +26,7 @@ class ChatCompletionTest extends RateLimitAwareTest { static final String SYSTEM_MESSAGE = "Be concise"; static final String USER_MESSAGE = "Write exactly the following 2 words: 'hello world'"; + static final String WEATHER_PROMPT = "What is the weather in Boston in Celsius?"; static final String WEATHER_TOOL_NAME = "get_current_weather"; static final Function WEATHER_FUNCTION = Function.builder() @@ -102,7 +105,7 @@ void testCustomizableApi(ChatCompletionModel model) { void testTools(ChatCompletionModel model) { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -117,7 +120,7 @@ void testTools(ChatCompletionModel model) { AssistantMessage assistantMessage = response.choices().get(0).message(); assertThat(assistantMessage.content()).isNull(); assertThat(assistantMessage.functionCall()).isNull(); - assertThat(assistantMessage.toolCalls()).isNotNull().hasSize(1); + assertThat(assistantMessage.toolCalls()).isNotNull().hasSizeBetween(1, 2); ToolCall toolCall = assistantMessage.toolCalls().get(0); assertThat(toolCall.id()).isNotBlank(); @@ -138,9 +141,24 @@ void testTools(ChatCompletionModel model) { String currentWeather = currentWeather(location, unit); ToolMessage toolMessage = ToolMessage.from(toolCall.id(), currentWeather); + List messages = new ArrayList<>(); + messages.add(userMessage); + messages.add(assistantMessage); + + for (ToolCall toolCall2 : assistantMessage.toolCalls()) { + FunctionCall functionCall2 = toolCall2.function(); + Map arguments2 = argumentsAsMap(functionCall2.arguments()); + + String location2 = argument("location", functionCall2); + String unit2 = argument("unit", functionCall2); + String currentWeather2 = currentWeather(location2, unit2); + ToolMessage toolMessage2 = ToolMessage.from(toolCall2.id(), currentWeather2); + messages.add(toolMessage2); + } + ChatCompletionRequest secondRequest = ChatCompletionRequest.builder() .model(model) - .messages(userMessage, assistantMessage, toolMessage) + .messages(messages) .build(); // when @@ -282,7 +300,7 @@ void testToolWithoutParameters() { void testFunctions(ChatCompletionModel model) { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -331,7 +349,7 @@ void testFunctions(ChatCompletionModel model) { void testToolChoice(ChatCompletionModel model) { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) @@ -386,7 +404,7 @@ void testToolChoice(ChatCompletionModel model) { void testFunctionChoice(ChatCompletionModel model) { // given - UserMessage userMessage = UserMessage.from("What is the weather in Boston?"); + UserMessage userMessage = UserMessage.from(WEATHER_PROMPT); ChatCompletionRequest request = ChatCompletionRequest.builder() .model(model) diff --git a/src/test/java/dev/ai4j/openai4j/completion/CompletionStreamingTest.java b/src/test/java/dev/ai4j/openai4j/completion/CompletionStreamingTest.java index 1a962ee..1e85157 100644 --- a/src/test/java/dev/ai4j/openai4j/completion/CompletionStreamingTest.java +++ b/src/test/java/dev/ai4j/openai4j/completion/CompletionStreamingTest.java @@ -7,10 +7,12 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import static java.util.concurrent.Executors.newSingleThreadExecutor; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; class CompletionStreamingTest extends RateLimitAwareTest { @@ -65,7 +67,7 @@ void testCustomizableApi() throws Exception { } @Test - void testCancelStreamingAfterStreamingStarted() throws InterruptedException { + void testCancelStreamingAfterStreamingStarted() throws Exception { OpenAiClient client = OpenAiClient.builder() // without caching @@ -75,45 +77,27 @@ void testCancelStreamingAfterStreamingStarted() throws InterruptedException { .logStreamingResponses() .build(); - AtomicBoolean streamingStarted = new AtomicBoolean(false); - AtomicBoolean streamingCancelled = new AtomicBoolean(false); - AtomicBoolean cancellationSucceeded = new AtomicBoolean(true); + final AtomicBoolean streamingCancelled = new AtomicBoolean(false); + final AtomicReference atomicReference = new AtomicReference<>(); + final CompletableFuture completableFuture = new CompletableFuture<>(); ResponseHandle responseHandle = client.completion("Write a poem about AI in 10 words") .onPartialResponse(partialResponse -> { - streamingStarted.set(true); - System.out.println("[[streaming started]]"); - if (streamingCancelled.get()) { - cancellationSucceeded.set(false); - System.out.println("[[cancellation failed]]"); + if (! streamingCancelled.getAndSet(true)) { + CompletableFuture.runAsync(() -> { + atomicReference.get().cancel(); + completableFuture.complete(null); + }); } }) - .onComplete(() -> { - cancellationSucceeded.set(false); - System.out.println("[[cancellation failed]]"); - }) - .onError(e -> { - cancellationSucceeded.set(false); - System.out.println("[[cancellation failed]]"); - }) + .onComplete(() -> fail("Response completed")) + .onError(e -> fail("Response errored")) .execute(); - while (!streamingStarted.get()) { - Thread.sleep(10); - } - - newSingleThreadExecutor().execute(() -> { - responseHandle.cancel(); - streamingCancelled.set(true); - System.out.println("[[streaming cancelled]]"); - }); - - while (!streamingCancelled.get()) { - Thread.sleep(10); - } - Thread.sleep(2000); + atomicReference.set(responseHandle); + completableFuture.get(); - assertThat(cancellationSucceeded).isTrue(); + assertThat(streamingCancelled).isTrue(); } @Test