Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import software.amazon.smithy.java.core.serde.event.EventEncoderFactory;
import software.amazon.smithy.java.core.serde.event.EventStreamingException;
import software.amazon.smithy.java.core.serde.event.FrameEncoder;
import software.amazon.smithy.java.core.serde.event.FrameTransformer;

/**
* A {@link EventEncoderFactory} for AWS events.
Expand All @@ -24,19 +25,22 @@ public final class AwsEventEncoderFactory implements EventEncoderFactory<AwsEven
private final Schema schema;
private final Codec codec;
private final String payloadMediaType;
private final FrameTransformer<AwsEventFrame> transformer;
private final Function<Throwable, EventStreamingException> exceptionHandler;

private AwsEventEncoderFactory(
InitialEventType initialEventType,
Schema schema,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> transformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
this.initialEventType = Objects.requireNonNull(initialEventType, "initialEventType");
this.schema = Objects.requireNonNull(schema, "schema").isMember() ? schema.memberTarget() : schema;
this.codec = Objects.requireNonNull(codec, "codec");
this.payloadMediaType = Objects.requireNonNull(payloadMediaType, "payloadMediaType");
this.transformer = Objects.requireNonNull(transformer, "transformer");
this.exceptionHandler = Objects.requireNonNull(exceptionHandler, "exceptionHandler");
}

Expand All @@ -53,12 +57,14 @@ public static AwsEventEncoderFactory forInputStream(
InputEventStreamingApiOperation<?, ?, ?> operation,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> transformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
return new AwsEventEncoderFactory(InitialEventType.INITIAL_REQUEST,
operation.inputStreamMember(),
codec,
payloadMediaType,
transformer,
exceptionHandler);
}

Expand All @@ -75,18 +81,25 @@ public static AwsEventEncoderFactory forOutputStream(
OutputEventStreamingApiOperation<?, ?, ?> operation,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> transformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
return new AwsEventEncoderFactory(InitialEventType.INITIAL_RESPONSE,
operation.outputStreamMember(),
codec,
payloadMediaType,
transformer,
exceptionHandler);
}

@Override
public EventEncoder<AwsEventFrame> newEventEncoder() {
return new AwsEventShapeEncoder(initialEventType, schema, codec, payloadMediaType, exceptionHandler);
return new AwsEventShapeEncoder(initialEventType,
schema,
codec,
payloadMediaType,
transformer,
exceptionHandler);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.time.Instant;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Flow;
import java.util.function.Supplier;
import software.amazon.eventstream.HeaderValue;
import software.amazon.eventstream.Message;
Expand All @@ -21,6 +20,7 @@
import software.amazon.smithy.java.core.serde.ShapeDeserializer;
import software.amazon.smithy.java.core.serde.SpecificShapeDeserializer;
import software.amazon.smithy.java.core.serde.event.EventDecoder;
import software.amazon.smithy.java.core.serde.event.EventStream;

/**
* A decoder for AWS events
Expand All @@ -36,7 +36,6 @@ public final class AwsEventShapeDecoder<E extends SerializableStruct, IR extends
private final Supplier<ShapeBuilder<E>> eventBuilder;
private final Schema eventSchema;
private final Codec codec;
private volatile Flow.Publisher<SerializableStruct> publisher;

AwsEventShapeDecoder(
InitialEventType initialEventType,
Expand All @@ -54,19 +53,9 @@ public final class AwsEventShapeDecoder<E extends SerializableStruct, IR extends

@Override
public SerializableStruct decode(AwsEventFrame frame) {
var message = frame.unwrap();
var eventType = getEventType(message);
if (initialEventType.value().equals(eventType)) {
return decodeInitialResponse(frame);
}
return decodeEvent(frame);
}

@Override
public void onPrepare(Flow.Publisher<SerializableStruct> publisher) {
this.publisher = publisher;
}

private E decodeEvent(AwsEventFrame frame) {
var message = frame.unwrap();
var eventType = getEventType(message);
Expand All @@ -85,12 +74,13 @@ private E decodeEvent(AwsEventFrame frame) {
return builder.build();
}

private IR decodeInitialResponse(AwsEventFrame frame) {
@Override
public IR decodeInitialEvent(AwsEventFrame frame, EventStream<?> eventStream) {
var message = frame.unwrap();
var builder = initialEventBuilder.get();
var publisherMember = getPublisherMember(builder.schema());
var publisherMember = getEventStreamMember(builder.schema());
// Set the publisher member
var responseDeserializer = new InitialResponseDeserializer(publisherMember, publisher);
var responseDeserializer = new InitialResponseDeserializer(publisherMember, eventStream);
builder.deserialize(responseDeserializer);
// Deserialize the rest of the members if any
var headers = message.getHeaders();
Expand All @@ -100,7 +90,7 @@ private IR decodeInitialResponse(AwsEventFrame frame) {
return builder.build();
}

private Schema getPublisherMember(Schema schema) {
private Schema getEventStreamMember(Schema schema) {
for (var member : schema.members()) {
if (member.memberTarget().hasTrait(TraitKey.STREAMING_TRAIT)) {
return member;
Expand All @@ -115,16 +105,16 @@ private String getEventType(Message message) {

static class InitialResponseDeserializer extends SpecificShapeDeserializer {
private final Schema publisherMember;
private final Flow.Publisher<? extends SerializableStruct> publisher;
private final EventStream<?> eventStream;

InitialResponseDeserializer(Schema publisherMember, Flow.Publisher<? extends SerializableStruct> publisher) {
InitialResponseDeserializer(Schema publisherMember, EventStream<?> eventStream) {
this.publisherMember = publisherMember;
this.publisher = publisher;
this.eventStream = eventStream;
}

@Override
public Flow.Publisher<? extends SerializableStruct> readEventStream(Schema schema) {
return publisher;
public EventStream<? extends SerializableStruct> readEventStream(Schema schema) {
return eventStream;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import software.amazon.smithy.java.core.serde.SpecificShapeSerializer;
import software.amazon.smithy.java.core.serde.event.EventEncoder;
import software.amazon.smithy.java.core.serde.event.EventStreamingException;
import software.amazon.smithy.java.core.serde.event.FrameTransformer;
import software.amazon.smithy.model.shapes.ShapeId;

public final class AwsEventShapeEncoder implements EventEncoder<AwsEventFrame> {
Expand All @@ -35,13 +36,15 @@ public final class AwsEventShapeEncoder implements EventEncoder<AwsEventFrame> {
private final String payloadMediaType;
private final Map<String, BiFunction<OutputStream, Map<String, HeaderValue>, ShapeSerializer>> possibleTypes;
private final Map<ShapeId, Schema> possibleExceptions;
private final FrameTransformer<AwsEventFrame> frameTransformer;
private final Function<Throwable, EventStreamingException> exceptionHandler;

public AwsEventShapeEncoder(
InitialEventType initialEventType,
Schema eventSchema,
Codec codec,
String payloadMediaType,
FrameTransformer<AwsEventFrame> frameTransformer,
Function<Throwable, EventStreamingException> exceptionHandler
) {
this.initialEventType = Objects.requireNonNull(initialEventType, "initialEventType");
Expand All @@ -51,6 +54,7 @@ public AwsEventShapeEncoder(
codec,
initialEventType.value());
this.possibleExceptions = possibleExceptions(Objects.requireNonNull(eventSchema, "eventSchema"));
this.frameTransformer = Objects.requireNonNull(frameTransformer, "frameTransformer");
this.exceptionHandler = Objects.requireNonNull(exceptionHandler, "exceptionHandler");
}

Expand All @@ -62,7 +66,8 @@ public AwsEventFrame encode(SerializableStruct item) {
headers.put(":message-type", HeaderValue.fromString("event"));
headers.put(":event-type", HeaderValue.fromString(typeHolder.get()));
headers.put(":content-type", HeaderValue.fromString(payloadMediaType));
return new AwsEventFrame(new Message(headers, payload));
var frame = new AwsEventFrame(new Message(headers, payload));
return frameTransformer.apply(frame);
}

private byte[] encodeInput(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public AwsFrameDecoder(FrameTransformer<AwsEventFrame> transformer) {
public List<AwsEventFrame> decode(ByteBuffer buffer) {
decoder.feed(buffer);
var messages = decoder.getDecodedMessages();
var result = new ArrayList<AwsEventFrame>();
var result = new ArrayList<AwsEventFrame>(messages.size());
for (var message : messages) {
var event = new AwsEventFrame(message);
var transformed = transformer.apply(event);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@

package software.amazon.smithy.java.aws.events;

import java.nio.ByteBuffer;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import software.amazon.smithy.java.core.schema.Schema;
import software.amazon.smithy.java.core.schema.SerializableStruct;
import software.amazon.smithy.java.core.schema.TraitKey;
import software.amazon.smithy.java.core.serde.event.EventDecoderFactory;
import software.amazon.smithy.java.core.serde.event.EventEncoderFactory;
import software.amazon.smithy.java.core.serde.event.EventStreamFrameDecodingProcessor;
import software.amazon.smithy.java.core.serde.event.EventStreamFrameEncodingProcessor;
import software.amazon.smithy.java.core.serde.event.EventStream;
import software.amazon.smithy.java.core.serde.event.InternalEventStreamReader;
import software.amazon.smithy.java.core.serde.event.InternalEventStreamWriter;
import software.amazon.smithy.java.io.datastream.DataStream;

/**
Expand All @@ -24,47 +22,36 @@ public final class RpcEventStreamsUtil {

private RpcEventStreamsUtil() {}

public static Flow.Publisher<ByteBuffer> bodyForEventStreaming(
@SuppressWarnings("unchecked")
public static DataStream bodyForEventStreaming(
EventEncoderFactory<AwsEventFrame> eventStreamEncodingFactory,
SerializableStruct input
) {
Flow.Publisher<SerializableStruct> eventStream = input.getMemberValue(streamingMember(input.schema()));
return EventStreamFrameEncodingProcessor.create(eventStream, eventStreamEncodingFactory, input);
}

// TODO: Make more synchronous
public static <O extends SerializableStruct> O deserializeResponse(
EventDecoderFactory<AwsEventFrame> eventDecoderFactory,
DataStream bodyDataStream
) {
var result = new CompletableFuture<O>();
var processor = EventStreamFrameDecodingProcessor.create(bodyDataStream, eventDecoderFactory);

// A subscriber to serialize the initial event.
processor.subscribe(new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
subscription.request(1);
}

EventStream<SerializableStruct> eventStream = input.getMemberValue(streamingMember(input.schema()));
InternalEventStreamWriter<SerializableStruct, SerializableStruct, AwsEventFrame> writer =
InternalEventStreamWriter.toInternal(eventStream);
writer.bootstrap(new InternalEventStreamWriter.Bootstrap<>() {
@Override
@SuppressWarnings("unchecked")
public void onNext(SerializableStruct item) {
result.complete((O) item);
public EventEncoderFactory<AwsEventFrame> encoder() {
return eventStreamEncodingFactory;
}

@Override
public void onError(Throwable throwable) {
result.completeExceptionally(throwable);
}

@Override
public void onComplete() {
result.completeExceptionally(new RuntimeException("Unexpected event stream completion"));
public SerializableStruct initialEvent() {
return input;
}
});
return writer.toDataStream();
}

return result.join();
public static <O extends SerializableStruct> O deserializeResponse(
EventDecoderFactory<AwsEventFrame> eventDecoderFactory,
DataStream bodyDataStream
) {
var reader = InternalEventStreamReader.<O, SerializableStruct, AwsEventFrame>newReader(bodyDataStream,
eventDecoderFactory,
true);
return reader.readInitialEvent();
}

private static Schema streamingMember(Schema schema) {
Expand All @@ -75,5 +62,4 @@ private static Schema streamingMember(Schema schema) {
}
throw new IllegalArgumentException("No streaming member found");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public void testDecodeInitialResponse() {
var frame = new AwsEventFrame(message);

// Act
var struct = createDecoder().decode(frame);
var struct = createDecoder().decodeInitialEvent(frame, null);

// Assert
assertInstanceOf(TestOperationOutput.class, struct);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import software.amazon.smithy.java.aws.events.model.TestOperationInput;
import software.amazon.smithy.java.core.serde.Codec;
import software.amazon.smithy.java.core.serde.event.EventStreamingException;
import software.amazon.smithy.java.core.serde.event.FrameTransformer;
import software.amazon.smithy.java.json.JsonCodec;

class AwsEventShapeEncoderTest {
Expand Down Expand Up @@ -135,6 +136,7 @@ static AwsEventShapeEncoder createEncoder() {
TestOperation.instance().inputStreamMember(), // event schema
createJsonCodec(), // codec
"text/json",
FrameTransformer.identity(),
(e) -> new EventStreamingException("InternalServerException", "Internal Server Error"));
}

Expand Down
Loading