Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public interface Authenticator {
*
* @return a {@link CompletableFuture} that will complete with the access token
*/
CompletableFuture<String> asyncToken();
CompletableFuture<String> tokenAsync();

/**
* Returns the authentication scheme.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ public String token() {
}

@Override
public CompletableFuture<String> asyncToken() {
public CompletableFuture<String> tokenAsync() {

TokenResponse currentToken = token.get();

if (!isExpired(currentToken))
return completedFuture(token.get().accessToken());

return client.asyncToken(new TokenRequest(username, password, apiKey)).thenApply(identityTokenResponse -> {
return client.tokenAsync(new TokenRequest(username, password, apiKey)).thenApply(identityTokenResponse -> {
token.getAndSet(identityTokenResponse);
return identityTokenResponse.accessToken();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ protected CP4DRestClient(Builder<?, ?> builder) {
*
* @return a {@link CompletableFuture} that contains the token and related metadata
*/
public abstract CompletableFuture<TokenResponse> asyncToken(TokenRequest request);
public abstract CompletableFuture<TokenResponse> tokenAsync(TokenRequest request);

/**
* Creates a new {@link Builder} by loading the first available {@code CP4D*RestClientBuilderFactory} discovered via {@link ServiceLoader},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public TokenResponse token(TokenRequest request) {
}

@Override
public CompletableFuture<TokenResponse> asyncToken(TokenRequest request) {
public CompletableFuture<TokenResponse> tokenAsync(TokenRequest request) {
return asyncHttpClient
.send(createTokenRequest(request), BodyHandlers.ofString())
.thenCompose(response -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public TokenResponse token(TokenRequest request) {
}

@Override
public CompletableFuture<TokenResponse> asyncToken(TokenRequest request) {
public CompletableFuture<TokenResponse> tokenAsync(TokenRequest request) {
return asyncHttpClient
.send(createTokenRequest(request), BodyHandlers.ofString())
.thenApplyAsync(this::parseTokenResponse, ExecutorProvider.cpuExecutor())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public TokenResponse token(TokenRequest request) {
}

@Override
public CompletableFuture<TokenResponse> asyncToken(TokenRequest request) {
public CompletableFuture<TokenResponse> tokenAsync(TokenRequest request) {
return completedFuture(createTokenResponse(request));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public TokenResponse token(String apiKey, String grantType) {
}

@Override
public CompletableFuture<TokenResponse> asyncToken(String apiKey, String grantType) {
public CompletableFuture<TokenResponse> tokenAsync(String apiKey, String grantType) {
return asyncHttpClient.send(createHttpRequest(apiKey, grantType), BodyHandlers.ofString())
.thenApplyAsync(response -> fromJson(response.body(), TokenResponse.class), ExecutorProvider.cpuExecutor())
.thenApplyAsync(Function.identity(), ExecutorProvider.ioExecutor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ public String token() {
}

@Override
public CompletableFuture<String> asyncToken() {
public CompletableFuture<String> tokenAsync() {

TokenResponse currentToken = token.get();

if (!isExpired(currentToken))
return completedFuture(currentToken.accessToken());

return client.asyncToken(apiKey, grantType).thenApply(identityTokenResponse -> {
return client.tokenAsync(apiKey, grantType).thenApply(identityTokenResponse -> {
token.getAndSet(identityTokenResponse);
return identityTokenResponse.accessToken();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ protected IBMCloudRestClient(Builder<?, ?> builder) {
* @param grantType the grant type to use
* @return a {@link CompletableFuture} that completes with the IAM response or exceptionally on error
*/
public abstract CompletableFuture<TokenResponse> asyncToken(String apiKey, String grantType);
public abstract CompletableFuture<TokenResponse> tokenAsync(String apiKey, String grantType);

/**
* Creates a new {@link Builder} using the first available {@link IBMCloudRestClientBuilderFactory} discovered via {@link ServiceLoader}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public AuthenticationInterceptor(Authenticator authenticator) {

@Override
public <T> CompletableFuture<HttpResponse<T>> intercept(HttpRequest request, BodyHandler<T> bodyHandler, int index, AsyncChain chain) {
return authenticator.asyncToken()
return authenticator.tokenAsync()
.thenCompose(token -> chain.proceed(requestWithAuthHeader(request, token), bodyHandler));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Async {
@Test
void should_send_request_with_bearer_token() throws Exception {

when(mockAuthenticator.asyncToken()).thenReturn(completedFuture("my_super_token"));
when(mockAuthenticator.tokenAsync()).thenReturn(completedFuture("my_super_token"));
when(mockAuthenticator.scheme()).thenReturn("Bearer");

withWatsonxServiceMock(() -> {
Expand All @@ -164,7 +164,7 @@ void should_send_request_with_bearer_token() throws Exception {
void should_send_request_with_zen_api_key() throws Exception {

var cp4dAuthenticatorMock = mock(CP4DAuthenticator.class);
when(cp4dAuthenticatorMock.asyncToken()).thenReturn(completedFuture("#1234"));
when(cp4dAuthenticatorMock.tokenAsync()).thenReturn(completedFuture("#1234"));
when(cp4dAuthenticatorMock.scheme()).thenReturn("ZenApiKey");
when(cp4dAuthenticatorMock.isAuthMode(AuthMode.ZEN_API_KEY)).thenReturn(true);

Expand All @@ -189,7 +189,7 @@ void should_send_request_with_zen_api_key() throws Exception {
@Test
void should_throw_exception_when_bearer_token_is_invalid() {

when(mockAuthenticator.asyncToken()).thenThrow(new RuntimeException("error"));
when(mockAuthenticator.tokenAsync()).thenThrow(new RuntimeException("error"));

withWatsonxServiceMock(() -> {

Expand All @@ -212,7 +212,7 @@ void should_throw_exception_when_bearer_token_is_invalid() {
@SuppressWarnings("unchecked")
void should_execute_with_custom_executor() throws Exception {

when(mockAuthenticator.asyncToken()).thenReturn(completedFuture("my_super_token"));
when(mockAuthenticator.tokenAsync()).thenReturn(completedFuture("my_super_token"));
var threadNames = new ArrayList<>();

var ioExecutor = Executors.newSingleThreadScheduledExecutor(r -> new Thread(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ void should_return_a_valid_token() throws Exception {
.apiKey("api_key")
.build();

assertEquals("access-token", assertDoesNotThrow(() -> authenticator.asyncToken().get()));
assertEquals("access-token", assertDoesNotThrow(() -> authenticator.tokenAsync().get()));
});
}

Expand Down Expand Up @@ -543,7 +543,7 @@ void should_use_iam_authentication() {
.authMode(AuthMode.IAM)
.build();

String token = authenticator.asyncToken().get();
String token = authenticator.tokenAsync().get();
assertEquals("access-token", token);
verify(mockSecureHttpClient, times(2)).sendAsync(reqCaptor.capture(), any());

Expand Down Expand Up @@ -587,9 +587,9 @@ void should_use_cached_token() throws Exception {
.build();

// Execute the http request.
assertDoesNotThrow(() -> authenticator.asyncToken().get());
assertDoesNotThrow(() -> authenticator.tokenAsync().get());
// Get the value from the cache.
assertDoesNotThrow(() -> authenticator.asyncToken().get());
assertDoesNotThrow(() -> authenticator.tokenAsync().get());

verify(mockSecureHttpClient, times(1)).sendAsync(any(), any());
});
Expand Down Expand Up @@ -626,10 +626,10 @@ else if (url.endsWith("/v1/preauth/validateAuth GET"))
.authMode(AuthMode.IAM)
.build();

var token = authenticator.asyncToken().get();
var token = authenticator.tokenAsync().get();
assertEquals("iam-access-token", token);

token = authenticator.asyncToken().get();
token = authenticator.tokenAsync().get();
assertEquals("iam-access-token", token);

ArgumentCaptor<HttpRequest> captor = ArgumentCaptor.forClass(HttpRequest.class);
Expand Down Expand Up @@ -663,7 +663,7 @@ void should_use_zen_api_key_authentication_mode() {
.build();

var encoded = Base64.encodeBase64String("username:api_key".getBytes());
assertEquals(encoded, assertDoesNotThrow(() -> authenticator.asyncToken().get()));
assertEquals(encoded, assertDoesNotThrow(() -> authenticator.tokenAsync().get()));
assertTrue(authenticator.isAuthMode(AuthMode.ZEN_API_KEY));
}

Expand Down Expand Up @@ -706,7 +706,7 @@ void should_use_the_correct_executors() throws Exception {
.apiKey("api_key")
.build();

assertDoesNotThrow(() -> authenticator.asyncToken()
assertDoesNotThrow(() -> authenticator.tokenAsync()
.thenRunAsync(() -> threadNames.add(Thread.currentThread().getName()), ioExecutor)
.thenRunAsync(() -> threadNames.add(Thread.currentThread().getName()), cpuExecutor)
.get(3, TimeUnit.SECONDS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void should_return_a_valid_token() throws Exception {
.apiKey("my_super_api_key")
.build();

assertEquals("my_super_token", assertDoesNotThrow(() -> authenticator.asyncToken().get()));
assertEquals("my_super_token", assertDoesNotThrow(() -> authenticator.tokenAsync().get()));
assertEquals("https://iam.cloud.ibm.com/identity/token", mockHttpRequest.getValue().uri().toString());
assertEquals("application/x-www-form-urlencoded",
mockHttpRequest.getValue().headers().firstValue("Content-Type").get());
Expand All @@ -311,9 +311,9 @@ void should_use_a_cached_token() throws Exception {
.build();

// Execute the http request.
assertDoesNotThrow(() -> authenticator.asyncToken().get());
assertDoesNotThrow(() -> authenticator.tokenAsync().get());
// Get the value from the cache.
assertDoesNotThrow(() -> authenticator.asyncToken().get());
assertDoesNotThrow(() -> authenticator.tokenAsync().get());

verify(mockSecureHttpClient, times(1)).sendAsync(any(), any());
});
Expand All @@ -335,7 +335,7 @@ void should_use_custom_parameters() throws Exception {
.apiKey("my_super_api_key")
.build();

assertEquals("my_super_token", assertDoesNotThrow(() -> authenticator.asyncToken().get()));
assertEquals("my_super_token", assertDoesNotThrow(() -> authenticator.tokenAsync().get()));
assertEquals("http://mytest.com/identity/token", mockHttpRequest.getValue().uri().toString());
assertEquals("application/x-www-form-urlencoded",
mockHttpRequest.getValue().headers().firstValue("Content-Type").get());
Expand Down Expand Up @@ -378,7 +378,7 @@ void should_use_the_correct_executors() throws Exception {
.apiKey("my_super_api_key")
.build();

assertDoesNotThrow(() -> authenticator.asyncToken()
assertDoesNotThrow(() -> authenticator.tokenAsync()
.thenRunAsync(() -> threadNames.add(Thread.currentThread().getName()), ioExecutor)
.thenRunAsync(() -> threadNames.add(Thread.currentThread().getName()), cpuExecutor)
.get(3, TimeUnit.SECONDS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,22 @@
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;
import static java.util.Objects.requireNonNullElse;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.runAsync;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.time.Duration;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.List;
import java.util.StringJoiner;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import com.ibm.watsonx.ai.WatsonxService.ProjectService;
import com.ibm.watsonx.ai.chat.ChatRequest;
import com.ibm.watsonx.ai.chat.ChatResponse;
import com.ibm.watsonx.ai.chat.ChatUtility;
import com.ibm.watsonx.ai.core.Json;
import com.ibm.watsonx.ai.core.auth.Authenticator;
import com.ibm.watsonx.ai.core.provider.ExecutorProvider;
import com.ibm.watsonx.ai.core.spi.json.TypeToken;
import com.ibm.watsonx.ai.file.FileDeleteRequest;
import com.ibm.watsonx.ai.file.FileService;
Expand Down Expand Up @@ -292,10 +286,25 @@ public <T> List<BatchResult<T>> submitAndFetch(BatchCreateRequest request, Class

while (status != Status.COMPLETED && status != Status.FAILED) {

if (LocalTime.now().isAfter(endTime))
if (LocalTime.now().isAfter(endTime)) {

cancel(
BatchCancelRequest.builder()
.batchId(batchData.id())
.projectId(projectSpace.projectId())
.spaceId(projectSpace.spaceId())
.transactionId(request.transactionId())
.build());

deleteFile(
removeUploadedFile ? batchData.inputFileId() : null,
null,
request.transactionId());

throw new RuntimeException(
"The execution of the batch operation for the file \"%s\" took longer than the timeout set by %s milliseconds"
.formatted(request.inputFileId(), timeout.toMillis()));
}

try {

Expand All @@ -322,8 +331,7 @@ public <T> List<BatchResult<T>> submitAndFetch(BatchCreateRequest request, Class
deleteFile(
removeUploadedFile ? batchData.inputFileId() : null,
null,
request.transactionId(),
timeout
request.transactionId()
);
throw new RuntimeException("The batch operation failed: %s".formatted(batchData));
}
Expand All @@ -337,8 +345,7 @@ public <T> List<BatchResult<T>> submitAndFetch(BatchCreateRequest request, Class
deleteFile(
removeUploadedFile ? batchData.inputFileId() : null,
removeOutputFile ? batchData.outputFileId() : null,
request.transactionId(),
timeout
request.transactionId()
);

return result;
Expand Down Expand Up @@ -499,25 +506,20 @@ public BatchData cancel(BatchCancelRequest request) {

/**
* Deletes the input and/or output files associated with a completed batch job.
* <p>
* Each non-null file identifier is deleted concurrently. The method blocks until all deletions complete or the timeout expires.
*
* @param inputFileId the identifier of the input file to delete, or {@code null} to skip
* @param outputFileId the identifier of the output file to delete, or {@code null} to skip
* @param transactionId optional transaction identifier to propagate to the delete requests
* @param timeout the maximum time to wait for all deletions to complete
*/
private void deleteFile(String inputFileId, String outputFileId, String transactionId, Duration timeout) {

var futures = new ArrayList<CompletableFuture<Void>>();
private void deleteFile(String inputFileId, String outputFileId, String transactionId) {

if (nonNull(inputFileId)) {
var request = FileDeleteRequest.builder()
.fileId(inputFileId)
.transactionId(transactionId)
.build();

futures.add(runAsync(() -> fileService.delete(request), ExecutorProvider.callbackExecutor()));
fileService.deleteAsync(request);
}

if (nonNull(outputFileId)) {
Expand All @@ -526,10 +528,8 @@ private void deleteFile(String inputFileId, String outputFileId, String transact
.transactionId(transactionId)
.build();

futures.add(runAsync(() -> fileService.delete(request), ExecutorProvider.callbackExecutor()));
fileService.deleteAsync(request);
}

allOf(futures.toArray(new CompletableFuture[0])).join();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import static java.util.Objects.nonNull;
import static java.util.Objects.requireNonNull;
import static java.util.Objects.requireNonNullElse;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import com.ibm.watsonx.ai.chat.model.ChatMessage;
import com.ibm.watsonx.ai.chat.model.ChatParameters;
Expand Down Expand Up @@ -208,7 +208,7 @@ public Builder messages(ChatMessage... messages) {
*/
public Builder messages(List<? extends ChatMessage> messages) {
if (nonNull(messages))
this.messages = new LinkedList<>(messages);
this.messages = new ArrayList<>(messages);
return this;
}

Expand Down Expand Up @@ -236,7 +236,7 @@ public Builder addMessages(List<? extends ChatMessage> messages) {
if (isNull(messages) || messages.isEmpty())
return this;

this.messages = requireNonNullElse(this.messages, new LinkedList<>());
this.messages = requireNonNullElse(this.messages, new ArrayList<>());
this.messages.addAll(messages);
return this;
}
Expand Down
Loading