Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 73 additions & 54 deletions agentscope-core/src/main/java/io/agentscope/core/agent/AgentBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,29 @@ public final String getDescription() {
*/
@Override
public final Mono<Msg> call(List<Msg> msgs) {
if (!running.compareAndSet(false, true) && checkRunning) {
return Mono.error(
new IllegalStateException(
"Agent is still running, please wait for it to finish"));
}
resetInterruptFlag();

return TracerRegistry.get()
.callAgent(
this,
msgs,
() ->
notifyPreCall(msgs)
.flatMap(this::doCall)
.flatMap(this::notifyPostCall)
.onErrorResume(
createErrorHandler(msgs.toArray(new Msg[0]))))
.doFinally(signalType -> running.set(false));
return Mono.using(
() -> {
if (checkRunning && !running.compareAndSet(false, true)) {
throw new IllegalStateException(
"Agent is still running, please wait for it to finish");
}
resetInterruptFlag();
return this;
},
resource ->
TracerRegistry.get()
.callAgent(
this,
msgs,
() ->
notifyPreCall(msgs)
.flatMap(this::doCall)
.flatMap(this::notifyPostCall)
.onErrorResume(
createErrorHandler(
msgs.toArray(new Msg[0])))),
resource -> running.set(false),
true);
Comment on lines +163 to +185

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The three overloaded call methods (call(List<Msg>), call(List<Msg>, Class<?>), and call(List<Msg>, JsonNode)) share nearly identical logic for handling the agent's running state using Mono.using. This code duplication can be eliminated by extracting the common logic into a private helper method. This would improve maintainability, as future changes to the lock management would only need to be made in one place.

Here is a suggested refactoring that would replace all three call methods and add a new private executeCall method:

    @Override
    public final Mono<Msg> call(List<Msg> msgs) {
        return executeCall(msgs, this::doCall);
    }

    @Override
    public final Mono<Msg> call(List<Msg> msgs, Class<?> structuredOutputClass) {
        return executeCall(msgs, m -> doCall(m, structuredOutputClass));
    }

    @Override
    public final Mono<Msg> call(List<Msg> msgs, JsonNode schema) {
        return executeCall(msgs, m -> doCall(m, schema));
    }

    private Mono<Msg> executeCall(List<Msg> msgs, Function<List<Msg>, Mono<Msg>> doCallLogic) {
        return Mono.using(
                () -> {
                    if (checkRunning && !running.compareAndSet(false, true)) {
                        throw new IllegalStateException(
                                "Agent is still running, please wait for it to finish");
                    }
                    resetInterruptFlag();
                    return this;
                },
                resource ->
                        TracerRegistry.get()
                                .callAgent(
                                        this,
                                        msgs,
                                        () ->
                                                notifyPreCall(msgs)
                                                        .flatMap(doCallLogic)
                                                        .flatMap(this::notifyPostCall)
                                                        .onErrorResume(
                                                                createErrorHandler(
                                                                        msgs.toArray(new Msg[0])))),
                resource -> running.set(false),
                true);
    }

Since a single code suggestion cannot span all three methods across different parts of the file, I'm providing the full refactoring here for your consideration.

}

/**
Expand All @@ -191,24 +196,33 @@ public final Mono<Msg> call(List<Msg> msgs) {
*/
@Override
public final Mono<Msg> call(List<Msg> msgs, Class<?> structuredOutputClass) {
if (!running.compareAndSet(false, true) && checkRunning) {
return Mono.error(
new IllegalStateException(
"Agent is still running, please wait for it to finish"));
}
resetInterruptFlag();

return TracerRegistry.get()
.callAgent(
this,
msgs,
() ->
notifyPreCall(msgs)
.flatMap(m -> doCall(m, structuredOutputClass))
.flatMap(this::notifyPostCall)
.onErrorResume(
createErrorHandler(msgs.toArray(new Msg[0]))))
.doFinally(signalType -> running.set(false));
return Mono.using(
() -> {
if (checkRunning && !running.compareAndSet(false, true)) {
throw new IllegalStateException(
"Agent is still running, please wait for it to finish");
}
resetInterruptFlag();
return this;
},
resource ->
TracerRegistry.get()
.callAgent(
this,
msgs,
() ->
notifyPreCall(msgs)
.flatMap(
m ->
doCall(
m,
structuredOutputClass))
.flatMap(this::notifyPostCall)
.onErrorResume(
createErrorHandler(
msgs.toArray(new Msg[0])))),
resource -> running.set(false),
true);
}

/**
Expand All @@ -222,24 +236,29 @@ public final Mono<Msg> call(List<Msg> msgs, Class<?> structuredOutputClass) {
*/
@Override
public final Mono<Msg> call(List<Msg> msgs, JsonNode schema) {
if (!running.compareAndSet(false, true) && checkRunning) {
return Mono.error(
new IllegalStateException(
"Agent is still running, please wait for it to finish"));
}
resetInterruptFlag();

return TracerRegistry.get()
.callAgent(
this,
msgs,
() ->
notifyPreCall(msgs)
.flatMap(m -> doCall(m, schema))
.flatMap(this::notifyPostCall)
.onErrorResume(
createErrorHandler(msgs.toArray(new Msg[0]))))
.doFinally(signalType -> running.set(false));
return Mono.using(
() -> {
if (checkRunning && !running.compareAndSet(false, true)) {
throw new IllegalStateException(
"Agent is still running, please wait for it to finish");
}
resetInterruptFlag();
return this;
},
resource ->
TracerRegistry.get()
.callAgent(
this,
msgs,
() ->
notifyPreCall(msgs)
.flatMap(m -> doCall(m, schema))
.flatMap(this::notifyPostCall)
.onErrorResume(
createErrorHandler(
msgs.toArray(new Msg[0])))),
resource -> running.set(false),
true);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

/**
* Unit tests for AgentBase class.
Expand Down Expand Up @@ -177,6 +178,30 @@ void testMultipleMessageInput() {
"Response should be from the agent");
}

@Test
@DisplayName("Should not trigger concurrency conflict")
void testConcurrencyConflict() {
Msg message = TestUtils.createUserMessage("User", "First message");
// Get response
Msg responseMsg =
agent.call(message)
.subscribeOn(Schedulers.boundedElastic()) // mock chat model
.block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));

// no IllegalStateException throw
Msg response2 =
agent.call(message).block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));

// Verify response
assertNotNull(responseMsg, "Response should not be null");
assertEquals(
TestConstants.TEST_AGENT_NAME,
responseMsg.getName(),
"Response should be from the agent");

assertNotNull(response2, "Response should not be null");
}

@Test
@DisplayName("Should handle observe without generating reply")
void testObserve() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.agentscope.core.ReActAgent;
import io.agentscope.core.agent.test.MockModel;
import io.agentscope.core.agent.test.TestConstants;
import io.agentscope.core.memory.InMemoryMemory;
import io.agentscope.core.memory.Memory;
import io.agentscope.core.message.Msg;
Expand All @@ -32,10 +33,12 @@
import io.agentscope.core.model.ChatUsage;
import io.agentscope.core.tool.Toolkit;
import io.agentscope.core.util.JsonUtils;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import reactor.core.scheduler.Schedulers;

class ReActAgentStructuredOutputTest {

Expand Down Expand Up @@ -530,4 +533,93 @@ void testStructuredOutputPreservesThinkingBlock() {
thinking.getThinking(),
"Thinking content should be preserved");
}

@Test
void testConcurrencyConflictStructuredOutput() {
Memory memory = new InMemoryMemory();
Map<String, Object> toolInput =
Map.of(
"response",
Map.of(
"location", "San Francisco",
"temperature", "72°F",
"condition", "Sunny"));

MockModel mockModel =
new MockModel(
msgs -> {
boolean hasToolResults =
msgs.stream().anyMatch(m -> m.getRole() == MsgRole.TOOL);
if (!hasToolResults) {
return List.of(
ChatResponse.builder()
.id("msg_1")
.content(
List.of(
ToolUseBlock.builder()
.id("call_123")
.name("generate_response")
.input(toolInput)
.content(
JsonUtils
.getJsonCodec()
.toJson(
toolInput))
.build()))
.usage(new ChatUsage(10, 20, 30))
.build());
} else {
return List.of(
ChatResponse.builder()
.id("msg_2")
.content(
List.of(
TextBlock.builder()
.text("Done")
.build()))
.usage(new ChatUsage(5, 10, 15))
.build());
}
});

ReActAgent agent =
ReActAgent.builder()
.name("weather-agent")
.sysPrompt("You are a weather assistant")
.model(mockModel)
.toolkit(toolkit)
.memory(memory)
.build();

Msg inputMsg =
Msg.builder()
.name("user")
.role(MsgRole.USER)
.content(
TextBlock.builder()
.text("What's the weather in San Francisco?")
.build())
.build();

Msg responseMsg =
agent.call(inputMsg, WeatherResponse.class)
.subscribeOn(Schedulers.boundedElastic())
.block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));

Msg response2 =
agent.call(inputMsg, WeatherResponse.class)
.block(Duration.ofMillis(TestConstants.DEFAULT_TEST_TIMEOUT_MS));

assertNotNull(responseMsg);
WeatherResponse result = responseMsg.getStructuredData(WeatherResponse.class);
assertNotNull(result);
assertEquals("San Francisco", result.location);
assertEquals("72°F", result.temperature);
assertEquals("Sunny", result.condition);

assertNotNull(response2);
// no IllegalStateException throw
WeatherResponse result2 = response2.getStructuredData(WeatherResponse.class);
assertNotNull(result2);
}
}
Loading