From 90b78cb6f92687250a8cb3826da23d27421c0093 Mon Sep 17 00:00:00 2001 From: wuji1428 <2246065079@qq.com> Date: Thu, 5 Feb 2026 14:45:45 +0800 Subject: [PATCH] feat(AgentBase): Added concurrency conflict testing and optimized resource management Two new test methods were added to verify the agent's behavior under concurrent invocation, and resource management and exception handling mechanisms in the `AgentBase` class were improved. Related dependencies and import statements were also updated. --- .../io/agentscope/core/agent/AgentBase.java | 127 ++++++++++-------- .../agentscope/core/agent/AgentBaseTest.java | 25 ++++ .../agent/ReActAgentStructuredOutputTest.java | 92 +++++++++++++ 3 files changed, 190 insertions(+), 54 deletions(-) diff --git a/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java b/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java index 0d266fe08..34a3f099c 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java +++ b/agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java @@ -160,24 +160,29 @@ public final String getDescription() { */ @Override public final Mono call(List msgs) { - if (!running.compareAndSet(false, true) && checkRunning) { - return Mono.error( - new IllegalStateException( - "Agent is still running, please wait for it to finish")); - } - resetInterruptFlag(); - - return TracerRegistry.get() - .callAgent( - this, - msgs, - () -> - notifyPreCall(msgs) - .flatMap(this::doCall) - .flatMap(this::notifyPostCall) - .onErrorResume( - createErrorHandler(msgs.toArray(new Msg[0])))) - .doFinally(signalType -> running.set(false)); + return Mono.using( + () -> { + if (checkRunning && !running.compareAndSet(false, true)) { + throw new IllegalStateException( + "Agent is still running, please wait for it to finish"); + } + resetInterruptFlag(); + return this; + }, + resource -> + TracerRegistry.get() + .callAgent( + this, + msgs, + () -> + notifyPreCall(msgs) + .flatMap(this::doCall) + .flatMap(this::notifyPostCall) + .onErrorResume( + createErrorHandler( + msgs.toArray(new Msg[0])))), + resource -> running.set(false), + true); } /** @@ -191,24 +196,33 @@ public final Mono call(List msgs) { */ @Override public final Mono call(List msgs, Class structuredOutputClass) { - if (!running.compareAndSet(false, true) && checkRunning) { - return Mono.error( - new IllegalStateException( - "Agent is still running, please wait for it to finish")); - } - resetInterruptFlag(); - - return TracerRegistry.get() - .callAgent( - this, - msgs, - () -> - notifyPreCall(msgs) - .flatMap(m -> doCall(m, structuredOutputClass)) - .flatMap(this::notifyPostCall) - .onErrorResume( - createErrorHandler(msgs.toArray(new Msg[0])))) - .doFinally(signalType -> running.set(false)); + return Mono.using( + () -> { + if (checkRunning && !running.compareAndSet(false, true)) { + throw new IllegalStateException( + "Agent is still running, please wait for it to finish"); + } + resetInterruptFlag(); + return this; + }, + resource -> + TracerRegistry.get() + .callAgent( + this, + msgs, + () -> + notifyPreCall(msgs) + .flatMap( + m -> + doCall( + m, + structuredOutputClass)) + .flatMap(this::notifyPostCall) + .onErrorResume( + createErrorHandler( + msgs.toArray(new Msg[0])))), + resource -> running.set(false), + true); } /** @@ -222,24 +236,29 @@ public final Mono call(List msgs, Class structuredOutputClass) { */ @Override public final Mono call(List msgs, JsonNode schema) { - if (!running.compareAndSet(false, true) && checkRunning) { - return Mono.error( - new IllegalStateException( - "Agent is still running, please wait for it to finish")); - } - resetInterruptFlag(); - - return TracerRegistry.get() - .callAgent( - this, - msgs, - () -> - notifyPreCall(msgs) - .flatMap(m -> doCall(m, schema)) - .flatMap(this::notifyPostCall) - .onErrorResume( - createErrorHandler(msgs.toArray(new Msg[0])))) - .doFinally(signalType -> running.set(false)); + return Mono.using( + () -> { + if (checkRunning && !running.compareAndSet(false, true)) { + throw new IllegalStateException( + "Agent is still running, please wait for it to finish"); + } + resetInterruptFlag(); + return this; + }, + resource -> + TracerRegistry.get() + .callAgent( + this, + msgs, + () -> + notifyPreCall(msgs) + .flatMap(m -> doCall(m, schema)) + .flatMap(this::notifyPostCall) + .onErrorResume( + createErrorHandler( + msgs.toArray(new Msg[0])))), + resource -> running.set(false), + true); } /** diff --git a/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java b/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java index 25ef7b67f..b375c9d2e 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/agent/AgentBaseTest.java @@ -37,6 +37,7 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; /** * Unit tests for AgentBase class. @@ -177,6 +178,30 @@ void testMultipleMessageInput() { "Response should be from the agent"); } + @Test + @DisplayName("Should not trigger concurrency conflict") + void testConcurrencyConflict() { + Msg message = TestUtils.createUserMessage("User", "First message"); + // Get response + Msg responseMsg = + agent.call(message) + .subscribeOn(Schedulers.boundedElastic()) // mock chat model + .block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + // no IllegalStateException throw + Msg response2 = + agent.call(message).block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + // Verify response + assertNotNull(responseMsg, "Response should not be null"); + assertEquals( + TestConstants.TEST_AGENT_NAME, + responseMsg.getName(), + "Response should be from the agent"); + + assertNotNull(response2, "Response should not be null"); + } + @Test @DisplayName("Should handle observe without generating reply") void testObserve() { diff --git a/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java b/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java index 7de520e2d..be9df745e 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/agent/ReActAgentStructuredOutputTest.java @@ -21,6 +21,7 @@ import io.agentscope.core.ReActAgent; import io.agentscope.core.agent.test.MockModel; +import io.agentscope.core.agent.test.TestConstants; import io.agentscope.core.memory.InMemoryMemory; import io.agentscope.core.memory.Memory; import io.agentscope.core.message.Msg; @@ -32,10 +33,12 @@ import io.agentscope.core.model.ChatUsage; import io.agentscope.core.tool.Toolkit; import io.agentscope.core.util.JsonUtils; +import java.time.Duration; import java.util.List; import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import reactor.core.scheduler.Schedulers; class ReActAgentStructuredOutputTest { @@ -530,4 +533,93 @@ void testStructuredOutputPreservesThinkingBlock() { thinking.getThinking(), "Thinking content should be preserved"); } + + @Test + void testConcurrencyConflictStructuredOutput() { + Memory memory = new InMemoryMemory(); + Map toolInput = + Map.of( + "response", + Map.of( + "location", "San Francisco", + "temperature", "72°F", + "condition", "Sunny")); + + MockModel mockModel = + new MockModel( + msgs -> { + boolean hasToolResults = + msgs.stream().anyMatch(m -> m.getRole() == MsgRole.TOOL); + if (!hasToolResults) { + return List.of( + ChatResponse.builder() + .id("msg_1") + .content( + List.of( + ToolUseBlock.builder() + .id("call_123") + .name("generate_response") + .input(toolInput) + .content( + JsonUtils + .getJsonCodec() + .toJson( + toolInput)) + .build())) + .usage(new ChatUsage(10, 20, 30)) + .build()); + } else { + return List.of( + ChatResponse.builder() + .id("msg_2") + .content( + List.of( + TextBlock.builder() + .text("Done") + .build())) + .usage(new ChatUsage(5, 10, 15)) + .build()); + } + }); + + ReActAgent agent = + ReActAgent.builder() + .name("weather-agent") + .sysPrompt("You are a weather assistant") + .model(mockModel) + .toolkit(toolkit) + .memory(memory) + .build(); + + Msg inputMsg = + Msg.builder() + .name("user") + .role(MsgRole.USER) + .content( + TextBlock.builder() + .text("What's the weather in San Francisco?") + .build()) + .build(); + + Msg responseMsg = + agent.call(inputMsg, WeatherResponse.class) + .subscribeOn(Schedulers.boundedElastic()) + .block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + Msg response2 = + agent.call(inputMsg, WeatherResponse.class) + .block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS)); + + assertNotNull(responseMsg); + WeatherResponse result = responseMsg.getStructuredData(WeatherResponse.class); + assertNotNull(result); + assertEquals("San Francisco", result.location); + assertEquals("72°F", result.temperature); + assertEquals("Sunny", result.condition); + + assertNotNull(response2); + // no IllegalStateException throw + WeatherResponse result2 = response2.getStructuredData(WeatherResponse.class); + assertNotNull(result2); + } }