diff --git a/sdk/ai/azure-ai-agents/CHANGELOG.md b/sdk/ai/azure-ai-agents/CHANGELOG.md index eab78e4f5981..92ba51b1c61b 100644 --- a/sdk/ai/azure-ai-agents/CHANGELOG.md +++ b/sdk/ai/azure-ai-agents/CHANGELOG.md @@ -32,10 +32,13 @@ ### Bugs Fixed +- Fixed Memory Stores long-running operations (e.g. `beginUpdateMemories`) failing because the required `Foundry-Features` header was not included in poll requests, and custom LRO terminal states (`"completed"`, `"superseded"`) were not mapped to standard `LongRunningOperationStatus` values, causing pollers to hang indefinitely. - Fixed request parameter name from `"agent"` to `"agent_reference"` in `ResponsesClient` and `ResponsesAsyncClient` methods `createWithAgent` and `createWithAgentConversation` ### Other Changes +- Enabled and stabilised `MemoryStoresTests` and `MemoryStoresAsyncTests` (previously `@Disabled`), with timeout guards to prevent hanging. + ## 2.0.0-beta.1 (2026-02-25) ### Features Added diff --git a/sdk/ai/azure-ai-agents/assets.json b/sdk/ai/azure-ai-agents/assets.json index a6061845a097..08d8cb4c7c20 100644 --- a/sdk/ai/azure-ai-agents/assets.json +++ b/sdk/ai/azure-ai-agents/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "java", "TagPrefix": "java/ai/azure-ai-agents", - "Tag": "java/ai/azure-ai-agents_e4777fbd74" + "Tag": "java/ai/azure-ai-agents_34d0d1c5d4" } \ No newline at end of file diff --git a/sdk/ai/azure-ai-agents/customizations/src/main/java/AgentsCustomizations.java b/sdk/ai/azure-ai-agents/customizations/src/main/java/AgentsCustomizations.java index d72807757c79..592020d63a72 100644 --- a/sdk/ai/azure-ai-agents/customizations/src/main/java/AgentsCustomizations.java +++ b/sdk/ai/azure-ai-agents/customizations/src/main/java/AgentsCustomizations.java @@ -1,5 +1,7 @@ import com.azure.autorest.customization.Customization; import com.azure.autorest.customization.LibraryCustomization; +import com.github.javaparser.StaticJavaParser; +import com.github.javaparser.ast.body.MethodDeclaration; import org.slf4j.Logger; @@ -12,6 +14,7 @@ public class AgentsCustomizations extends Customization { @Override public void customize(LibraryCustomization libraryCustomization, Logger logger) { renameImageGenToolSize(libraryCustomization, logger); + modifyPollingStrategies(libraryCustomization, logger); } private void renameImageGenToolSize(LibraryCustomization customization, Logger logger) { @@ -30,4 +33,24 @@ private void renameImageGenToolSize(LibraryCustomization customization, Logger l .filter(entry -> "ONE_FIVE_THREE_SIXX_ONE_ZERO_TWO_FOUR".equals(entry.getName().getIdentifier())) .forEach(entry -> entry.setName("RESOLUTION_1536_X_1024")))); } + + private void modifyPollingStrategies(LibraryCustomization customization, Logger logger) { + customization.getClass("com.azure.ai.agents.implementation", "OperationLocationPollingStrategy") + .customizeAst(ast -> ast.getClassByName("OperationLocationPollingStrategy") + .ifPresent(clazz -> { + clazz.getConstructors().get(1).getBody().getStatements() + .set(0, StaticJavaParser.parseStatement("super(PollingUtils.OPERATION_LOCATION_HEADER, AgentsServicePollUtils.withFoundryFeatures(pollingStrategyOptions));")); + + clazz.addMember(StaticJavaParser.parseMethodDeclaration("@Override public Mono> poll(PollingContext pollingContext, TypeReference pollResponseType) { return super.poll(pollingContext, pollResponseType).map(AgentsServicePollUtils::remapStatus); }")); + })); + + customization.getClass("com.azure.ai.agents.implementation", "SyncOperationLocationPollingStrategy") + .customizeAst(ast -> ast.getClassByName("SyncOperationLocationPollingStrategy") + .ifPresent(clazz -> { + clazz.getConstructors().get(1).getBody().getStatements() + .set(0, StaticJavaParser.parseStatement("super(PollingUtils.OPERATION_LOCATION_HEADER, AgentsServicePollUtils.withFoundryFeatures(pollingStrategyOptions));")); + + clazz.addMember(StaticJavaParser.parseMethodDeclaration("@Override public PollResponse poll(PollingContext pollingContext, TypeReference pollResponseType) { return AgentsServicePollUtils.remapStatus(super.poll(pollingContext, pollResponseType)); }")); + })); + } } diff --git a/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/AgentsServicePollUtils.java b/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/AgentsServicePollUtils.java new file mode 100644 index 000000000000..b6dc84725be1 --- /dev/null +++ b/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/AgentsServicePollUtils.java @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.agents.implementation; + +import com.azure.ai.agents.models.FoundryFeaturesOptInKeys; +import com.azure.ai.agents.models.MemoryStoreUpdateStatus; +import com.azure.core.http.HttpHeaderName; +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.policy.AddHeadersFromContextPolicy; +import com.azure.core.util.Context; +import com.azure.core.util.polling.LongRunningOperationStatus; +import com.azure.core.util.polling.PollResponse; +import com.azure.core.util.polling.PollingStrategyOptions; + +/** + * Shared polling helpers for the Agents SDK. + * + *

