diff --git a/aws/client/aws-client-awsquery/build.gradle.kts b/aws/client/aws-client-awsquery/build.gradle.kts new file mode 100644 index 000000000..208030bee --- /dev/null +++ b/aws/client/aws-client-awsquery/build.gradle.kts @@ -0,0 +1,22 @@ +plugins { + id("smithy-java.module-conventions") + id("smithy-java.protocol-testing-conventions") +} + +description = "This module provides the implementation of AWS Query protocol" + +extra["displayName"] = "Smithy :: Java :: AWS :: Client :: AWS Query" +extra["moduleName"] = "software.amazon.smithy.java.aws.client.awsquery" + +dependencies { + api(project(":client:client-http")) + api(project(":codecs:xml-codec")) + api(project(":io")) + api(libs.smithy.aws.traits) + + // Protocol test dependencies + testImplementation(libs.smithy.aws.protocol.tests) +} + +val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" +addGenerateSrcsTask(generator, "awsQuery", "aws.protocoltests.query#AwsQuery") diff --git a/aws/client/aws-client-awsquery/src/it/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryProtocolTests.java b/aws/client/aws-client-awsquery/src/it/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryProtocolTests.java new file mode 100644 index 000000000..63c24bbde --- /dev/null +++ b/aws/client/aws-client-awsquery/src/it/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryProtocolTests.java @@ -0,0 +1,79 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.aws.client.awsquery; + +import static java.net.URLDecoder.decode; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Collectors; +import software.amazon.smithy.java.io.ByteBufferUtils; +import software.amazon.smithy.java.io.datastream.DataStream; +import software.amazon.smithy.java.protocoltests.harness.HttpClientRequestTests; +import software.amazon.smithy.java.protocoltests.harness.HttpClientResponseTests; +import software.amazon.smithy.java.protocoltests.harness.ProtocolTest; +import software.amazon.smithy.java.protocoltests.harness.ProtocolTestFilter; +import software.amazon.smithy.java.protocoltests.harness.TestType; + +@ProtocolTest( + service = "aws.protocoltests.query#AwsQuery", + testType = TestType.CLIENT) +public class AwsQueryProtocolTests { + + @HttpClientRequestTests + @ProtocolTestFilter( + skipTests = { + "SDKAppliedContentEncoding_awsQuery", + "SDKAppendsGzipAndIgnoresHttpProvidedEncoding_awsQuery", + }) + public void requestTest(DataStream expected, DataStream actual) { + String expectedStr = new String( + ByteBufferUtils.getBytes(expected.asByteBuffer()), + StandardCharsets.UTF_8); + String actualStr = new String( + ByteBufferUtils.getBytes(actual.asByteBuffer()), + StandardCharsets.UTF_8); + + Map expectedParams = parseFormUrlEncoded(expectedStr); + Map actualParams = parseFormUrlEncoded(actualStr); + + assertEquals(expectedParams, actualParams); + } + + @HttpClientResponseTests + @ProtocolTestFilter( + skipTests = { + "AwsQueryClientPopulatesDefaultsValuesWhenMissingInResponse", + "QueryCustomizedError", + }) + public void responseTest(Runnable test) { + test.run(); + } + + private Map parseFormUrlEncoded(String body) { + if (body == null || body.isEmpty()) { + return new TreeMap<>(); + } + return Arrays.stream(body.split("&")) + .map(pair -> pair.split("=", 2)) + .collect(Collectors.toMap( + parts -> urlDecode(parts[0]), + parts -> parts.length > 1 ? urlDecode(parts[1]) : "", + (a, b) -> b, + TreeMap::new)); + } + + private String urlDecode(String value) { + try { + return decode(value, StandardCharsets.UTF_8); + } catch (Exception e) { + return value; + } + } +} diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryClientProtocol.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryClientProtocol.java new file mode 100644 index 000000000..d5261f6ed --- /dev/null +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryClientProtocol.java @@ -0,0 +1,178 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.aws.client.awsquery; + +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import software.amazon.smithy.aws.traits.protocols.AwsQueryTrait; +import software.amazon.smithy.java.client.core.ClientProtocol; +import software.amazon.smithy.java.client.core.ClientProtocolFactory; +import software.amazon.smithy.java.client.core.ProtocolSettings; +import software.amazon.smithy.java.client.http.HttpClientProtocol; +import software.amazon.smithy.java.client.http.HttpErrorDeserializer; +import software.amazon.smithy.java.context.Context; +import software.amazon.smithy.java.core.error.CallException; +import software.amazon.smithy.java.core.error.ModeledException; +import software.amazon.smithy.java.core.schema.ApiOperation; +import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.schema.ShapeBuilder; +import software.amazon.smithy.java.core.schema.Unit; +import software.amazon.smithy.java.core.serde.Codec; +import software.amazon.smithy.java.core.serde.TypeRegistry; +import software.amazon.smithy.java.core.serde.document.Document; +import software.amazon.smithy.java.http.api.HttpHeaders; +import software.amazon.smithy.java.http.api.HttpRequest; +import software.amazon.smithy.java.http.api.HttpResponse; +import software.amazon.smithy.java.io.datastream.DataStream; +import software.amazon.smithy.java.xml.XmlCodec; +import software.amazon.smithy.java.xml.XmlUtil; +import software.amazon.smithy.model.shapes.ShapeId; + +public final class AwsQueryClientProtocol extends HttpClientProtocol { + + private static final String CONTENT_TYPE = "application/x-www-form-urlencoded"; + private static final List CONTENT_TYPE_LIST = List.of(CONTENT_TYPE); + + private final ShapeId service; + private final String version; + private final HttpErrorDeserializer errorDeserializer; + private final XmlCodec codec = XmlCodec.builder().build(); + + public AwsQueryClientProtocol(ShapeId service, String version) { + super(AwsQueryTrait.ID); + this.service = Objects.requireNonNull(service, "service is required"); + this.version = Objects.requireNonNull(version, "version is required"); + this.errorDeserializer = HttpErrorDeserializer.builder() + .codec(XmlCodec.builder().build()) + .serviceId(service) + .errorPayloadParser(XML_ERROR_PAYLOAD_PARSER) + .knownErrorFactory(new XmlKnownErrorFactory()) + .build(); + } + + @Override + public Codec payloadCodec() { + return codec; + } + + @Override + public HttpRequest createRequest( + ApiOperation operation, + I input, + Context context, + URI endpoint + ) { + String operationName = operation.schema().id().getName(); + AwsQueryFormSerializer serializer = new AwsQueryFormSerializer(operationName, version); + + if (!Unit.ID.equals(operation.inputSchema().id())) { + input.serializeMembers(serializer); + } + + ByteBuffer body = serializer.finish(); + + return HttpRequest.builder() + .method("POST") + .uri(endpoint) + .headers(HttpHeaders.of(Map.of("Content-Type", CONTENT_TYPE_LIST))) + .body(DataStream.ofByteBuffer(body, CONTENT_TYPE)) + .build(); + } + + @Override + public O deserializeResponse( + ApiOperation operation, + Context context, + TypeRegistry typeRegistry, + HttpRequest request, + HttpResponse response + ) { + if (response.statusCode() >= 300) { + throw errorDeserializer.createError(context, operation.schema().id(), typeRegistry, response); + } + + var builder = operation.outputBuilder(); + var content = response.body(); + + if (content.contentLength() == 0) { + return builder.build(); + } + + var operationName = operation.schema().id().getName(); + try (var codec = XmlCodec.builder() + .wrapperElements(List.of(operationName + "Response", operationName + "Result")) + .build()) { + return codec.deserializeShape(response.body().asByteBuffer(), builder); + } + } + + private static final HttpErrorDeserializer.ErrorPayloadParser XML_ERROR_PAYLOAD_PARSER = + new HttpErrorDeserializer.ErrorPayloadParser() { + @Override + public CallException parsePayload( + Context context, + Codec codec, + HttpErrorDeserializer.KnownErrorFactory knownErrorFactory, + ShapeId serviceId, + TypeRegistry typeRegistry, + HttpResponse response, + ByteBuffer buffer + ) { + var deserializer = codec.createDeserializer(buffer); + String code = XmlUtil.parseErrorCodeName(deserializer); + var nameSpace = serviceId.getNamespace(); + var id = ShapeId.fromOptionalNamespace(nameSpace, code); + var builder = typeRegistry.createBuilder(id, ModeledException.class); + if (builder != null) { + return knownErrorFactory.createError(context, codec, response, builder); + } + return null; + } + + @Override + public ShapeId extractErrorType( + Document document, + String namespace + ) { + return null; + } + }; + + private static final class XmlKnownErrorFactory implements HttpErrorDeserializer.KnownErrorFactory { + @Override + public ModeledException createError( + Context context, + Codec codec, + HttpResponse response, + ShapeBuilder builder + ) { + ByteBuffer bytes = DataStream.ofPublisher( + response.body(), + response.contentType(), + response.contentLength(-1)).asByteBuffer(); + return codec.deserializeShape(bytes, builder); + } + } + + public static final class Factory implements ClientProtocolFactory { + + @Override + public ShapeId id() { + return AwsQueryTrait.ID; + } + + @Override + public ClientProtocol createProtocol(ProtocolSettings settings, AwsQueryTrait trait) { + return new AwsQueryClientProtocol( + Objects.requireNonNull(settings.service(), "service is a required protocol setting"), + Objects.requireNonNull(settings.serviceVersion(), + "serviceVersion is a required protocol setting for AWS Query.")); + } + } +} diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryFormSerializer.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryFormSerializer.java new file mode 100644 index 000000000..76fcaf961 --- /dev/null +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryFormSerializer.java @@ -0,0 +1,651 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.aws.client.awsquery; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Arrays; +import java.util.Base64; +import java.util.function.BiConsumer; +import software.amazon.smithy.java.core.schema.Schema; +import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.schema.TraitKey; +import software.amazon.smithy.java.core.serde.MapSerializer; +import software.amazon.smithy.java.core.serde.SerializationException; +import software.amazon.smithy.java.core.serde.ShapeSerializer; +import software.amazon.smithy.java.core.serde.TimestampFormatter; +import software.amazon.smithy.java.core.serde.document.Document; +import software.amazon.smithy.model.traits.TimestampFormatTrait; + +final class AwsQueryFormSerializer implements ShapeSerializer { + private static final byte[] ACTION_PREFIX = "Action=".getBytes(StandardCharsets.UTF_8); + private static final byte[] VERSION_PREFIX = "&Version=".getBytes(StandardCharsets.UTF_8); + private static final byte[] MEMBER = "member".getBytes(StandardCharsets.UTF_8); + private static final byte[] ENTRY = "entry".getBytes(StandardCharsets.UTF_8); + private static final byte[] KEY = "key".getBytes(StandardCharsets.UTF_8); + private static final byte[] VALUE = "value".getBytes(StandardCharsets.UTF_8); + + private final FormUrlEncodedSink sink; + + private byte[][] prefixCache = new byte[8][]; + private int[] prefixLengths = new int[8]; + private int prefixDepth = 0; + + private final ListItemSerializer listSerializer = new ListItemSerializer(); + private final QueryMapSerializer mapSerializer = new QueryMapSerializer(); + + AwsQueryFormSerializer(String action, String version) { + this.sink = new FormUrlEncodedSink(); + sink.writeBytes(ACTION_PREFIX, 0, ACTION_PREFIX.length); + sink.writeAscii(action); + sink.writeBytes(VERSION_PREFIX, 0, VERSION_PREFIX.length); + sink.writeAscii(version); + } + + ByteBuffer finish() { + return sink.finish(); + } + + private void writeParam(byte[] key, String value) { + sink.writeByte('&'); + writeCurrentPrefix(); + if (prefixDepth > 0) { + sink.writeByte('.'); + } + sink.writeBytes(key, 0, key.length); + sink.writeByte('='); + sink.writeUrlEncoded(value); + } + + private void writeParam(String key, String value) { + sink.writeByte('&'); + writeCurrentPrefix(); + if (prefixDepth > 0) { + sink.writeByte('.'); + } + sink.writeUrlEncoded(key); + sink.writeByte('='); + sink.writeUrlEncoded(value); + } + + private void writeCurrentPrefix() { + for (int i = 0; i < prefixDepth; i++) { + if (i > 0) { + sink.writeByte('.'); + } + sink.writeBytes(prefixCache[i], 0, prefixLengths[i]); + } + } + + private void pushPrefix(String prefix) { + if (prefixDepth >= prefixCache.length) { + prefixCache = Arrays.copyOf(prefixCache, prefixCache.length * 2); + prefixLengths = Arrays.copyOf(prefixLengths, prefixLengths.length * 2); + } + byte[] encoded = encodePrefix(prefix); + prefixCache[prefixDepth] = encoded; + prefixLengths[prefixDepth++] = encoded.length; + } + + private void pushPrefix(byte[] prefix) { + if (prefixDepth >= prefixCache.length) { + prefixCache = Arrays.copyOf(prefixCache, prefixCache.length * 2); + prefixLengths = Arrays.copyOf(prefixLengths, prefixLengths.length * 2); + } + prefixCache[prefixDepth] = prefix; + prefixLengths[prefixDepth++] = prefix.length; + } + + private void pushIndexedPrefix(byte[] base, int index) { + if (prefixDepth >= prefixCache.length) { + prefixCache = Arrays.copyOf(prefixCache, prefixCache.length * 2); + prefixLengths = Arrays.copyOf(prefixLengths, prefixLengths.length * 2); + } + byte[] encoded = encodeIndexedPrefix(base, index); + prefixCache[prefixDepth] = encoded; + prefixLengths[prefixDepth++] = encoded.length; + } + + private void pushIndexPrefix(int index) { + if (prefixDepth >= prefixCache.length) { + prefixCache = Arrays.copyOf(prefixCache, prefixCache.length * 2); + prefixLengths = Arrays.copyOf(prefixLengths, prefixLengths.length * 2); + } + byte[] encoded = encodeIndex(index); + prefixCache[prefixDepth] = encoded; + prefixLengths[prefixDepth++] = encoded.length; + } + + private void popPrefix() { + prefixDepth--; + } + + private byte[] encodePrefix(String prefix) { + FormUrlEncodedSink tmp = new FormUrlEncodedSink(prefix.length() * 3); + tmp.writeUrlEncoded(prefix); + ByteBuffer bb = tmp.finish(); + byte[] result = new byte[bb.remaining()]; + bb.get(result); + return result; + } + + @SuppressWarnings("deprecation") + private byte[] encodeIndexedPrefix(byte[] base, int index) { + String indexStr = Integer.toString(index); + byte[] result = new byte[base.length + 1 + indexStr.length()]; + System.arraycopy(base, 0, result, 0, base.length); + result[base.length] = '.'; + indexStr.getBytes(0, indexStr.length(), result, base.length + 1); + return result; + } + + @SuppressWarnings("deprecation") + private byte[] encodeIndex(int index) { + String indexStr = Integer.toString(index); + byte[] result = new byte[indexStr.length()]; + indexStr.getBytes(0, indexStr.length(), result, 0); + return result; + } + + private static String getMemberName(Schema schema) { + var xmlName = schema.getTrait(TraitKey.XML_NAME_TRAIT); + if (xmlName != null) { + return xmlName.getValue(); + } + return schema.memberName(); + } + + @Override + public void writeStruct(Schema schema, SerializableStruct struct) { + if (schema.isMember()) { + String memberName = getMemberName(schema); + if (memberName != null) { + pushPrefix(memberName); + struct.serializeMembers(this); + popPrefix(); + return; + } + } + struct.serializeMembers(this); + } + + @Override + public void writeList(Schema schema, T listState, int size, BiConsumer consumer) { + boolean flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + Schema memberSchema = schema.listMember(); + + if (schema.isMember()) { + pushPrefix(getMemberName(schema)); + } + + if (size == 0) { + writeEmptyValue(); + if (schema.isMember()) { + popPrefix(); + } + return; + } + + byte[] memberNameBytes; + if (flattened) { + memberNameBytes = null; + } else { + var xmlName = memberSchema.getTrait(TraitKey.XML_NAME_TRAIT); + memberNameBytes = xmlName != null ? xmlName.getValue().getBytes(StandardCharsets.UTF_8) : MEMBER; + } + + listSerializer.reset(memberNameBytes, flattened); + consumer.accept(listState, listSerializer); + + if (schema.isMember()) { + popPrefix(); + } + } + + private void writeEmptyValue() { + sink.writeByte('&'); + writeCurrentPrefix(); + sink.writeByte('='); + } + + private final class ListItemSerializer implements ShapeSerializer { + private byte[] memberNameBytes; + private boolean flattened; + private int index; + + void reset(byte[] memberNameBytes, boolean flattened) { + this.memberNameBytes = memberNameBytes; + this.flattened = flattened; + this.index = 1; + } + + private void pushIndexedMemberPrefix() { + if (flattened) { + pushIndexPrefix(index); + } else { + pushIndexedPrefix(memberNameBytes, index); + } + } + + @Override + public void writeStruct(Schema schema, SerializableStruct struct) { + pushIndexedMemberPrefix(); + index++; + struct.serializeMembers(AwsQueryFormSerializer.this); + popPrefix(); + } + + @Override + public void writeList(Schema schema, T listState, int size, BiConsumer consumer) { + pushIndexedMemberPrefix(); + index++; + AwsQueryFormSerializer.this.writeList(schema, listState, size, consumer); + popPrefix(); + } + + @Override + public void writeMap(Schema schema, T mapState, int size, BiConsumer consumer) { + pushIndexedMemberPrefix(); + index++; + AwsQueryFormSerializer.this.writeMap(schema, mapState, size, consumer); + popPrefix(); + } + + @Override + public void writeBoolean(Schema schema, boolean value) { + writeIndexedParam(value ? "true" : "false"); + } + + @Override + public void writeByte(Schema schema, byte value) { + writeIndexedParam(Byte.toString(value)); + } + + @Override + public void writeShort(Schema schema, short value) { + writeIndexedParam(Short.toString(value)); + } + + @Override + public void writeInteger(Schema schema, int value) { + writeIndexedParam(Integer.toString(value)); + } + + @Override + public void writeLong(Schema schema, long value) { + writeIndexedParam(Long.toString(value)); + } + + @Override + public void writeFloat(Schema schema, float value) { + if (Float.isNaN(value)) { + writeIndexedParam("NaN"); + } else if (Float.isInfinite(value)) { + writeIndexedParam(value > 0 ? "Infinity" : "-Infinity"); + } else { + writeIndexedParam(Float.toString(value)); + } + } + + @Override + public void writeDouble(Schema schema, double value) { + if (Double.isNaN(value)) { + writeIndexedParam("NaN"); + } else if (Double.isInfinite(value)) { + writeIndexedParam(value > 0 ? "Infinity" : "-Infinity"); + } else { + writeIndexedParam(Double.toString(value)); + } + } + + @Override + public void writeBigInteger(Schema schema, BigInteger value) { + writeIndexedParam(value.toString()); + } + + @Override + public void writeBigDecimal(Schema schema, BigDecimal value) { + writeIndexedParam(value.toPlainString()); + } + + @Override + public void writeString(Schema schema, String value) { + writeIndexedParam(value); + } + + @Override + public void writeBlob(Schema schema, ByteBuffer value) { + byte[] bytes = new byte[value.remaining()]; + value.duplicate().get(bytes); + writeIndexedParam(Base64.getEncoder().encodeToString(bytes)); + } + + @Override + public void writeTimestamp(Schema schema, Instant value) { + TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); + writeIndexedParam(formatter.writeString(value)); + } + + @Override + public void writeDocument(Schema schema, Document value) { + throw new SerializationException("AWS Query protocol does not support document types"); + } + + @Override + public void writeNull(Schema schema) { + index++; + } + + private void writeIndexedParam(String value) { + sink.writeByte('&'); + writeCurrentPrefix(); + if (prefixDepth > 0) { + sink.writeByte('.'); + } + if (flattened) { + sink.writeInt(index); + } else { + sink.writeBytes(memberNameBytes, 0, memberNameBytes.length); + sink.writeByte('.'); + sink.writeInt(index); + } + sink.writeByte('='); + sink.writeUrlEncoded(value); + index++; + } + } + + @Override + public void writeMap(Schema schema, T mapState, int size, BiConsumer consumer) { + boolean flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + Schema keySchema = schema.mapKeyMember(); + Schema valueSchema = schema.mapValueMember(); + + if (schema.isMember()) { + pushPrefix(getMemberName(schema)); + } + + var keyXmlName = keySchema.getTrait(TraitKey.XML_NAME_TRAIT); + var valueXmlName = valueSchema.getTrait(TraitKey.XML_NAME_TRAIT); + + byte[] keyNameBytes = keyXmlName != null ? keyXmlName.getValue().getBytes(StandardCharsets.UTF_8) : KEY; + byte[] valueNameBytes = valueXmlName != null ? valueXmlName.getValue().getBytes(StandardCharsets.UTF_8) : VALUE; + byte[] entryNameBytes = flattened ? null : ENTRY; + + mapSerializer.reset(entryNameBytes, keyNameBytes, valueNameBytes, flattened); + consumer.accept(mapState, mapSerializer); + + if (schema.isMember()) { + popPrefix(); + } + } + + private final class QueryMapSerializer implements MapSerializer { + private byte[] entryNameBytes; + private byte[] keyNameBytes; + private byte[] valueNameBytes; + private boolean flattened; + private int index; + + void reset(byte[] entryNameBytes, byte[] keyNameBytes, byte[] valueNameBytes, boolean flattened) { + this.entryNameBytes = entryNameBytes; + this.keyNameBytes = keyNameBytes; + this.valueNameBytes = valueNameBytes; + this.flattened = flattened; + this.index = 1; + } + + @Override + public void writeEntry( + Schema keySchema, + String key, + T state, + BiConsumer valueSerializer + ) { + if (flattened) { + pushIndexPrefix(index); + } else { + pushIndexedPrefix(entryNameBytes, index); + } + + writeParam(keyNameBytes, key); + + pushPrefix(valueNameBytes); + valueSerializer.accept(state, mapValueSerializer); + popPrefix(); + + popPrefix(); + index++; + } + } + + private final MapValueSerializer mapValueSerializer = new MapValueSerializer(); + + private final class MapValueSerializer implements ShapeSerializer { + @Override + public void writeStruct(Schema schema, SerializableStruct struct) { + struct.serializeMembers(AwsQueryFormSerializer.this); + } + + @Override + public void writeList(Schema schema, T listState, int size, BiConsumer consumer) { + boolean flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + Schema memberSchema = schema.listMember(); + + if (size == 0) { + writeEmptyValue(); + return; + } + + byte[] memberNameBytes; + if (flattened) { + memberNameBytes = null; + } else { + var xmlName = memberSchema.getTrait(TraitKey.XML_NAME_TRAIT); + memberNameBytes = xmlName != null ? xmlName.getValue().getBytes(StandardCharsets.UTF_8) : MEMBER; + } + + listSerializer.reset(memberNameBytes, flattened); + consumer.accept(listState, listSerializer); + } + + @Override + public void writeMap(Schema schema, T mapState, int size, BiConsumer consumer) { + boolean flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + Schema keySchema = schema.mapKeyMember(); + Schema valueSchema = schema.mapValueMember(); + + var keyXmlName = keySchema.getTrait(TraitKey.XML_NAME_TRAIT); + var valueXmlName = valueSchema.getTrait(TraitKey.XML_NAME_TRAIT); + + byte[] keyNameBytes = keyXmlName != null ? keyXmlName.getValue().getBytes(StandardCharsets.UTF_8) : KEY; + byte[] valueNameBytes = + valueXmlName != null ? valueXmlName.getValue().getBytes(StandardCharsets.UTF_8) : VALUE; + byte[] entryNameBytes = flattened ? null : ENTRY; + + mapSerializer.reset(entryNameBytes, keyNameBytes, valueNameBytes, flattened); + consumer.accept(mapState, mapSerializer); + } + + @Override + public void writeBoolean(Schema schema, boolean value) { + writeValueParam(value ? "true" : "false"); + } + + @Override + public void writeByte(Schema schema, byte value) { + writeValueParam(Byte.toString(value)); + } + + @Override + public void writeShort(Schema schema, short value) { + writeValueParam(Short.toString(value)); + } + + @Override + public void writeInteger(Schema schema, int value) { + writeValueParam(Integer.toString(value)); + } + + @Override + public void writeLong(Schema schema, long value) { + writeValueParam(Long.toString(value)); + } + + @Override + public void writeFloat(Schema schema, float value) { + if (Float.isNaN(value)) { + writeValueParam("NaN"); + } else if (Float.isInfinite(value)) { + writeValueParam(value > 0 ? "Infinity" : "-Infinity"); + } else { + writeValueParam(Float.toString(value)); + } + } + + @Override + public void writeDouble(Schema schema, double value) { + if (Double.isNaN(value)) { + writeValueParam("NaN"); + } else if (Double.isInfinite(value)) { + writeValueParam(value > 0 ? "Infinity" : "-Infinity"); + } else { + writeValueParam(Double.toString(value)); + } + } + + @Override + public void writeBigInteger(Schema schema, BigInteger value) { + writeValueParam(value.toString()); + } + + @Override + public void writeBigDecimal(Schema schema, BigDecimal value) { + writeValueParam(value.toPlainString()); + } + + @Override + public void writeString(Schema schema, String value) { + writeValueParam(value); + } + + @Override + public void writeBlob(Schema schema, ByteBuffer value) { + byte[] bytes = new byte[value.remaining()]; + value.duplicate().get(bytes); + writeValueParam(Base64.getEncoder().encodeToString(bytes)); + } + + @Override + public void writeTimestamp(Schema schema, Instant value) { + TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); + writeValueParam(formatter.writeString(value)); + } + + @Override + public void writeDocument(Schema schema, Document value) { + throw new SerializationException("AWS Query protocol does not support document types"); + } + + @Override + public void writeNull(Schema schema) {} + + private void writeValueParam(String value) { + sink.writeByte('&'); + writeCurrentPrefix(); + sink.writeByte('='); + sink.writeUrlEncoded(value); + } + } + + @Override + public void writeBoolean(Schema schema, boolean value) { + writeParam(getMemberName(schema), value ? "true" : "false"); + } + + @Override + public void writeByte(Schema schema, byte value) { + writeParam(getMemberName(schema), Byte.toString(value)); + } + + @Override + public void writeShort(Schema schema, short value) { + writeParam(getMemberName(schema), Short.toString(value)); + } + + @Override + public void writeInteger(Schema schema, int value) { + writeParam(getMemberName(schema), Integer.toString(value)); + } + + @Override + public void writeLong(Schema schema, long value) { + writeParam(getMemberName(schema), Long.toString(value)); + } + + @Override + public void writeFloat(Schema schema, float value) { + String memberName = getMemberName(schema); + if (Float.isNaN(value)) { + writeParam(memberName, "NaN"); + } else if (Float.isInfinite(value)) { + writeParam(memberName, value > 0 ? "Infinity" : "-Infinity"); + } else { + writeParam(memberName, Float.toString(value)); + } + } + + @Override + public void writeDouble(Schema schema, double value) { + String memberName = getMemberName(schema); + if (Double.isNaN(value)) { + writeParam(memberName, "NaN"); + } else if (Double.isInfinite(value)) { + writeParam(memberName, value > 0 ? "Infinity" : "-Infinity"); + } else { + writeParam(memberName, Double.toString(value)); + } + } + + @Override + public void writeBigInteger(Schema schema, BigInteger value) { + writeParam(getMemberName(schema), value.toString()); + } + + @Override + public void writeBigDecimal(Schema schema, BigDecimal value) { + writeParam(getMemberName(schema), value.toPlainString()); + } + + @Override + public void writeString(Schema schema, String value) { + writeParam(getMemberName(schema), value); + } + + @Override + public void writeBlob(Schema schema, ByteBuffer value) { + byte[] bytes = new byte[value.remaining()]; + value.duplicate().get(bytes); + writeParam(getMemberName(schema), Base64.getEncoder().encodeToString(bytes)); + } + + @Override + public void writeTimestamp(Schema schema, Instant value) { + TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); + writeParam(getMemberName(schema), formatter.writeString(value)); + } + + @Override + public void writeDocument(Schema schema, Document value) { + throw new SerializationException("AWS Query protocol does not support document types"); + } + + @Override + public void writeNull(Schema schema) {} +} diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryXmlResponseDeserializer.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryXmlResponseDeserializer.java new file mode 100644 index 000000000..da6e56424 --- /dev/null +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryXmlResponseDeserializer.java @@ -0,0 +1,31 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.aws.client.awsquery; + +import java.nio.ByteBuffer; +import java.util.List; +import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.schema.ShapeBuilder; +import software.amazon.smithy.java.xml.XmlCodec; + +final class AwsQueryXmlResponseDeserializer { + + private final ByteBuffer source; + private final String operationName; + + AwsQueryXmlResponseDeserializer(ByteBuffer source, String operationName) { + this.source = source; + this.operationName = operationName; + } + + T deserialize(ShapeBuilder builder) { + try (var codec = XmlCodec.builder() + .wrapperElements(List.of(operationName + "Response", operationName + "Result")) + .build()) { + return codec.deserializeShape(source, builder); + } + } +} diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSink.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSink.java new file mode 100644 index 000000000..6587ea542 --- /dev/null +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSink.java @@ -0,0 +1,127 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.aws.client.awsquery; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +final class FormUrlEncodedSink { + private static final byte[] HEX = { + '0', + '1', + '2', + '3', + '4', + '5', + '6', + '7', + '8', + '9', + 'A', + 'B', + 'C', + 'D', + 'E', + 'F' + }; + + private byte[] bytes; + private int pos; + + FormUrlEncodedSink() { + this.bytes = new byte[256]; + this.pos = 0; + } + + FormUrlEncodedSink(int initialCapacity) { + this.bytes = new byte[initialCapacity]; + this.pos = 0; + } + + void writeByte(int b) { + ensureCapacity(1); + bytes[pos++] = (byte) b; + } + + void writeBytes(byte[] b, int off, int len) { + ensureCapacity(len); + System.arraycopy(b, off, bytes, pos, len); + pos += len; + } + + @SuppressWarnings("deprecation") + void writeAscii(String s) { + int len = s.length(); + ensureCapacity(len); + s.getBytes(0, len, bytes, pos); + pos += len; + } + + void writeUrlEncoded(String s) { + int len = s.length(); + ensureCapacity(len * 3); + for (int i = 0; i < len; i++) { + char c = s.charAt(i); + if (isUnreserved(c)) { + bytes[pos++] = (byte) c; + } else if (c < 0x80) { + writePercentEncoded(c); + } else if (c < 0x800) { + writePercentEncoded(0xC0 | (c >> 6)); + writePercentEncoded(0x80 | (c & 0x3F)); + } else if (Character.isHighSurrogate(c) && i + 1 < len) { + char low = s.charAt(++i); + if (Character.isLowSurrogate(low)) { + int cp = Character.toCodePoint(c, low); + writePercentEncoded(0xF0 | (cp >> 18)); + writePercentEncoded(0x80 | ((cp >> 12) & 0x3F)); + writePercentEncoded(0x80 | ((cp >> 6) & 0x3F)); + writePercentEncoded(0x80 | (cp & 0x3F)); + } + } else { + writePercentEncoded(0xE0 | (c >> 12)); + writePercentEncoded(0x80 | ((c >> 6) & 0x3F)); + writePercentEncoded(0x80 | (c & 0x3F)); + } + } + } + + @SuppressWarnings("deprecation") + void writeInt(int value) { + String s = Integer.toString(value); + int len = s.length(); + ensureCapacity(len); + s.getBytes(0, len, bytes, pos); + pos += len; + } + + ByteBuffer finish() { + return ByteBuffer.wrap(bytes, 0, pos); + } + + private static boolean isUnreserved(char c) { + return (c >= 'A' && c <= 'Z') + || (c >= 'a' && c <= 'z') + || (c >= '0' && c <= '9') + || c == '-' + || c == '.' + || c == '_' + || c == '~'; + } + + private void writePercentEncoded(int b) { + bytes[pos++] = '%'; + bytes[pos++] = HEX[(b >> 4) & 0xF]; + bytes[pos++] = HEX[b & 0xF]; + } + + private void ensureCapacity(int len) { + int required = pos + len; + if (required > bytes.length) { + bytes = Arrays.copyOf(bytes, Math.max(required, bytes.length + (bytes.length >> 1))); + } + } +} diff --git a/aws/client/aws-client-awsquery/src/main/resources/META-INF/services/software.amazon.smithy.java.client.core.ClientProtocolFactory b/aws/client/aws-client-awsquery/src/main/resources/META-INF/services/software.amazon.smithy.java.client.core.ClientProtocolFactory new file mode 100644 index 000000000..29dfb11f9 --- /dev/null +++ b/aws/client/aws-client-awsquery/src/main/resources/META-INF/services/software.amazon.smithy.java.client.core.ClientProtocolFactory @@ -0,0 +1 @@ +software.amazon.smithy.java.aws.client.awsquery.AwsQueryClientProtocol$Factory diff --git a/client/client-core/src/main/java/software/amazon/smithy/java/client/core/ProtocolSettings.java b/client/client-core/src/main/java/software/amazon/smithy/java/client/core/ProtocolSettings.java index e8b3aa7bd..fc3d76313 100644 --- a/client/client-core/src/main/java/software/amazon/smithy/java/client/core/ProtocolSettings.java +++ b/client/client-core/src/main/java/software/amazon/smithy/java/client/core/ProtocolSettings.java @@ -12,21 +12,36 @@ */ public final class ProtocolSettings { private final ShapeId service; + private final String serviceVersion; private ProtocolSettings(Builder builder) { this.service = builder.service; + this.serviceVersion = builder.serviceVersion; } public ShapeId service() { return service; } + /** + * Gets the service version string. + * + *

The service version is required by some protocols (e.g., AWS Query) + * that include the version in the request body. + * + * @return the service version, or null if not set + */ + public String serviceVersion() { + return serviceVersion; + } + public static Builder builder() { return new Builder(); } public static final class Builder { private ShapeId service; + private String serviceVersion; private Builder() {} @@ -35,6 +50,17 @@ public Builder service(ShapeId service) { return this; } + /** + * Sets the service version string. + * + * @param serviceVersion the service version + * @return the builder + */ + public Builder serviceVersion(String serviceVersion) { + this.serviceVersion = serviceVersion; + return this; + } + public ProtocolSettings build() { return new ProtocolSettings(this); } diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java index 0928c772d..54b76c698 100644 --- a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java @@ -7,6 +7,7 @@ import java.io.OutputStream; import java.nio.ByteBuffer; +import java.util.List; import javax.xml.stream.XMLEventFactory; import javax.xml.stream.XMLInputFactory; import javax.xml.stream.XMLOutputFactory; @@ -27,6 +28,7 @@ public final class XmlCodec implements Codec { private final XMLOutputFactory xmlOutputFactory; private final XmlInfo xmlInfo = new XmlInfo(); private final XMLEventFactory eventFactory = XMLEventFactory.newInstance(); + private final List wrapperElements; private XmlCodec(Builder builder) { xmlInputFactory = XMLInputFactory.newInstance(); @@ -35,6 +37,7 @@ private XmlCodec(Builder builder) { xmlInputFactory.setProperty(XMLInputFactory.IS_REPLACING_ENTITY_REFERENCES, false); xmlInputFactory.setProperty(XMLInputFactory.IS_COALESCING, false); xmlOutputFactory = XMLOutputFactory.newInstance(); + this.wrapperElements = builder.wrapperElements; } /** @@ -59,7 +62,11 @@ public ShapeSerializer createSerializer(OutputStream sink) { public ShapeDeserializer createDeserializer(ByteBuffer source) { try { var reader = xmlInputFactory.createXMLStreamReader(ByteBufferUtils.byteBufferInputStream(source)); - return XmlDeserializer.topLevel(xmlInfo, eventFactory, new XmlReader.StreamReader(reader, xmlInputFactory)); + return XmlDeserializer.topLevel( + xmlInfo, + eventFactory, + new XmlReader.StreamReader(reader, xmlInputFactory), + wrapperElements); } catch (XMLStreamException e) { throw new RuntimeException(e); } @@ -69,16 +76,36 @@ public ShapeDeserializer createDeserializer(ByteBuffer source) { * Builder used to create an XML codec. */ public static final class Builder { + private List wrapperElements = List.of(); private Builder() {} + /** + * Configure wrapper elements to skip during deserialization. + * + *

When deserializing, these elements are skipped in order at the top level only + * before reading the actual content. This is useful for protocols like AWS Query + * where responses are wrapped in elements like {@code } + * and {@code }. + * + *

The elements must match exactly (not by suffix) and are only skipped at + * the top level, not for nested structures. + * + * @param wrapperElements the list of wrapper element names to skip, in order + * @return the builder + */ + public Builder wrapperElements(List wrapperElements) { + this.wrapperElements = wrapperElements; + return this; + } + /** * Create the codec and ensure all required settings are present. * * @return the codec. * @throws NullPointerException if any required settings are missing. */ - public Codec build() { + public XmlCodec build() { return new XmlCodec(this); } } diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java index c9f843206..475e0e4d0 100644 --- a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java @@ -11,6 +11,7 @@ import java.time.DateTimeException; import java.time.Instant; import java.util.Base64; +import java.util.List; import javax.xml.stream.XMLEventFactory; import javax.xml.stream.XMLStreamException; import software.amazon.smithy.java.core.schema.Schema; @@ -29,13 +30,23 @@ final class XmlDeserializer implements ShapeDeserializer { private final XMLEventFactory eventFactory; private final InnerDeserializer innerDeserializer; private final boolean isTopLevel; + private final List wrapperElements; static XmlDeserializer topLevel( XmlInfo xmlInfo, XMLEventFactory eventFactory, XmlReader reader ) throws XMLStreamException { - return new XmlDeserializer(xmlInfo, eventFactory, reader, true); + return new XmlDeserializer(xmlInfo, eventFactory, reader, true, List.of()); + } + + static XmlDeserializer topLevel( + XmlInfo xmlInfo, + XMLEventFactory eventFactory, + XmlReader reader, + List wrapperElements + ) throws XMLStreamException { + return new XmlDeserializer(xmlInfo, eventFactory, reader, true, wrapperElements); } static XmlDeserializer flattened( @@ -43,14 +54,20 @@ static XmlDeserializer flattened( XMLEventFactory eventFactory, XmlReader reader ) throws XMLStreamException { - return new XmlDeserializer(xmlInfo, eventFactory, reader, false); + return new XmlDeserializer(xmlInfo, eventFactory, reader, false, List.of()); } - private XmlDeserializer(XmlInfo xmlInfo, XMLEventFactory eventFactory, XmlReader reader, boolean isTopLevel) - throws XMLStreamException { + private XmlDeserializer( + XmlInfo xmlInfo, + XMLEventFactory eventFactory, + XmlReader reader, + boolean isTopLevel, + List wrapperElements + ) throws XMLStreamException { this.xmlInfo = xmlInfo; this.reader = reader; this.isTopLevel = isTopLevel; + this.wrapperElements = wrapperElements; this.eventFactory = eventFactory; this.innerDeserializer = new InnerDeserializer(); } @@ -83,6 +100,12 @@ private void enter(Schema schema) { return; } + if (!wrapperElements.isEmpty()) { + // Skip wrapper elements (exact match), then deserialize content without root validation. + skipWrapperElements(); + return; + } + var name = reader.nextMemberElement(); String expected; var trait = schema.getTrait(TraitKey.XML_NAME_TRAIT); @@ -105,6 +128,24 @@ private void enter(Schema schema) { } } + private void skipWrapperElements() throws XMLStreamException { + // Navigate through wrapper elements (exact match) + // After this, we should be positioned so the next nextMemberElement returns content + for (String wrapperName : wrapperElements) { + var name = reader.nextMemberElement(); + if (name == null) { + return; + } + if (!name.equals(wrapperName)) { + // Not the expected wrapper element - protocol mismatch + throw new SerializationException( + "Expected wrapper element '" + wrapperName + "', found '" + name + "'"); + } + // Continue to next wrapper level + } + // Now positioned inside the innermost wrapper, ready for readStruct to read members + } + private void exit() { try { reader.closeElement(); diff --git a/codegen/plugins/client-codegen/src/main/java/software/amazon/smithy/java/codegen/client/generators/ClientInterfaceGenerator.java b/codegen/plugins/client-codegen/src/main/java/software/amazon/smithy/java/codegen/client/generators/ClientInterfaceGenerator.java index 479104873..52a874120 100644 --- a/codegen/plugins/client-codegen/src/main/java/software/amazon/smithy/java/codegen/client/generators/ClientInterfaceGenerator.java +++ b/codegen/plugins/client-codegen/src/main/java/software/amazon/smithy/java/codegen/client/generators/ClientInterfaceGenerator.java @@ -172,6 +172,7 @@ final class RequestOverrideBuilder extends ${requestOverride:T}.OverrideBuilder< new DefaultProtocolGenerator( writer, settings.service(), + directive.service().getVersion(), defaultProtocolTrait, directive.context())); writer.putContext("clientPlugin", ClientPlugin.class); @@ -388,6 +389,7 @@ public void run() { private record DefaultProtocolGenerator( JavaWriter writer, ShapeId service, + String serviceVersion, Trait defaultProtocolTrait, CodeGenerationContext context) implements Runnable { @@ -400,6 +402,7 @@ public void run() { var template = """ private static final ${protocolSettings:T} protocolSettings = ${protocolSettings:T}.builder() .service(${shapeId:T}.from(${service:S})) + .serviceVersion(${serviceVersion:S}) .build(); private static final ${trait:T} protocolTrait = ${initializer:C}; """; @@ -409,6 +412,7 @@ public void run() { writer.putContext("initializer", writer.consumer(w -> initializer.accept(w, defaultProtocolTrait))); writer.putContext("shapeId", ShapeId.class); writer.putContext("service", service); + writer.putContext("serviceVersion", serviceVersion); writer.write(template); writer.popState(); } diff --git a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java index 6614aa3bb..3873ad0df 100644 --- a/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java +++ b/protocol-test-harness/src/main/java/software/amazon/smithy/java/protocoltests/harness/ProtocolTestExtension.java @@ -242,7 +242,10 @@ private static Model applyServiceTransformations(ServiceShape service) { if (protocolFactory == null) { continue; } - var protocolSettings = ProtocolSettings.builder().service(service.getId()).build(); + var protocolSettings = ProtocolSettings.builder() + .service(service.getId()) + .serviceVersion(service.getVersion()) + .build(); var instance = protocolFactory.createProtocol(protocolSettings, protocolTraitEntry.getValue()); protocols.put(protocolTraitEntry.getKey(), instance); } diff --git a/settings.gradle.kts b/settings.gradle.kts index 562f62f7e..03df9af0e 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -76,6 +76,7 @@ include(":aws:client:aws-client-core") include(":aws:client:aws-client-http") include(":aws:client:aws-client-restjson") include(":aws:client:aws-client-restxml") +include(":aws:client:aws-client-awsquery") include(":aws:client:aws-client-rulesengine") include(":aws:integrations:aws-lambda-endpoint") include(":aws:server:aws-server-restjson")