diff --git a/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java b/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java index 0d266fe08..34a3f099c 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java +++ b/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java @@ -160,24 +160,29 @@ public final String getDescription() { */ @Override public final Mono call(List msgs) { - if (!running.compareAndSet(false, true) && checkRunning) { - return Mono.error( - new IllegalStateException( - "Agent is still running, please wait for it to finish")); - } - resetInterruptFlag(); - - return TracerRegistry.get() - .callAgent( - this, - msgs, - () -> - notifyPreCall(msgs) - .flatMap(this::doCall) - .flatMap(this::notifyPostCall) - .onErrorResume( - createErrorHandler(msgs.toArray(new Msg[0])))) - .doFinally(signalType -> running.set(false)); + return Mono.using( + () -> { + if (checkRunning && !running.compareAndSet(false, true)) { + throw new IllegalStateException( + "Agent is still running, please wait for it to finish"); + } + resetInterruptFlag(); + return this; + }, + resource -> + TracerRegistry.get() + .callAgent( + this, + msgs, + () -> + notifyPreCall(msgs) + .flatMap(this::doCall) + .flatMap(this::notifyPostCall) + .onErrorResume( + createErrorHandler( + msgs.toArray(new Msg[0])))), + resource -> running.set(false), + true); } /** @@ -191,24 +196,33 @@ public final Mono call(List msgs) { */ @Override public final Mono call(List msgs, Class structuredOutputClass) { - if (!running.compareAndSet(false, true) && checkRunning) { - return Mono.error( - new IllegalStateException( - "Agent is still running, please wait for it to finish")); - } - resetInterruptFlag(); - - return TracerRegistry.get() - .callAgent( - this, - msgs, - () -> - notifyPreCall(msgs) - .flatMap(m -> doCall(m, structuredOutputClass)) - .flatMap(this::notifyPostCall) - .onErrorResume( - createErrorHandler(msgs.toArray(new Msg[0])))) - .doFinally(signalType -> running.set(false)); + return Mono.using( + () -> { + if (checkRunning && !running.compareAndSet(false, true)) { + throw new IllegalStateException( + "Agent is still running, please wait for it to finish"); + } + resetInterruptFlag(); + return this; + }, + resource -> + TracerRegistry.get() + .callAgent( + this, + msgs, + () -> + notifyPreCall(msgs) + .flatMap( + m -> + doCall( + m, + structuredOutputClass)) + .flatMap(this::notifyPostCall) + .onErrorResume( + createErrorHandler( + msgs.toArray(new Msg[0])))), + resource -> running.set(false), + true); } /** @@ -222,24 +236,29 @@ public final Mono call(List msgs, Class structuredOutputClass) { */ @Override public final Mono call(List msgs, JsonNode schema) { - if (!running.compareAndSet(false, true) && checkRunning) { - return Mono.error( - new IllegalStateException( - "Agent is still running, please wait for it to finish")); - } - resetInterruptFlag(); - - return TracerRegistry.get() - .callAgent( - this, - msgs, - () -> - notifyPreCall(msgs) - .flatMap(m -> doCall(m, schema)) - .flatMap(this::notifyPostCall) - .onErrorResume( - createErrorHandler(msgs.toArray(new Msg[0])))) - .doFinally(signalType -> running.set(false)); + return Mono.using( + () -> { + if (checkRunning && !running.compareAndSet(false, true)) { + throw new IllegalStateException( + "Agent is still running, please wait for it to finish"); + } + resetInterruptFlag(); + return this; + }, + resource -> + TracerRegistry.get() + .callAgent( + this, + msgs, + () -> + notifyPreCall(msgs) + .flatMap(m -> doCall(m, schema)) + .flatMap(this::notifyPostCall) + .onErrorResume( + createErrorHandler( + msgs.toArray(new Msg[0])))), + resource -> running.set(false), + true); } /** diff --git a/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java b/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java index 25ef7b67f..b375c9d2e 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; /** * Unit tests for AgentBase class. @@ -177,6 +178,30 @@ void testMultipleMessageInput() { "Response should be from the agent"); } + @Test + @DisplayName("Should not trigger concurrency conflict") + void testConcurrencyConflict() { + Msg message = TestUtils.createUserMessage("User", "First message"); + // Get response + Msg responseMsg = + agent.call(message) + .subscribeOn(Schedulers.boundedElastic()) // mock chat model + .block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + // no IllegalStateException throw + Msg response2 = + agent.call(message).block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + // Verify response + assertNotNull(responseMsg, "Response should not be null"); + assertEquals( + TestConstants.TEST_AGENT_NAME, + responseMsg.getName(), + "Response should be from the agent"); + + assertNotNull(response2, "Response should not be null"); + } + @Test @DisplayName("Should handle observe without generating reply") void testObserve() { diff --git a/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java b/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java index 7de520e2d..be9df745e 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java @@ -21,6 +21,7 @@ import io.agentscope.core.ReActAgent; import io.agentscope.core.agent.test.MockModel; +import io.agentscope.core.agent.test.TestConstants; import io.agentscope.core.memory.InMemoryMemory; import io.agentscope.core.memory.Memory; import io.agentscope.core.message.Msg; @@ -32,10 +33,12 @@ import io.agentscope.core.model.ChatUsage; import io.agentscope.core.tool.Toolkit; import io.agentscope.core.util.JsonUtils; +import java.time.Duration; import java.util.List; import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.scheduler.Schedulers; class ReActAgentStructuredOutputTest { @@ -530,4 +533,93 @@ void testStructuredOutputPreservesThinkingBlock() { thinking.getThinking(), "Thinking content should be preserved"); } + + @Test + void testConcurrencyConflictStructuredOutput() { + Memory memory = new InMemoryMemory(); + Map toolInput = + Map.of( + "response", + Map.of( + "location", "San Francisco", + "temperature", "72°F", + "condition", "Sunny")); + + MockModel mockModel = + new MockModel( + msgs -> { + boolean hasToolResults = + msgs.stream().anyMatch(m -> m.getRole() == MsgRole.TOOL); + if (!hasToolResults) { + return List.of( + ChatResponse.builder() + .id("msg_1") + .content( + List.of( + ToolUseBlock.builder() + .id("call_123") + .name("generate_response") + .input(toolInput) + .content( + JsonUtils + .getJsonCodec() + .toJson( + toolInput)) + .build())) + .usage(new ChatUsage(10, 20, 30)) + .build()); + } else { + return List.of( + ChatResponse.builder() + .id("msg_2") + .content( + List.of( + TextBlock.builder() + .text("Done") + .build())) + .usage(new ChatUsage(5, 10, 15)) + .build()); + } + }); + + ReActAgent agent = + ReActAgent.builder() + .name("weather-agent") + .sysPrompt("You are a weather assistant") + .model(mockModel) + .toolkit(toolkit) + .memory(memory) + .build(); + + Msg inputMsg = + Msg.builder() + .name("user") + .role(MsgRole.USER) + .content( + TextBlock.builder() + .text("What's the weather in San Francisco?") + .build()) + .build(); + + Msg responseMsg = + agent.call(inputMsg, WeatherResponse.class) + .subscribeOn(Schedulers.boundedElastic()) + .block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + Msg response2 = + agent.call(inputMsg, WeatherResponse.class) + .block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + assertNotNull(responseMsg); + WeatherResponse result = responseMsg.getStructuredData(WeatherResponse.class); + assertNotNull(result); + assertEquals("San Francisco", result.location); + assertEquals("72°F", result.temperature); + assertEquals("Sunny", result.condition); + + assertNotNull(response2); + // no IllegalStateException throw + WeatherResponse result2 = response2.getStructuredData(WeatherResponse.class); + assertNotNull(result2); + } }