The generated {@code OperationLocationPollingStrategy} / {@code SyncOperationLocationPollingStrategy} + * delegate here so that the two strategies stay in sync and only minimal edits are needed in the + * generated files.

+ * + *

This class is package-private; it is not part of the public API.

+ */ +final class AgentsServicePollUtils { + + /** Required preview-feature header for Memory Stores operations. */ + private static final HttpHeaderName FOUNDRY_FEATURES = HttpHeaderName.fromString("Foundry-Features"); + private static final String FOUNDRY_FEATURES_VALUE = FoundryFeaturesOptInKeys.MEMORY_STORES_V1_PREVIEW.toString(); + + private AgentsServicePollUtils() { + } + + /** + * Adds the {@code Foundry-Features} header to the given {@link PollingStrategyOptions}'s + * {@link Context}. If the context already carries {@link HttpHeaders} under the + * {@link AddHeadersFromContextPolicy} key they are preserved; the {@code Foundry-Features} + * entry is merged in. Because the pipeline already contains + * {@link AddHeadersFromContextPolicy}, the header is automatically added to every HTTP + * request the parent strategy makes (initial, poll, and final-result GETs). + * + *

Note: this method mutates and returns the same + * {@code PollingStrategyOptions} instance.

+ */ + static PollingStrategyOptions withFoundryFeatures(PollingStrategyOptions options) { + Context context = options.getContext() != null ? options.getContext() : Context.NONE; + Object existing = context.getData(AddHeadersFromContextPolicy.AZURE_REQUEST_HTTP_HEADERS_KEY).orElse(null); + HttpHeaders headers + = (existing instanceof HttpHeaders) ? new HttpHeaders((HttpHeaders) existing) : new HttpHeaders(); + headers.set(FOUNDRY_FEATURES, FOUNDRY_FEATURES_VALUE); + return options.setContext(context.addData(AddHeadersFromContextPolicy.AZURE_REQUEST_HTTP_HEADERS_KEY, headers)); + } + + /** + * Remaps a {@link PollResponse} whose status may contain a custom service terminal state + * ({@code "completed"}, {@code "superseded"}) that the base {@code OperationResourcePollingStrategy} + * cannot recognize. If no remapping is needed the original response is returned as-is. + * + *

The Memory Stores Azure core defines:

+ *
    + *
  • {@code "completed"} {@link LongRunningOperationStatus#SUCCESSFULLY_COMPLETED}
  • + *
  • {@code "superseded"} {@link LongRunningOperationStatus#USER_CANCELLED}
  • + *
+ */ + static PollResponse remapStatus(PollResponse response) { + LongRunningOperationStatus status = response.getStatus(); + LongRunningOperationStatus mapped = mapCustomStatus(status); + if (mapped == status) { + return response; + } + return new PollResponse<>(mapped, response.getValue(), response.getRetryAfter()); + } + + private static LongRunningOperationStatus mapCustomStatus(LongRunningOperationStatus status) { + // Standard statuses (Succeeded, Failed, Canceled, InProgress, NotStarted) are already + // mapped correctly by the parent's PollResult; only remap the custom ones. + String name = status.toString(); + if (MemoryStoreUpdateStatus.COMPLETED.toString().equalsIgnoreCase(name)) { + return LongRunningOperationStatus.SUCCESSFULLY_COMPLETED; + } else if (MemoryStoreUpdateStatus.SUPERSEDED.toString().equalsIgnoreCase(name)) { + return LongRunningOperationStatus.USER_CANCELLED; + } + return status; + } +} diff --git a/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/OperationLocationPollingStrategy.java b/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/OperationLocationPollingStrategy.java index df8351347821..6dc846c0de9b 100644 --- a/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/OperationLocationPollingStrategy.java +++ b/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/OperationLocationPollingStrategy.java @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // Code generated by Microsoft (R) TypeSpec Code Generator. - package com.azure.ai.agents.implementation; import com.azure.core.exception.AzureException; @@ -22,7 +21,6 @@ import reactor.core.publisher.Mono; // DO NOT modify this helper class - /** * Implements an operation location polling strategy, from Operation-Location. * @@ -35,7 +33,9 @@ public final class OperationLocationPollingStrategy extends OperationResou private static final ClientLogger LOGGER = new ClientLogger(OperationLocationPollingStrategy.class); private final ObjectSerializer serializer; + private final String endpoint; + private final String propertyName; /** @@ -56,7 +56,8 @@ public OperationLocationPollingStrategy(PollingStrategyOptions pollingStrategyOp * @throws NullPointerException if {@code pollingStrategyOptions} is null. */ public OperationLocationPollingStrategy(PollingStrategyOptions pollingStrategyOptions, String propertyName) { - super(PollingUtils.OPERATION_LOCATION_HEADER, pollingStrategyOptions); + super(PollingUtils.OPERATION_LOCATION_HEADER, + AgentsServicePollUtils.withFoundryFeatures(pollingStrategyOptions)); this.propertyName = propertyName; this.endpoint = pollingStrategyOptions.getEndpoint(); this.serializer = pollingStrategyOptions.getSerializer() != null @@ -71,7 +72,6 @@ public OperationLocationPollingStrategy(PollingStrategyOptions pollingStrategyOp public Mono> onInitialResponse(Response response, PollingContext pollingContext, TypeReference pollResponseType) { // Response is Response - HttpHeader operationLocationHeader = response.getHeaders().get(PollingUtils.OPERATION_LOCATION_HEADER); if (operationLocationHeader != null) { pollingContext.setData(PollingUtils.OPERATION_LOCATION_HEADER.getCaseSensitiveName(), @@ -80,7 +80,6 @@ public Mono> onInitialResponse(Response response, PollingCont final String httpMethod = response.getRequest().getHttpMethod().name(); pollingContext.setData(PollingUtils.HTTP_METHOD, httpMethod); pollingContext.setData(PollingUtils.REQUEST_URL, response.getRequest().getUrl().toString()); - if (response.getStatusCode() == 200 || response.getStatusCode() == 201 || response.getStatusCode() == 202 @@ -137,4 +136,9 @@ public Mono getResult(PollingContext pollingContext, TypeReference resu return super.getResult(pollingContext, resultType); } } + + @Override + public Mono> poll(PollingContext pollingContext, TypeReference pollResponseType) { + return super.poll(pollingContext, pollResponseType).map(AgentsServicePollUtils::remapStatus); + } } diff --git a/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/SyncOperationLocationPollingStrategy.java b/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/SyncOperationLocationPollingStrategy.java index 7d249b059e57..5d8c3582180f 100644 --- a/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/SyncOperationLocationPollingStrategy.java +++ b/sdk/ai/azure-ai-agents/src/main/java/com/azure/ai/agents/implementation/SyncOperationLocationPollingStrategy.java @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // Code generated by Microsoft (R) TypeSpec Code Generator. - package com.azure.ai.agents.implementation; import com.azure.core.exception.AzureException; @@ -23,7 +22,6 @@ import java.util.Map; // DO NOT modify this helper class - /** * Implements a synchronous operation location polling strategy, from Operation-Location. * @@ -36,7 +34,9 @@ public final class SyncOperationLocationPollingStrategy extends SyncOperat private static final ClientLogger LOGGER = new ClientLogger(SyncOperationLocationPollingStrategy.class); private final ObjectSerializer serializer; + private final String endpoint; + private final String propertyName; /** @@ -57,7 +57,8 @@ public SyncOperationLocationPollingStrategy(PollingStrategyOptions pollingStrate * @throws NullPointerException if {@code pollingStrategyOptions} is null. */ public SyncOperationLocationPollingStrategy(PollingStrategyOptions pollingStrategyOptions, String propertyName) { - super(PollingUtils.OPERATION_LOCATION_HEADER, pollingStrategyOptions); + super(PollingUtils.OPERATION_LOCATION_HEADER, + AgentsServicePollUtils.withFoundryFeatures(pollingStrategyOptions)); this.propertyName = propertyName; this.endpoint = pollingStrategyOptions.getEndpoint(); this.serializer = pollingStrategyOptions.getSerializer() != null @@ -72,7 +73,6 @@ public SyncOperationLocationPollingStrategy(PollingStrategyOptions pollingStrate public PollResponse onInitialResponse(Response response, PollingContext pollingContext, TypeReference pollResponseType) { // Response is Response - HttpHeader operationLocationHeader = response.getHeaders().get(PollingUtils.OPERATION_LOCATION_HEADER); if (operationLocationHeader != null) { pollingContext.setData(PollingUtils.OPERATION_LOCATION_HEADER.getCaseSensitiveName(), @@ -81,7 +81,6 @@ public PollResponse onInitialResponse(Response response, PollingContext final String httpMethod = response.getRequest().getHttpMethod().name(); pollingContext.setData(PollingUtils.HTTP_METHOD, httpMethod); pollingContext.setData(PollingUtils.REQUEST_URL, response.getRequest().getUrl().toString()); - if (response.getStatusCode() == 200 || response.getStatusCode() == 201 || response.getStatusCode() == 202 @@ -97,7 +96,6 @@ public PollResponse onInitialResponse(Response response, PollingContext } return new PollResponse<>(LongRunningOperationStatus.IN_PROGRESS, initialResponseType, retryAfter); } - throw LOGGER.logExceptionAsError(new AzureException( String.format("Operation failed or cancelled with status code %d, '%s' header: %s, and response body: %s", response.getStatusCode(), PollingUtils.OPERATION_LOCATION_HEADER, operationLocationHeader, @@ -130,4 +128,9 @@ public U getResult(PollingContext pollingContext, TypeReference resultType return super.getResult(pollingContext, resultType); } } + + @Override + public PollResponse poll(PollingContext pollingContext, TypeReference pollResponseType) { + return AgentsServicePollUtils.remapStatus(super.poll(pollingContext, pollResponseType)); + } } diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresAsyncTests.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresAsyncTests.java index 4b73a6f6eed4..80715460c5ca 100644 --- a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresAsyncTests.java +++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresAsyncTests.java @@ -12,21 +12,20 @@ import com.azure.ai.agents.models.MemoryStoreDetails; import com.azure.ai.agents.models.MemoryStoreUpdateCompletedResult; import com.azure.ai.agents.models.MemoryStoreUpdateResponse; -import com.azure.ai.agents.models.MemoryStoreUpdateStatus; import com.azure.ai.agents.models.PageOrder; import com.azure.core.exception.ResourceNotFoundException; import com.azure.core.http.HttpClient; import com.azure.core.util.polling.AsyncPollResponse; -import com.azure.core.util.polling.LongRunningOperationStatus; import com.azure.core.util.polling.PollerFlux; import com.openai.models.responses.EasyInputMessage; import com.openai.models.responses.ResponseInputItem; -import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import java.time.Duration; import java.util.Arrays; import java.util.Objects; @@ -35,12 +34,9 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; -@Disabled("Awaiting service versioning consolidation.") +@Timeout(30) public class MemoryStoresAsyncTests extends ClientTestBase { - private static final LongRunningOperationStatus COMPLETED_OPERATION_STATUS - = LongRunningOperationStatus.fromString(MemoryStoreUpdateStatus.COMPLETED.toString(), true); - @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @MethodSource("com.azure.ai.agents.TestUtils#getTestParameters") public void basicMemoryStoresCrud(HttpClient httpClient, AgentsServiceVersion serviceVersion) { @@ -282,15 +278,9 @@ private static Mono cleanupBeforeTest(MemoryStoresAsyncClient memoryStoreC private static Mono waitForUpdateCompletion(PollerFlux pollerFlux) { Objects.requireNonNull(pollerFlux, "pollerFlux cannot be null"); - return pollerFlux.takeUntil(response -> COMPLETED_OPERATION_STATUS.equals(response.getStatus())) + return pollerFlux.takeUntil(response -> response.getStatus().isComplete()) + .timeout(Duration.ofSeconds(30)) .last() - .map(AsyncPollResponse::getValue) - .map(response -> { - MemoryStoreUpdateCompletedResult result = response == null ? null : response.getResult(); - if (result == null) { - throw new IllegalStateException("Memory store update did not complete successfully."); - } - return result; - }); + .flatMap(AsyncPollResponse::getFinalResult); } } diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresTests.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresTests.java index 6e067c5bae03..0258db1cd9e5 100644 --- a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresTests.java +++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/MemoryStoresTests.java @@ -7,14 +7,14 @@ import com.azure.ai.agents.models.DeleteMemoryStoreResult; import com.azure.core.exception.ResourceNotFoundException; import com.azure.core.http.HttpClient; -import com.azure.core.util.polling.LongRunningOperationStatus; import com.azure.core.util.polling.SyncPoller; import com.openai.models.responses.EasyInputMessage; import com.openai.models.responses.ResponseInputItem; -import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import java.time.Duration; import java.util.Arrays; import static com.azure.ai.agents.TestUtils.DISPLAY_NAME_WITH_ARGUMENTS; @@ -23,9 +23,11 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -@Disabled("Awaiting service versioning consolidation.") +@Timeout(30) public class MemoryStoresTests extends ClientTestBase { + private static final Duration POLL_TIMEOUT = Duration.ofSeconds(30); + @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @MethodSource("com.azure.ai.agents.TestUtils#getTestParameters") public void basicMemoryStoresCrud(HttpClient httpClient, AgentsServiceVersion serviceVersion) { @@ -114,10 +116,6 @@ public void basicMemoryStores(HttpClient httpClient, AgentsServiceVersion servic assertNotNull(memoryStore.getId()); assertEquals(memoryStoreName, memoryStore.getName()); assertEquals(description, memoryStore.getDescription()); - System.out.println("Created memory store: " + memoryStore.getName() + " (" + memoryStore.getId() + "): " - + memoryStore.getDescription()); - System.out.println(" - Chat model: " + definition.getChatModel()); - System.out.println(" - Embedding model: " + definition.getEmbeddingModel()); // Add memories to the memory store ResponseInputItem userMessage = ResponseInputItem.ofEasyInputMessage( @@ -126,23 +124,16 @@ public void basicMemoryStores(HttpClient httpClient, AgentsServiceVersion servic SyncPoller updatePoller = memoryStoreClient.beginUpdateMemories(memoryStoreName, scope, Arrays.asList(userMessage), null, 0); - // Wait for the update operation to complete - LongRunningOperationStatus status = null; - while (status != LongRunningOperationStatus.fromString(MemoryStoreUpdateStatus.COMPLETED.toString(), true)) { - sleep(500); - System.out.println("Polling status: " + status); - status = updatePoller.poll().getStatus(); - } + // Wait for the update operation to complete (with timeout to avoid hanging) + updatePoller.waitForCompletion(POLL_TIMEOUT); MemoryStoreUpdateCompletedResult updateResult = updatePoller.getFinalResult(); assertNotNull(updateResult); assertNotNull(updateResult.getMemoryOperations()); - System.out.println("Updated with " + updateResult.getMemoryOperations().size() + " memory operations"); + for (MemoryOperation operation : updateResult.getMemoryOperations()) { assertNotNull(operation.getKind()); assertNotNull(operation.getMemoryItem().getMemoryId()); assertNotNull(operation.getMemoryItem().getContent()); - System.out.println(" - Operation: " + operation.getKind() + ", Memory ID: " - + operation.getMemoryItem().getMemoryId() + ", Content: " + operation.getMemoryItem().getContent()); } ResponseInputItem queryMessage = ResponseInputItem.ofEasyInputMessage( @@ -153,23 +144,19 @@ public void basicMemoryStores(HttpClient httpClient, AgentsServiceVersion servic Arrays.asList(queryMessage), null, searchOptions); assertNotNull(searchResponse); assertNotNull(searchResponse.getMemories()); - System.out.println("Found " + searchResponse.getMemories().size() + " memories"); + for (MemorySearchItem memory : searchResponse.getMemories()) { assertNotNull(memory.getMemoryItem().getMemoryId()); assertNotNull(memory.getMemoryItem().getContent()); - System.out.println(" - Memory ID: " + memory.getMemoryItem().getMemoryId() + ", Content: " - + memory.getMemoryItem().getContent()); } // Delete memories for a specific scope memoryStoreClient.deleteScope(memoryStoreName, scope); - System.out.println("Deleted memories for scope '" + scope + "'"); // Delete memory store DeleteMemoryStoreResult deleteResponse = memoryStoreClient.deleteMemoryStore(memoryStoreName); assertNotNull(deleteResponse); assertTrue(deleteResponse.isDeleted()); - System.out.println("Deleted memory store `" + memoryStoreName + "`"); } @ParameterizedTest(name = DISPLAY_NAME_WITH_ARGUMENTS) @@ -202,8 +189,6 @@ public void advancedMemoryStores(HttpClient httpClient, AgentsServiceVersion ser = memoryStoreClient.createMemoryStore(memoryStoreName, definition, description, null); assertNotNull(memoryStore); assertEquals(memoryStoreName, memoryStore.getName()); - System.out.println("Created memory store: " + memoryStore.getName() + " (" + memoryStore.getId() + "): " - + memoryStore.getDescription()); ResponseInputItem initialMessage = ResponseInputItem.ofEasyInputMessage( EasyInputMessage.builder().role(EasyInputMessage.Role.USER).content(firstMessageContent).build()); @@ -214,8 +199,6 @@ public void advancedMemoryStores(HttpClient httpClient, AgentsServiceVersion ser assertNotNull(initialResponse); String initialUpdateId = initialResponse.getUpdateId(); assertNotNull(initialUpdateId); - System.out.println("Scheduled memory update operation (Update ID: " + initialUpdateId + ", Status: " - + initialPoller.poll().getStatus() + ")"); // Extend the previous update with another update and more messages ResponseInputItem chainedMessage = ResponseInputItem.ofEasyInputMessage( @@ -227,30 +210,16 @@ public void advancedMemoryStores(HttpClient httpClient, AgentsServiceVersion ser assertNotNull(chainedResponse); String chainedUpdateId = chainedResponse.getUpdateId(); assertNotNull(chainedUpdateId); - System.out.println("Scheduled memory update operation (Update ID: " + chainedUpdateId + ", Status: " - + chainedPoller.poll().getStatus() + ")"); - - // As first update has not started yet, the new update will cancel the first update and cover both sets of messages - System.out.println("Superseded first memory update operation (Update ID: " + initialUpdateId + ", Status: " - + initialPoller.poll().getStatus() + ")"); - - LongRunningOperationStatus chainedStatus = null; - while (chainedStatus - != LongRunningOperationStatus.fromString(MemoryStoreUpdateStatus.COMPLETED.toString(), true)) { - sleep(500); - chainedStatus = chainedPoller.poll().getStatus(); - } + + chainedPoller.waitForCompletion(POLL_TIMEOUT); MemoryStoreUpdateCompletedResult updateResult = chainedPoller.getFinalResult(); assertNotNull(updateResult); assertNotNull(updateResult.getMemoryOperations()); - System.out.println("Second update " + chainedUpdateId + " completed with " - + updateResult.getMemoryOperations().size() + " memory operations"); + for (MemoryOperation operation : updateResult.getMemoryOperations()) { assertNotNull(operation.getKind()); assertNotNull(operation.getMemoryItem().getMemoryId()); assertNotNull(operation.getMemoryItem().getContent()); - System.out.println(" - Operation: " + operation.getKind() + ", Memory ID: " - + operation.getMemoryItem().getMemoryId() + ", Content: " + operation.getMemoryItem().getContent()); } // Retrieve memories from the memory store @@ -263,12 +232,10 @@ public void advancedMemoryStores(HttpClient httpClient, AgentsServiceVersion ser = memoryStoreClient.searchMemories(memoryStoreName, scope, Arrays.asList(searchQuery), null, searchOptions); assertNotNull(searchResponse); assertNotNull(searchResponse.getMemories()); - System.out.println("Found " + searchResponse.getMemories().size() + " memories"); + for (MemorySearchItem memory : searchResponse.getMemories()) { assertNotNull(memory.getMemoryItem().getMemoryId()); assertNotNull(memory.getMemoryItem().getContent()); - System.out.println(" - Memory ID: " + memory.getMemoryItem().getMemoryId() + ", Content: " - + memory.getMemoryItem().getContent()); } String previousSearchId = searchResponse.getSearchId(); assertNotNull(previousSearchId); @@ -283,30 +250,24 @@ public void advancedMemoryStores(HttpClient httpClient, AgentsServiceVersion ser Arrays.asList(agentMessage, followupQuery), previousSearchId, searchOptions); assertNotNull(followupSearch); assertNotNull(followupSearch.getMemories()); - System.out.println("Found " + followupSearch.getMemories().size() + " memories"); + for (MemorySearchItem memory : followupSearch.getMemories()) { assertNotNull(memory.getMemoryItem().getMemoryId()); assertNotNull(memory.getMemoryItem().getContent()); - System.out.println(" - Memory ID: " + memory.getMemoryItem().getMemoryId() + ", Content: " - + memory.getMemoryItem().getContent()); } // Delete memories for the current scope memoryStoreClient.deleteScope(memoryStoreName, scope); - System.out.println("Deleted memories for scope '" + scope + "'"); // Delete memory store DeleteMemoryStoreResult deleteResponse = memoryStoreClient.deleteMemoryStore(memoryStoreName); assertNotNull(deleteResponse); assertTrue(deleteResponse.isDeleted()); - System.out.println("Deleted memory store `" + memoryStoreName + "`"); } - private static void cleanupBeforeTest(MemoryStoresClient memoryStoreClient, String memoryStoreName) { - // Ensure clean state: delete if it already exists + private void cleanupBeforeTest(MemoryStoresClient memoryStoreClient, String memoryStoreName) { try { - DeleteMemoryStoreResult deleteExisting = memoryStoreClient.deleteMemoryStore(memoryStoreName); - assertNotNull(deleteExisting); + memoryStoreClient.deleteMemoryStore(memoryStoreName); } catch (ResourceNotFoundException ex) { // ok if it does not exist } diff --git a/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/implementation/AgentsServicePollUtilsTest.java b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/implementation/AgentsServicePollUtilsTest.java new file mode 100644 index 000000000000..56d744c43499 --- /dev/null +++ b/sdk/ai/azure-ai-agents/src/test/java/com/azure/ai/agents/implementation/AgentsServicePollUtilsTest.java @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.ai.agents.implementation; + +import com.azure.core.http.HttpHeaders; +import com.azure.core.http.HttpPipelineBuilder; +import com.azure.core.http.policy.AddHeadersFromContextPolicy; +import com.azure.core.util.Context; +import com.azure.core.util.polling.LongRunningOperationStatus; +import com.azure.core.util.polling.PollResponse; +import com.azure.core.util.polling.PollingStrategyOptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +class AgentsServicePollUtilsTest { + + static Stream remapStatusCases() { + return Stream.of( + // Custom statuses that need remapping + Arguments.of("completed", LongRunningOperationStatus.SUCCESSFULLY_COMPLETED), + Arguments.of("Completed", LongRunningOperationStatus.SUCCESSFULLY_COMPLETED), + Arguments.of("COMPLETED", LongRunningOperationStatus.SUCCESSFULLY_COMPLETED), + Arguments.of("superseded", LongRunningOperationStatus.USER_CANCELLED), + Arguments.of("Superseded", LongRunningOperationStatus.USER_CANCELLED)); + } + + @ParameterizedTest + @MethodSource("remapStatusCases") + void remapStatusMapsCustomStatuses(String statusName, LongRunningOperationStatus expected) { + // The parent's PollResult.setStatus(String) calls fromString(name, false) for unknown statuses + LongRunningOperationStatus customStatus = LongRunningOperationStatus.fromString(statusName, false); + PollResponse original = new PollResponse<>(customStatus, "value"); + + PollResponse remapped = AgentsServicePollUtils.remapStatus(original); + + assertEquals(expected, remapped.getStatus()); + assertEquals("value", remapped.getValue()); + } + + static Stream standardStatusCases() { + return Stream.of(Arguments.of(LongRunningOperationStatus.SUCCESSFULLY_COMPLETED), + Arguments.of(LongRunningOperationStatus.FAILED), Arguments.of(LongRunningOperationStatus.USER_CANCELLED), + Arguments.of(LongRunningOperationStatus.IN_PROGRESS), Arguments.of(LongRunningOperationStatus.NOT_STARTED)); + } + + @ParameterizedTest + @MethodSource("standardStatusCases") + void remapStatusPassesThroughStandardStatuses(LongRunningOperationStatus status) { + PollResponse original = new PollResponse<>(status, "value"); + + PollResponse result = AgentsServicePollUtils.remapStatus(original); + + assertSame(original, result, "Standard status should return the same PollResponse instance"); + } + + @Test + void remapStatusPreservesRetryAfter() { + LongRunningOperationStatus completed = LongRunningOperationStatus.fromString("completed", false); + java.time.Duration retryAfter = java.time.Duration.ofSeconds(5); + PollResponse original = new PollResponse<>(completed, "value", retryAfter); + + PollResponse remapped = AgentsServicePollUtils.remapStatus(original); + + assertEquals(LongRunningOperationStatus.SUCCESSFULLY_COMPLETED, remapped.getStatus()); + assertEquals(retryAfter, remapped.getRetryAfter()); + } + + @Test + void withFoundryFeaturesAddsHeaderToContext() { + PollingStrategyOptions options = new PollingStrategyOptions(new HttpPipelineBuilder().build()); + + PollingStrategyOptions result = AgentsServicePollUtils.withFoundryFeatures(options); + + Context context = result.getContext(); + assertNotNull(context); + Object headerObj = context.getData(AddHeadersFromContextPolicy.AZURE_REQUEST_HTTP_HEADERS_KEY).orElse(null); + assertNotNull(headerObj, "Context should contain HTTP headers under the AddHeadersFromContextPolicy key"); + HttpHeaders headers = (HttpHeaders) headerObj; + assertEquals("MemoryStores=V1Preview", + headers.getValue(com.azure.core.http.HttpHeaderName.fromString("Foundry-Features"))); + } + + @Test + void withFoundryFeaturesPreservesExistingContext() { + PollingStrategyOptions options + = new PollingStrategyOptions(new HttpPipelineBuilder().build()).setContext(new Context("myKey", "myValue")); + + PollingStrategyOptions result = AgentsServicePollUtils.withFoundryFeatures(options); + + Context context = result.getContext(); + assertEquals("myValue", context.getData("myKey").orElse(null), "Existing context data should be preserved"); + assertNotNull(context.getData(AddHeadersFromContextPolicy.AZURE_REQUEST_HTTP_HEADERS_KEY).orElse(null), + "Foundry-Features header should also be present"); + } + + @Test + void withFoundryFeaturesMergesWithExistingHeaders() { + HttpHeaders existingHeaders = new HttpHeaders(); + existingHeaders.set(com.azure.core.http.HttpHeaderName.fromString("X-Custom"), "custom-value"); + Context contextWithHeaders + = new Context(AddHeadersFromContextPolicy.AZURE_REQUEST_HTTP_HEADERS_KEY, existingHeaders); + PollingStrategyOptions options + = new PollingStrategyOptions(new HttpPipelineBuilder().build()).setContext(contextWithHeaders); + + PollingStrategyOptions result = AgentsServicePollUtils.withFoundryFeatures(options); + + HttpHeaders merged = (HttpHeaders) result.getContext() + .getData(AddHeadersFromContextPolicy.AZURE_REQUEST_HTTP_HEADERS_KEY) + .orElse(null); + assertNotNull(merged); + assertEquals("custom-value", merged.getValue(com.azure.core.http.HttpHeaderName.fromString("X-Custom")), + "Pre-existing header should be preserved"); + assertEquals("MemoryStores=V1Preview", + merged.getValue(com.azure.core.http.HttpHeaderName.fromString("Foundry-Features")), + "Foundry-Features header should be added"); + } +}