Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/main/java/io/qdrant/client/QdrantGrpcClient.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package io.qdrant.client;

import io.grpc.CallCredentials;
import io.grpc.Channel;
import io.grpc.ClientInterceptors;
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
Expand All @@ -21,6 +23,7 @@ public class QdrantGrpcClient implements AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(QdrantGrpcClient.class);
@Nullable private final CallCredentials callCredentials;
private final ManagedChannel channel;
private final Channel interceptedChannel;
private final boolean shutdownChannelOnClose;
@Nullable private final Duration timeout;

Expand All @@ -31,6 +34,8 @@ public class QdrantGrpcClient implements AutoCloseable {
@Nullable Duration timeout) {
this.callCredentials = callCredentials;
this.channel = channel;
this.interceptedChannel =
ClientInterceptors.intercept(channel, RequestHeaders.newInterceptor());
this.shutdownChannelOnClose = shutdownChannelOnClose;
this.timeout = timeout;
}
Expand Down Expand Up @@ -136,7 +141,7 @@ public ManagedChannel channel() {
* @return a new instance of {@link QdrantFutureStub}
*/
public QdrantGrpc.QdrantFutureStub qdrant() {
return QdrantGrpc.newFutureStub(channel)
return QdrantGrpc.newFutureStub(interceptedChannel)
.withCallCredentials(callCredentials)
.withDeadline(
timeout != null ? Deadline.after(timeout.toMillis(), TimeUnit.MILLISECONDS) : null);
Expand All @@ -148,7 +153,7 @@ public QdrantGrpc.QdrantFutureStub qdrant() {
* @return a new instance of {@link PointsFutureStub}
*/
public PointsFutureStub points() {
return PointsGrpc.newFutureStub(channel)
return PointsGrpc.newFutureStub(interceptedChannel)
.withCallCredentials(callCredentials)
.withDeadline(
timeout != null ? Deadline.after(timeout.toMillis(), TimeUnit.MILLISECONDS) : null);
Expand All @@ -160,7 +165,7 @@ public PointsFutureStub points() {
* @return a new instance of {@link CollectionsFutureStub}
*/
public CollectionsFutureStub collections() {
return CollectionsGrpc.newFutureStub(channel)
return CollectionsGrpc.newFutureStub(interceptedChannel)
.withCallCredentials(callCredentials)
.withDeadline(
timeout != null ? Deadline.after(timeout.toMillis(), TimeUnit.MILLISECONDS) : null);
Expand All @@ -172,7 +177,7 @@ public CollectionsFutureStub collections() {
* @return a new instance of {@link SnapshotsFutureStub}
*/
public SnapshotsFutureStub snapshots() {
return SnapshotsGrpc.newFutureStub(channel)
return SnapshotsGrpc.newFutureStub(interceptedChannel)
.withCallCredentials(callCredentials)
.withDeadline(
timeout != null ? Deadline.after(timeout.toMillis(), TimeUnit.MILLISECONDS) : null);
Expand Down
85 changes: 85 additions & 0 deletions src/main/java/io/qdrant/client/RequestHeaders.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.qdrant.client;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.Context;
import io.grpc.ForwardingClientCall;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Utilities for attaching per-request headers to gRPC calls.
*
* <pre>{@code
* Context ctx = RequestHeaders.withHeader(Context.current(), "x-request-id", "abc-123");
* ctx.run(() -> client.listCollectionsAsync());
* }</pre>
*/
public final class RequestHeaders {

static final Context.Key<Map<String, String>> HEADERS_KEY = Context.key("qdrant-request-headers");

private RequestHeaders() {}

/**
* Returns a new {@link Context} that carries key/value as a gRPC metadata header on every request
* started within that context.
*
* @param ctx the parent context
* @param key the header name
* @param value the header value
* @return a child context with the header attached
*/
public static Context withHeader(Context ctx, String key, String value) {
return withHeaders(ctx, Collections.singletonMap(key, value));
}

/**
* Returns a new {@link Context} that carries all entries of headers as gRPC metadata on every
* request started within that context.
*
* @param ctx the parent context
* @param headers the headers to attach
* @return a child context with the headers attached
*/
public static Context withHeaders(Context ctx, Map<String, String> headers) {
if (headers == null || headers.isEmpty()) {
return ctx;
}
Map<String, String> merged = new HashMap<>();
Map<String, String> current = HEADERS_KEY.get(ctx);
if (current != null) merged.putAll(current);
merged.putAll(headers);
return ctx.withValue(HEADERS_KEY, merged);
}

/** Returns a {@link ClientInterceptor} that injects per-request headers from the context. */
static ClientInterceptor newInterceptor() {
return new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No way around this boilerplate 🙁

MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
next.newCall(method, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
Map<String, String> extra = HEADERS_KEY.get();
if (extra != null) {
for (Map.Entry<String, String> entry : extra.entrySet()) {
headers.put(
Metadata.Key.of(entry.getKey(), Metadata.ASCII_STRING_MARSHALLER),
entry.getValue());
}
}
super.start(responseListener, headers);
}
};
}
};
}
}
Loading