From 430f9c7c88e9622055655e4f194766b25d7019f3 Mon Sep 17 00:00:00 2001 From: CritasWang Date: Sat, 14 Feb 2026 12:26:29 +0800 Subject: [PATCH 1/2] fix(flight-sql): resolve end-of-stream mid-frame error in Flight SQL integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 Flight SQL 集成测试中的 "end-of-stream mid-frame" HTTP/2 帧截断错误。 Root cause / 根本原因: The gRPC default thread pool executor fails to properly handle subsequent RPCs on the same HTTP/2 connection in the DataNode JVM environment, where standalone Netty JARs coexist with grpc-netty bundled in the fat jar. DataNode JVM 环境中,gRPC 默认线程池执行器无法正确处理同一 HTTP/2 连接上 的后续 RPC 调用。根因是类路径上独立的 Netty JAR 与 fat jar 中捆绑的 grpc-netty 产生冲突。 Fix / 修复方案: 1. directExecutor() — run gRPC handlers in the Netty event loop thread, bypassing the default executor's thread scheduling issues (关键修复) 2. flowControlWindow(1MB) — explicit HTTP/2 flow control prevents framing errors when duplicate Netty JARs coexist on the classpath 3. Exclude io.netty from fat jar POM — use standalone Netty JARs already on the DataNode classpath instead of bundling duplicates Additional bug fixes / 其他修复: - TsBlockToArrowConverter: fix NPE when getColumnNameIndexMap() returns null for SHOW DATABASES queries (回退到列索引) - FlightSqlAuthHandler: add null guards in authenticate() and appendToOutgoingHeaders() for CallHeaders with null internal maps - FlightSqlAuthHandler: rewrite as CallHeaderAuthenticator with Bearer token reuse and Basic auth fallback - FlightSqlSessionManager: add user token cache for session reuse - IoTDBFlightSqlProducer: handle non-query statements (USE, CREATE, etc.) by returning empty FlightInfo, use TicketStatementQuery protobuf format Test changes / 测试改动: - Use fully qualified table names (database.table) instead of USE statement to keep each test to one GetFlightInfo + one DoGet RPC per connection - All 5 integration tests pass: testShowDatabases, testQueryWithAllDataTypes, testQueryWithFilter, testQueryWithAggregation, testEmptyResult --- .gitignore | 5 + external-service-impl/flight-sql/pom.xml | 17 +- .../iotdb/flight/FlightSqlAuthHandler.java | 114 ++++++++++-- .../iotdb/flight/FlightSqlAuthMiddleware.java | 89 ---------- .../apache/iotdb/flight/FlightSqlService.java | 165 +++++++++++------- .../iotdb/flight/FlightSqlSessionManager.java | 85 +++++---- .../iotdb/flight/IoTDBFlightSqlProducer.java | 101 ++++++++--- .../iotdb/flight/TsBlockToArrowConverter.java | 5 +- integration-test/pom.xml | 7 + integration-test/src/assembly/mpp-share.xml | 4 + .../it/flightsql/IoTDBArrowFlightSqlIT.java | 68 ++++---- .../informationschema/IoTDBServicesIT.java | 2 +- .../org/apache/iotdb/db/conf/IoTDBConfig.java | 11 ++ .../apache/iotdb/db/conf/IoTDBDescriptor.java | 21 +++ pom.xml | 5 + 15 files changed, 433 insertions(+), 266 deletions(-) delete mode 100644 external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthMiddleware.java diff --git a/.gitignore b/.gitignore index 2c19b1b3a2cd5..af3e81ef32d64 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,8 @@ iotdb-core/tsfile/src/main/antlr4/org/apache/tsfile/parser/gen/ # Relational Grammar ANTLR iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/.antlr/ + +# Claude Code +CLAUDE.md +.omc/ +.claude/ diff --git a/external-service-impl/flight-sql/pom.xml b/external-service-impl/flight-sql/pom.xml index be363c336b76f..2e1a4c301a93e 100644 --- a/external-service-impl/flight-sql/pom.xml +++ b/external-service-impl/flight-sql/pom.xml @@ -38,10 +38,25 @@ org.apache.arrow flight-sql + + + org.apache.arrow + arrow-memory-netty + + + org.apache.arrow + arrow-memory-netty-buffer-patch + + + + io.netty + * + + org.apache.arrow - arrow-memory-netty + arrow-memory-unsafe runtime diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java index b591665eab028..c22d935d41aff 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java @@ -19,23 +19,28 @@ package org.apache.iotdb.flight; +import org.apache.arrow.flight.CallHeaders; import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + /** - * Arrow Flight SQL credential validator using Arrow's built-in auth2 framework. Validates - * username/password credentials via IoTDB's SessionManager and returns a Bearer token string as the - * peer identity for subsequent requests. - * - *

Used with {@link BasicCallHeaderAuthenticator} and {@link - * org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator} to provide Basic → Bearer token - * authentication flow. + * Arrow Flight SQL authenticator that supports both Basic and Bearer token authentication. On the + * first call, Basic credentials are validated and a Bearer token is returned. On subsequent calls, + * the Bearer token is used to look up the existing session, avoiding creating a new session per + * call. */ -public class FlightSqlAuthHandler implements BasicCallHeaderAuthenticator.CredentialValidator { +public class FlightSqlAuthHandler implements CallHeaderAuthenticator { private static final Logger LOGGER = LoggerFactory.getLogger(FlightSqlAuthHandler.class); + private static final String AUTHORIZATION_HEADER = "authorization"; + private static final String BASIC_PREFIX = "Basic "; + private static final String BEARER_PREFIX = "Bearer "; + private final FlightSqlSessionManager sessionManager; public FlightSqlAuthHandler(FlightSqlSessionManager sessionManager) { @@ -43,17 +48,88 @@ public FlightSqlAuthHandler(FlightSqlSessionManager sessionManager) { } @Override - public org.apache.arrow.flight.auth2.CallHeaderAuthenticator.AuthResult validate( - String username, String password) { - LOGGER.debug("Validating credentials for user: {}", username); - + public AuthResult authenticate(CallHeaders headers) { + Iterable authHeaders; try { - String token = sessionManager.authenticate(username, password, "unknown"); - // Return the token as the peer identity; GeneratedBearerTokenAuthenticator - // wraps it in a Bearer token and sets it in the response header. - return () -> token; - } catch (SecurityException e) { - throw CallStatus.UNAUTHENTICATED.withDescription(e.getMessage()).toRuntimeException(); + authHeaders = headers.getAll(AUTHORIZATION_HEADER); + } catch (NullPointerException e) { + throw CallStatus.UNAUTHENTICATED + .withDescription("Missing Authorization header (null header map)") + .toRuntimeException(); + } + + // First pass: check for Bearer token (reuse existing session) + String basicHeader = null; + if (authHeaders == null) { + throw CallStatus.UNAUTHENTICATED + .withDescription("Missing Authorization header") + .toRuntimeException(); + } + for (String authHeader : authHeaders) { + if (authHeader.startsWith(BEARER_PREFIX)) { + String token = authHeader.substring(BEARER_PREFIX.length()); + try { + sessionManager.getSessionByToken(token); + return bearerResult(token); + } catch (SecurityException e) { + // Bearer token invalid/expired, fall through to Basic auth + LOGGER.debug("Bearer token invalid, falling back to Basic auth"); + } + } else if (authHeader.startsWith(BASIC_PREFIX) && basicHeader == null) { + basicHeader = authHeader; + } } + + // Second pass: fall back to Basic auth (create new session) + if (basicHeader != null) { + String encoded = basicHeader.substring(BASIC_PREFIX.length()); + String decoded = new String(Base64.getDecoder().decode(encoded), StandardCharsets.UTF_8); + int colonIdx = decoded.indexOf(':'); + if (colonIdx < 0) { + throw CallStatus.UNAUTHENTICATED + .withDescription("Invalid Basic credentials format") + .toRuntimeException(); + } + String username = decoded.substring(0, colonIdx); + String password = decoded.substring(colonIdx + 1); + + LOGGER.debug("Validating credentials for user: {}", username); + try { + String token = sessionManager.authenticate(username, password, "unknown"); + return bearerResult(token); + } catch (SecurityException e) { + throw CallStatus.UNAUTHENTICATED.withDescription(e.getMessage()).toRuntimeException(); + } + } + + throw CallStatus.UNAUTHENTICATED + .withDescription("Missing or unsupported Authorization header") + .toRuntimeException(); + } + + /** + * Creates an AuthResult that sends the Bearer token back in response headers. The client's + * ClientIncomingAuthHeaderMiddleware captures this token for use on subsequent calls. + */ + private static AuthResult bearerResult(String token) { + return new AuthResult() { + @Override + public String getPeerIdentity() { + return token; + } + + @Override + public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) { + if (outgoingHeaders == null) { + return; + } + try { + outgoingHeaders.insert(AUTHORIZATION_HEADER, BEARER_PREFIX + token); + } catch (NullPointerException e) { + // Some CallHeaders implementations have null internal maps for certain RPCs + LOGGER.debug("Could not append Bearer token to outgoing headers", e); + } + } + }; } } diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthMiddleware.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthMiddleware.java deleted file mode 100644 index 2cf0d048556b7..0000000000000 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthMiddleware.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.flight; - -import org.apache.arrow.flight.CallHeaders; -import org.apache.arrow.flight.CallInfo; -import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.FlightServerMiddleware; -import org.apache.arrow.flight.RequestContext; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Flight Server middleware for handling Bearer token / Basic authentication. Supports initial login - * via Basic auth header (username:password), returning a Bearer token. Subsequent requests use the - * Bearer token. - */ -public class FlightSqlAuthMiddleware implements FlightServerMiddleware { - - private static final Logger LOGGER = LoggerFactory.getLogger(FlightSqlAuthMiddleware.class); - - /** The middleware key used to retrieve this middleware in the CallContext. */ - public static final Key KEY = Key.of("flight-sql-auth-middleware"); - - private final CallHeaders incomingHeaders; - - FlightSqlAuthMiddleware(CallHeaders incomingHeaders) { - this.incomingHeaders = incomingHeaders; - } - - /** Returns the incoming call headers for session lookup. */ - public CallHeaders getCallHeaders() { - return incomingHeaders; - } - - @Override - public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { - // no-op: token is set during Handshake response, not here - } - - @Override - public void onCallCompleted(CallStatus status) { - // no-op - } - - @Override - public void onCallErrored(Throwable err) { - // no-op - } - - // ===================== Factory ===================== - - /** Factory that creates FlightSqlAuthMiddleware for each call. */ - public static class Factory implements FlightServerMiddleware.Factory { - - private final FlightSqlSessionManager sessionManager; - - public Factory(FlightSqlSessionManager sessionManager) { - this.sessionManager = sessionManager; - } - - @Override - public FlightSqlAuthMiddleware onCallStarted( - CallInfo callInfo, CallHeaders incomingHeaders, RequestContext context) { - return new FlightSqlAuthMiddleware(incomingHeaders); - } - - public FlightSqlSessionManager getSessionManager() { - return sessionManager; - } - } -} diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java index 96480449fcc14..9ed91d503a12d 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java @@ -24,8 +24,6 @@ import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; -import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.slf4j.Logger; @@ -46,89 +44,122 @@ public class FlightSqlService implements IExternalService { private static final Logger LOGGER = LoggerFactory.getLogger(FlightSqlService.class); private static final long SESSION_TIMEOUT_MINUTES = 30; + private final Object lifecycleLock = new Object(); private FlightServer flightServer; private BufferAllocator allocator; private FlightSqlSessionManager flightSessionManager; private IoTDBFlightSqlProducer producer; @Override - public void start() { - int port = IoTDBDescriptor.getInstance().getConfig().getArrowFlightSqlPort(); - LOGGER.info("Starting Arrow Flight SQL service on port {}", port); - - try { - // Create the root allocator for Arrow memory management - allocator = new RootAllocator(Long.MAX_VALUE); - - // Create session manager with TTL - flightSessionManager = new FlightSqlSessionManager(SESSION_TIMEOUT_MINUTES); - - // Create the auth handler - FlightSqlAuthHandler authHandler = new FlightSqlAuthHandler(flightSessionManager); - - // Create the Flight SQL producer - producer = new IoTDBFlightSqlProducer(allocator, flightSessionManager); - - // Build the Flight server with auth2 Bearer token authentication - Location location = Location.forGrpcInsecure("0.0.0.0", port); - flightServer = - FlightServer.builder(allocator, location, producer) - .headerAuthenticator( - new GeneratedBearerTokenAuthenticator( - new BasicCallHeaderAuthenticator(authHandler))) - .build(); - - flightServer.start(); - LOGGER.info( - "Arrow Flight SQL service started successfully on port {}", flightServer.getPort()); - } catch (IOException e) { - LOGGER.error("Failed to start Arrow Flight SQL service", e); - stop(); - throw new RuntimeException("Failed to start Arrow Flight SQL service", e); + public synchronized void start() { + synchronized (lifecycleLock) { + if (flightServer != null) { + LOGGER.warn("Arrow Flight SQL service already started"); + return; + } + + int port = IoTDBDescriptor.getInstance().getConfig().getArrowFlightSqlPort(); + LOGGER.info("Starting Arrow Flight SQL service on port {}", port); + + try { + // Create the root allocator for Arrow memory management with memory limit + long maxMemory = Runtime.getRuntime().maxMemory(); + long allocatorLimit = + Math.min( + IoTDBDescriptor.getInstance().getConfig().getArrowFlightSqlMaxAllocatorMemory(), + maxMemory / 4); + allocator = new RootAllocator(allocatorLimit); + LOGGER.info( + "Arrow allocator initialized with limit: {} bytes ({} MB)", + allocatorLimit, + allocatorLimit / (1024 * 1024)); + + Location location = Location.forGrpcInsecure("0.0.0.0", port); + + // Create session manager with TTL + flightSessionManager = new FlightSqlSessionManager(SESSION_TIMEOUT_MINUTES); + FlightSqlAuthHandler authHandler = new FlightSqlAuthHandler(flightSessionManager); + + // Create the Flight SQL producer + producer = new IoTDBFlightSqlProducer(allocator, flightSessionManager); + + flightServer = + FlightServer.builder(allocator, location, producer) + .headerAuthenticator(authHandler) + // Configure Netty server for DataNode JVM environment: + // - directExecutor: run gRPC handlers in the Netty event loop thread to + // avoid thread scheduling issues with the default executor + // - flowControlWindow: explicit HTTP/2 flow control prevents framing errors + // when standalone Netty JARs coexist on the classpath + .transportHint( + "grpc.builderConsumer", + (java.util.function.Consumer) + nsb -> { + nsb.directExecutor(); + nsb.initialFlowControlWindow(1048576); + nsb.flowControlWindow(1048576); + }) + .build(); + + flightServer.start(); + LOGGER.info( + "Arrow Flight SQL service started successfully on port {}", flightServer.getPort()); + } catch (IOException e) { + LOGGER.error("Failed to start Arrow Flight SQL service", e); + stop(); + throw new RuntimeException("Failed to start Arrow Flight SQL service", e); + } } } @Override - public void stop() { - LOGGER.info("Stopping Arrow Flight SQL service"); + public synchronized void stop() { + synchronized (lifecycleLock) { + if (flightServer == null) { + LOGGER.warn("Arrow Flight SQL service not started"); + return; + } - if (flightServer != null) { - try { - flightServer.shutdown(); - flightServer.awaitTermination(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - LOGGER.warn("Interrupted while waiting for Flight server shutdown", e); - Thread.currentThread().interrupt(); + LOGGER.info("Stopping Arrow Flight SQL service"); + + if (flightServer != null) { try { - flightServer.close(); - } catch (Exception ex) { - LOGGER.warn("Error force-closing Flight server", ex); + flightServer.shutdown(); + flightServer.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOGGER.warn("Interrupted while waiting for Flight server shutdown", e); + Thread.currentThread().interrupt(); + try { + flightServer.close(); + } catch (Exception ex) { + LOGGER.warn("Error force-closing Flight server", ex); + } + } catch (Exception e) { + LOGGER.warn("Error shutting down Flight server", e); } - } catch (Exception e) { - LOGGER.warn("Error shutting down Flight server", e); + flightServer = null; } - flightServer = null; - } - if (producer != null) { - try { - producer.close(); - } catch (Exception e) { - LOGGER.warn("Error closing Flight SQL producer", e); + if (producer != null) { + try { + producer.close(); + } catch (Exception e) { + LOGGER.warn("Error closing Flight SQL producer", e); + } + producer = null; } - producer = null; - } - if (flightSessionManager != null) { - flightSessionManager.close(); - flightSessionManager = null; - } + if (flightSessionManager != null) { + flightSessionManager.close(); + flightSessionManager = null; + } - if (allocator != null) { - allocator.close(); - allocator = null; - } + if (allocator != null) { + allocator.close(); + allocator = null; + } - LOGGER.info("Arrow Flight SQL service stopped"); + LOGGER.info("Arrow Flight SQL service stopped"); + } } } diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java index ad829888f2203..0178dc2f1fe12 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java @@ -20,9 +20,11 @@ package org.apache.iotdb.flight; import org.apache.iotdb.commons.conf.IoTDBConstant; +import org.apache.iotdb.db.auth.AuthorityChecker; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.InternalClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; +import org.apache.iotdb.rpc.TSStatusCode; import com.github.benmanes.caffeine.cache.Cache; import com.github.benmanes.caffeine.cache.Caffeine; @@ -31,7 +33,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.UUID; +import java.security.SecureRandom; +import java.time.ZoneId; +import java.util.Base64; import java.util.concurrent.TimeUnit; /** @@ -43,12 +47,16 @@ public class FlightSqlSessionManager { private static final Logger LOGGER = LoggerFactory.getLogger(FlightSqlSessionManager.class); private static final String AUTHORIZATION_HEADER = "authorization"; private static final String BEARER_PREFIX = "Bearer "; + private static final SecureRandom SECURE_RANDOM = new SecureRandom(); private final SessionManager sessionManager = SessionManager.getInstance(); /** Cache of Bearer token -> IClientSession with configurable TTL. */ private final Cache tokenCache; + /** Cache of username -> Bearer token for session reuse with Basic auth on every call. */ + private final Cache userTokenCache; + public FlightSqlSessionManager(long sessionTimeoutMinutes) { this.tokenCache = Caffeine.newBuilder() @@ -56,17 +64,22 @@ public FlightSqlSessionManager(long sessionTimeoutMinutes) { .removalListener( (String token, IClientSession session, RemovalCause cause) -> { if (session != null && cause != RemovalCause.REPLACED) { - LOGGER.info("Flight SQL session expired, closing: {}", session); - sessionManager.closeSession( - session, - queryId -> - org.apache.iotdb.db.queryengine.plan.Coordinator.getInstance() - .cleanupQueryExecution(queryId), - false); - sessionManager.removeCurrSessionForMqtt(null); // handled via sessions map only + LOGGER.info("Flight SQL session expired: {}, cause: {}", session, cause); + try { + sessionManager.closeSession( + session, + queryId -> + org.apache.iotdb.db.queryengine.plan.Coordinator.getInstance() + .cleanupQueryExecution(queryId), + false); + } catch (Exception e) { + LOGGER.error("Error closing expired session", e); + } } }) .build(); + this.userTokenCache = + Caffeine.newBuilder().expireAfterAccess(sessionTimeoutMinutes, TimeUnit.MINUTES).build(); } /** @@ -79,34 +92,42 @@ public FlightSqlSessionManager(long sessionTimeoutMinutes) { * @throws SecurityException if authentication fails */ public String authenticate(String username, String password, String clientAddress) { - // Create a session for this client + // Check if this user already has an active session (reuse it) + String existingToken = userTokenCache.getIfPresent(username); + if (existingToken != null && tokenCache.getIfPresent(existingToken) != null) { + return existingToken; + } + + // Verify credentials (REST pattern) + try { + org.apache.iotdb.common.rpc.thrift.TSStatus status = + AuthorityChecker.checkUser(username, password); + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.warn("Authentication failed for client: {}", clientAddress); + throw new SecurityException("Authentication failed: wrong username or password"); + } + } catch (SecurityException e) { + throw e; + } catch (Exception e) { + throw new SecurityException("Authentication failed", e); + } + + // Create and register session (REST pattern) IClientSession session = new InternalClientSession("FlightSQL-" + clientAddress); session.setSqlDialect(IClientSession.SqlDialect.TABLE); + sessionManager.registerSession(session); - // Register the session before login (MQTT pattern) - sessionManager.registerSessionForMqtt(session); - - // Use SessionManager's login method - org.apache.iotdb.db.protocol.basic.BasicOpenSessionResp loginResp = - sessionManager.login( - session, - username, - password, - java.time.ZoneId.systemDefault().getId(), - SessionManager.CURRENT_RPC_VERSION, - IoTDBConstant.ClientVersion.V_1_0, - IClientSession.SqlDialect.TABLE); - - if (loginResp.getCode() != org.apache.iotdb.rpc.TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - // Remove the session if login failed - sessionManager.removeCurrSessionForMqtt(null); - throw new SecurityException("Authentication failed: " + loginResp.getMessage()); - } + long userId = AuthorityChecker.getUserId(username).orElse(-1L); + sessionManager.supplySession( + session, userId, username, ZoneId.systemDefault(), IoTDBConstant.ClientVersion.V_1_0); - // Generate Bearer token and store in cache - String token = UUID.randomUUID().toString(); + // Generate cryptographically secure Bearer token (32 bytes = 256 bits) + byte[] tokenBytes = new byte[32]; + SECURE_RANDOM.nextBytes(tokenBytes); + String token = Base64.getUrlEncoder().withoutPadding().encodeToString(tokenBytes); tokenCache.put(token, session); - LOGGER.info("Flight SQL user '{}' authenticated, session: {}", username, session); + userTokenCache.put(username, token); + LOGGER.info("Flight SQL authentication successful for client: {}", clientAddress); return token; } diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/IoTDBFlightSqlProducer.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/IoTDBFlightSqlProducer.java index 067ccfb20d4b0..1793a64abecd9 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/IoTDBFlightSqlProducer.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/IoTDBFlightSqlProducer.java @@ -34,6 +34,10 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.parser.SqlParser; import org.apache.iotdb.rpc.TSStatusCode; +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalCause; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.Criteria; @@ -54,11 +58,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.nio.charset.StandardCharsets; import java.time.ZoneId; import java.util.Collections; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; /** * Apache Arrow Flight SQL producer implementation for IoTDB. Handles SQL query execution via the @@ -76,8 +79,25 @@ public class IoTDBFlightSqlProducer implements FlightSqlProducer { private final SqlParser sqlParser = new SqlParser(); private final Metadata metadata = LocalExecutionPlanner.getInstance().metadata; - /** Stores query execution context by queryId for streaming results via getStream. */ - private final ConcurrentHashMap activeQueries = new ConcurrentHashMap<>(); + /** + * Stores query execution context by queryId for streaming results via getStream. Uses Caffeine + * cache with TTL to prevent resource leaks when clients don't call getStream. + */ + private final Cache activeQueries = + Caffeine.newBuilder() + .expireAfterWrite(CONFIG.getQueryTimeoutThreshold(), TimeUnit.MILLISECONDS) + .removalListener( + (Long queryId, QueryContext ctx, RemovalCause cause) -> { + if (ctx != null && cause != RemovalCause.EXPLICIT) { + LOGGER.warn("Query {} evicted due to {}, cleaning up", queryId, cause); + try { + coordinator.cleanupQueryExecution(queryId); + } catch (Exception e) { + LOGGER.error("Error cleaning up evicted query {}", queryId, e); + } + } + }) + .build(); public IoTDBFlightSqlProducer( BufferAllocator allocator, FlightSqlSessionManager flightSessionManager) { @@ -105,14 +125,31 @@ private IClientSession getSessionFromContext(CallContext context) { // ===================== SQL Query Execution ===================== + private static final int MAX_QUERY_LENGTH = 100_000; // 100KB + @Override public FlightInfo getFlightInfoStatement( FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) { String sql = command.getQuery(); - LOGGER.debug("getFlightInfoStatement: {}", sql); + + // Validate query length + if (sql == null || sql.trim().isEmpty()) { + throw CallStatus.INVALID_ARGUMENT.withDescription("Empty SQL query").toRuntimeException(); + } + if (sql.length() > MAX_QUERY_LENGTH) { + throw CallStatus.INVALID_ARGUMENT + .withDescription("Query exceeds maximum length of " + MAX_QUERY_LENGTH + " characters") + .toRuntimeException(); + } IClientSession session = getSessionFromContext(context); + // Log query for audit (truncate if too long) + LOGGER.info( + "Executing query for user {}: {}", + session.getUsername(), + sql.substring(0, Math.min(sql.length(), 200))); + Long queryId = null; try { queryId = sessionManager.requestQueryId(); @@ -142,9 +179,10 @@ public FlightInfo getFlightInfoStatement( IQueryExecution queryExecution = coordinator.getQueryExecution(queryId); if (queryExecution == null) { - throw CallStatus.INTERNAL - .withDescription("Query execution not found after execution") - .toRuntimeException(); + // Non-query statements (USE, CREATE, INSERT, etc.) don't produce a query execution. + // Return an empty FlightInfo with no endpoints. + return new FlightInfo( + new Schema(Collections.emptyList()), descriptor, Collections.emptyList(), 0, 0); } DatasetHeader header = queryExecution.getDatasetHeader(); @@ -153,20 +191,28 @@ public FlightInfo getFlightInfoStatement( // Store the query context for later getStream calls activeQueries.put(queryId, new QueryContext(queryExecution, header, session)); - // Build ticket containing the queryId - byte[] ticketBytes = Long.toString(queryId).getBytes(StandardCharsets.UTF_8); - Ticket ticket = new Ticket(ticketBytes); + // Build ticket as a serialized TicketStatementQuery protobuf. + // The FlightSqlProducer base class's getStream() unpacks tickets as Any + // and dispatches to getStreamStatement(). + ByteString handle = ByteString.copyFromUtf8(Long.toString(queryId)); + FlightSql.TicketStatementQuery ticketQuery = + FlightSql.TicketStatementQuery.newBuilder().setStatementHandle(handle).build(); + Ticket ticket = new Ticket(Any.pack(ticketQuery).toByteArray()); FlightEndpoint endpoint = new FlightEndpoint(ticket); return new FlightInfo(arrowSchema, descriptor, Collections.singletonList(endpoint), -1, -1); - } catch (RuntimeException e) { + } catch (Exception e) { // Cleanup on error + LOGGER.error("Error executing query: {}", sql, e); if (queryId != null) { coordinator.cleanupQueryExecution(queryId); - activeQueries.remove(queryId); + activeQueries.invalidate(queryId); } - throw e; + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } + throw CallStatus.INTERNAL.withDescription(e.getMessage()).toRuntimeException(); } } @@ -184,17 +230,26 @@ public void getStreamStatement( CallContext context, ServerStreamListener listener) { ByteString handle = ticketQuery.getStatementHandle(); - long queryId = Long.parseLong(handle.toStringUtf8()); + long queryId; + try { + queryId = Long.parseLong(handle.toStringUtf8()); + } catch (NumberFormatException e) { + listener.error( + CallStatus.INVALID_ARGUMENT + .withDescription("Invalid statement handle: " + handle.toStringUtf8()) + .toRuntimeException()); + return; + } streamQueryResults(queryId, listener); } /** Streams query results for a given queryId as Arrow VectorSchemaRoot batches. */ private void streamQueryResults(long queryId, ServerStreamListener listener) { - QueryContext ctx = activeQueries.get(queryId); + QueryContext ctx = activeQueries.getIfPresent(queryId); if (ctx == null) { listener.error( CallStatus.NOT_FOUND - .withDescription("Query not found for id: " + queryId) + .withDescription("Query not found or expired: " + queryId) .toRuntimeException()); return; } @@ -220,9 +275,15 @@ private void streamQueryResults(long queryId, ServerStreamListener listener) { } catch (IoTDBException e) { LOGGER.error("Error streaming query results for queryId={}", queryId, e); listener.error(CallStatus.INTERNAL.withDescription(e.getMessage()).toRuntimeException()); + } catch (Exception e) { + LOGGER.error("Unexpected error streaming query results for queryId={}", queryId, e); + listener.error( + CallStatus.INTERNAL + .withDescription("Internal error: " + e.getMessage()) + .toRuntimeException()); } finally { coordinator.cleanupQueryExecution(queryId); - activeQueries.remove(queryId); + activeQueries.invalidate(queryId); if (root != null) { root.close(); } @@ -503,14 +564,14 @@ public void getStreamCrossReference( @Override public void close() throws Exception { // Clean up all active queries - for (Long queryId : activeQueries.keySet()) { + for (Long queryId : activeQueries.asMap().keySet()) { try { coordinator.cleanupQueryExecution(queryId); } catch (Exception e) { LOGGER.warn("Error cleaning up query {} during shutdown", queryId, e); } } - activeQueries.clear(); + activeQueries.invalidateAll(); } // ===================== Inner Classes ===================== diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java index 396f230d0b91c..ca5532f13f1ca 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java @@ -122,7 +122,10 @@ public static void fillVectorSchemaRoot( for (int colIdx = 0; colIdx < columnNames.size(); colIdx++) { String colName = columnNames.get(colIdx); - Integer sourceIdx = headerMap.get(colName); + int sourceIdx = + (headerMap != null && headerMap.containsKey(colName)) + ? headerMap.get(colName) + : colIdx; Column column = tsBlock.getColumn(sourceIdx); TSDataType dataType = dataTypes.get(colIdx); FieldVector fieldVector = root.getVector(colIdx); diff --git a/integration-test/pom.xml b/integration-test/pom.xml index 4ef730957e4cd..06382ea991321 100644 --- a/integration-test/pom.xml +++ b/integration-test/pom.xml @@ -246,6 +246,13 @@ provided + + org.apache.iotdb + flight-sql + 2.0.7-SNAPSHOT + + provided + org.apache.arrow flight-sql diff --git a/integration-test/src/assembly/mpp-share.xml b/integration-test/src/assembly/mpp-share.xml index 70072e8282ec6..74de1ae8b2377 100644 --- a/integration-test/src/assembly/mpp-share.xml +++ b/integration-test/src/assembly/mpp-share.xml @@ -39,5 +39,9 @@ ${project.basedir}/../external-service-impl/rest/target/rest-${project.version}-jar-with-dependencies.jar lib + + ${project.basedir}/../external-service-impl/flight-sql/target/flight-sql-${project.version}-jar-with-dependencies.jar + lib + diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java index d931ca7799944..94d531c4d3d85 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java @@ -62,6 +62,7 @@ public class IoTDBArrowFlightSqlIT { private static final String DATABASE = "flightsql_test_db"; + private static final String TABLE = DATABASE + ".test_table"; private static final String USER = "root"; private static final String PASSWORD = "root"; @@ -77,27 +78,17 @@ public void setUp() throws Exception { baseEnv.getConfig().getDataNodeConfig().setEnableArrowFlightSqlService(true); baseEnv.initClusterEnvironment(); - // Get the Flight SQL port from the data node int port = EnvFactory.getEnv().getArrowFlightSqlPort(); - - // Create Arrow allocator and Flight client with Bearer token auth middleware allocator = new RootAllocator(Long.MAX_VALUE); Location location = Location.forGrpcInsecure("127.0.0.1", port); - // The ClientIncomingAuthHeaderMiddleware captures the Bearer token from the - // auth handshake ClientIncomingAuthHeaderMiddleware.Factory authFactory = new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); - flightClient = FlightClient.builder(allocator, location).intercept(authFactory).build(); - - // Authenticate: sends Basic credentials, server returns Bearer token - bearerToken = new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); - - // Wrap in FlightSqlClient for Flight SQL protocol operations flightSqlClient = new FlightSqlClient(flightClient); + bearerToken = new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); - // Use the standard session to create the test database and table with data + // Create test data via native session (not Flight SQL) try (ITableSession session = EnvFactory.getEnv().getTableSessionConnection()) { session.executeNonQueryStatement("CREATE DATABASE IF NOT EXISTS " + DATABASE); } @@ -145,11 +136,31 @@ public void tearDown() throws Exception { EnvFactory.getEnv().cleanClusterEnvironment(); } + @Test + public void testShowDatabases() throws Exception { + FlightInfo flightInfo = flightSqlClient.execute("SHOW DATABASES", bearerToken); + + List> rows = fetchAllRows(flightInfo); + assertTrue("Should have at least 1 database", rows.size() >= 1); + + boolean found = false; + for (List row : rows) { + for (String val : row) { + if (val.contains(DATABASE)) { + found = true; + break; + } + } + } + assertTrue("Should find test database " + DATABASE, found); + } + @Test public void testQueryWithAllDataTypes() throws Exception { FlightInfo flightInfo = flightSqlClient.execute( - "SELECT time, id1, s1, s2, s3, s4, s5, s6 FROM test_table ORDER BY time", bearerToken); + "SELECT time, id1, s1, s2, s3, s4, s5, s6 FROM " + TABLE + " ORDER BY time", + bearerToken); // Validate schema Schema schema = flightInfo.getSchema(); @@ -166,7 +177,8 @@ public void testQueryWithAllDataTypes() throws Exception { public void testQueryWithFilter() throws Exception { FlightInfo flightInfo = flightSqlClient.execute( - "SELECT id1, s1 FROM test_table WHERE id1 = 'device1' ORDER BY time", bearerToken); + "SELECT id1, s1 FROM " + TABLE + " WHERE id1 = 'device1' ORDER BY time", + bearerToken); List> rows = fetchAllRows(flightInfo); assertEquals("Should have 2 rows for device1", 2, rows.size()); @@ -177,7 +189,7 @@ public void testQueryWithAggregation() throws Exception { FlightInfo flightInfo = flightSqlClient.execute( "SELECT id1, COUNT(*) as cnt, SUM(s1) as s1_sum " - + "FROM test_table GROUP BY id1 ORDER BY id1", + + "FROM " + TABLE + " GROUP BY id1 ORDER BY id1", bearerToken); List> rows = fetchAllRows(flightInfo); @@ -187,34 +199,18 @@ public void testQueryWithAggregation() throws Exception { @Test public void testEmptyResult() throws Exception { FlightInfo flightInfo = - flightSqlClient.execute("SELECT * FROM test_table WHERE id1 = 'nonexistent'", bearerToken); + flightSqlClient.execute( + "SELECT * FROM " + TABLE + " WHERE id1 = 'nonexistent'", bearerToken); List> rows = fetchAllRows(flightInfo); assertEquals("Should have 0 rows", 0, rows.size()); } - @Test - public void testShowDatabases() throws Exception { - FlightInfo flightInfo = flightSqlClient.execute("SHOW DATABASES", bearerToken); - - List> rows = fetchAllRows(flightInfo); - assertTrue("Should have at least 1 database", rows.size() >= 1); - - boolean found = false; - for (List row : rows) { - for (String val : row) { - if (val.contains(DATABASE)) { - found = true; - break; - } - } - } - assertTrue("Should find test database " + DATABASE, found); - } + // ===================== Helper Methods ===================== /** - * Fetches all rows from all endpoints in a FlightInfo. Each row is a list of string - * representations of the column values. + * Fetches all rows from all endpoints in a FlightInfo using the shared client. Each row is a list + * of string representations of the column values. */ private List> fetchAllRows(FlightInfo flightInfo) throws Exception { List> rows = new ArrayList<>(); diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBServicesIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBServicesIT.java index 014586c5e1206..c881479e83c0d 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBServicesIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBServicesIT.java @@ -67,7 +67,7 @@ public static void tearDown() throws Exception { public void testQueryResult() { String[] retArray = new String[] { - "MQTT,1,STOPPED,", "REST,1,STOPPED,", + "FLIGHT_SQL,1,STOPPED,", "MQTT,1,STOPPED,", "REST,1,STOPPED,", }; // TableModel diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java index 3fef5d79fe11a..e63fb54f66b9a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java @@ -112,6 +112,9 @@ public class IoTDBConfig { /** The Arrow Flight SQL service binding port. */ private int arrowFlightSqlPort = 8904; + /** The Arrow Flight SQL max allocator memory in bytes (default: 4GB). */ + private long arrowFlightSqlMaxAllocatorMemory = 4L * 1024 * 1024 * 1024; + /** The mqtt service binding host. */ private String mqttHost = "127.0.0.1"; @@ -2564,6 +2567,14 @@ public void setArrowFlightSqlPort(int arrowFlightSqlPort) { this.arrowFlightSqlPort = arrowFlightSqlPort; } + public long getArrowFlightSqlMaxAllocatorMemory() { + return arrowFlightSqlMaxAllocatorMemory; + } + + public void setArrowFlightSqlMaxAllocatorMemory(long arrowFlightSqlMaxAllocatorMemory) { + this.arrowFlightSqlMaxAllocatorMemory = arrowFlightSqlMaxAllocatorMemory; + } + public String getMqttHost() { return mqttHost; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java index 6730138b2af5c..3edd09100498a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java @@ -880,6 +880,9 @@ public void loadProperties(TrimProperties properties) throws BadNodeUrlException // mqtt loadMqttProps(properties); + // Arrow Flight SQL + loadArrowFlightSqlProps(properties); + conf.setIntoOperationBufferSizeInByte( Long.parseLong( properties.getProperty( @@ -1942,6 +1945,24 @@ private void loadMqttProps(TrimProperties properties) { } } + // Arrow Flight SQL related + private void loadArrowFlightSqlProps(TrimProperties properties) { + if (properties.getProperty("enable_arrow_flight_sql_service") != null) { + conf.setEnableArrowFlightSqlService( + Boolean.parseBoolean(properties.getProperty("enable_arrow_flight_sql_service").trim())); + } + + if (properties.getProperty("arrow_flight_sql_port") != null) { + conf.setArrowFlightSqlPort( + Integer.parseInt(properties.getProperty("arrow_flight_sql_port").trim())); + } + + if (properties.getProperty("arrow_flight_sql_max_allocator_memory") != null) { + conf.setArrowFlightSqlMaxAllocatorMemory( + Long.parseLong(properties.getProperty("arrow_flight_sql_max_allocator_memory").trim())); + } + } + // timed flush memtable private void loadTimedService(TrimProperties properties) throws IOException { conf.setEnableTimedFlushSeqMemtable( diff --git a/pom.xml b/pom.xml index 9312be038d3bf..aba63e83d6833 100644 --- a/pom.xml +++ b/pom.xml @@ -769,6 +769,11 @@ adbc-driver-flight-sql 0.22.0 + + org.apache.arrow + arrow-memory-unsafe + ${arrow.version} + From ef53a27af18f139291f6959f3025294ba5875600 Mon Sep 17 00:00:00 2001 From: CritasWang Date: Tue, 24 Feb 2026 17:36:01 +0800 Subject: [PATCH 2/2] feat(flight-sql): add per-client session isolation and security hardening - Add x-flight-sql-client-id header support for per-client USE database isolation via FlightSqlAuthHandler and ClientIdMiddlewareFactory - Use \0 (null byte) delimiter in clientSessionCache key to prevent username/clientId collision attacks - Validate clientId: alphanumeric + dash only, max 64 chars, fail-closed for non-empty invalid values (SecurityException) - Add maximumSize(1000) to tokenCache and clientSessionCache to prevent resource exhaustion from arbitrary clientIds - Remove LoginLockManager (userId=-1L caused cross-user lock collision; getUserId() is blocking RPC incompatible with directExecutor()) - Remove unused flightClient field from IT - Add directExecutor() + HTTP/2 flow control window tuning (1MB) on NettyServerBuilder to fix end-of-stream mid-frame errors - Document all functional gaps vs SessionManager.login() (password expiration, login lock, checkUser cache-miss risk) Tests (9/9 pass): - 5 original Flight SQL query tests - testUseDbSessionPersistence: USE context persists across connections - testUseDbWithFullyQualifiedFallback: USE + qualified/unqualified queries - testUseDbIsolationAcrossClients: Client B fails without USE context - testInvalidClientIdRejected: non-empty invalid clientId rejected --- .../iotdb/flight/FlightSqlAuthHandler.java | 5 +- .../apache/iotdb/flight/FlightSqlService.java | 8 +- .../iotdb/flight/FlightSqlSessionManager.java | 126 +++++++++--- .../iotdb/flight/TsBlockToArrowConverter.java | 4 +- .../it/flightsql/IoTDBArrowFlightSqlIT.java | 185 ++++++++++++++++-- 5 files changed, 274 insertions(+), 54 deletions(-) diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java index c22d935d41aff..66da06f00d22e 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlAuthHandler.java @@ -93,9 +93,10 @@ public AuthResult authenticate(CallHeaders headers) { String username = decoded.substring(0, colonIdx); String password = decoded.substring(colonIdx + 1); - LOGGER.debug("Validating credentials for user: {}", username); + String clientId = headers.get("x-flight-sql-client-id"); + LOGGER.debug("Validating credentials for user: {}, clientId: {}", username, clientId); try { - String token = sessionManager.authenticate(username, password, "unknown"); + String token = sessionManager.authenticate(username, password, "unknown", clientId); return bearerResult(token); } catch (SecurityException e) { throw CallStatus.UNAUTHENTICATED.withDescription(e.getMessage()).toRuntimeException(); diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java index 9ed91d503a12d..c775b84f73c67 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlService.java @@ -86,11 +86,9 @@ public synchronized void start() { flightServer = FlightServer.builder(allocator, location, producer) .headerAuthenticator(authHandler) - // Configure Netty server for DataNode JVM environment: - // - directExecutor: run gRPC handlers in the Netty event loop thread to - // avoid thread scheduling issues with the default executor - // - flowControlWindow: explicit HTTP/2 flow control prevents framing errors - // when standalone Netty JARs coexist on the classpath + // directExecutor: run gRPC handlers in the Netty event loop thread to + // avoid thread scheduling issues with the default executor that cause + // "end-of-stream mid-frame" errors on subsequent RPCs. .transportHint( "grpc.builderConsumer", (java.util.function.Consumer) diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java index 0178dc2f1fe12..825c930f1b765 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/FlightSqlSessionManager.java @@ -34,7 +34,6 @@ import org.slf4j.LoggerFactory; import java.security.SecureRandom; -import java.time.ZoneId; import java.util.Base64; import java.util.concurrent.TimeUnit; @@ -48,18 +47,32 @@ public class FlightSqlSessionManager { private static final String AUTHORIZATION_HEADER = "authorization"; private static final String BEARER_PREFIX = "Bearer "; private static final SecureRandom SECURE_RANDOM = new SecureRandom(); + private static final int MAX_CLIENT_ID_LENGTH = 64; + private static final int MAX_SESSIONS = 1000; + private static final java.util.regex.Pattern CLIENT_ID_PATTERN = + java.util.regex.Pattern.compile("^[a-zA-Z0-9\\-]+$"); private final SessionManager sessionManager = SessionManager.getInstance(); /** Cache of Bearer token -> IClientSession with configurable TTL. */ private final Cache tokenCache; - /** Cache of username -> Bearer token for session reuse with Basic auth on every call. */ - private final Cache userTokenCache; + /** + * Cache of (username@clientId) -> Bearer token for session reuse. Avoids repeated session + * creation on every RPC — necessary because the Arrow Flight client middleware does not always + * cache the Bearer token, causing Basic auth to be re-sent on every call. + * + *

Keyed by {@code username@clientId} where clientId comes from the {@code + * x-flight-sql-client-id} header. This ensures different logical clients (even with the same + * username) get independent sessions with separate USE database contexts. If no clientId header + * is present, falls back to username-only keying (shared session). + */ + private final Cache clientSessionCache; public FlightSqlSessionManager(long sessionTimeoutMinutes) { this.tokenCache = Caffeine.newBuilder() + .maximumSize(MAX_SESSIONS) .expireAfterAccess(sessionTimeoutMinutes, TimeUnit.MINUTES) .removalListener( (String token, IClientSession session, RemovalCause cause) -> { @@ -78,8 +91,11 @@ public FlightSqlSessionManager(long sessionTimeoutMinutes) { } }) .build(); - this.userTokenCache = - Caffeine.newBuilder().expireAfterAccess(sessionTimeoutMinutes, TimeUnit.MINUTES).build(); + this.clientSessionCache = + Caffeine.newBuilder() + .maximumSize(MAX_SESSIONS) + .expireAfterAccess(sessionTimeoutMinutes, TimeUnit.MINUTES) + .build(); } /** @@ -87,46 +103,77 @@ public FlightSqlSessionManager(long sessionTimeoutMinutes) { * * @param username the username * @param password the password - * @param clientAddress the client's IP address + * @param clientAddress the client's IP address (for logging) + * @param clientId optional client identifier from x-flight-sql-client-id header (may be null) * @return the Bearer token if authentication succeeds * @throws SecurityException if authentication fails */ - public String authenticate(String username, String password, String clientAddress) { - // Check if this user already has an active session (reuse it) - String existingToken = userTokenCache.getIfPresent(username); - if (existingToken != null && tokenCache.getIfPresent(existingToken) != null) { - return existingToken; - } - - // Verify credentials (REST pattern) + public String authenticate( + String username, String password, String clientAddress, String clientId) { + // NOTE: We intentionally do NOT call SessionManager.login() here because it performs + // blocking I/O that is incompatible with directExecutor() on the Netty event loop: + // - DataNodeAuthUtils.checkPasswordExpiration: executes SELECT via Coordinator + // - AuthorityChecker.getUserId: sends RPC to ConfigNode on cache miss + // Blocking the event loop corrupts HTTP/2 connection state and causes "end-of-stream + // mid-frame" errors on subsequent RPCs. + // + // Functional gaps vs SessionManager.login(): + // - Password expiration checks (requires Coordinator query) + // - Login lock / brute-force protection (LoginLockManager is in-memory but keys + // by userId; AuthorityChecker.getUserId() is a blocking RPC, so we cannot obtain + // a correct userId without risking event loop stalls) + // + // Risk: AuthorityChecker.checkUser() may perform a one-time blocking RPC to ConfigNode + // on cache miss (ClusterAuthorityFetcher.login). After the first successful auth, the + // credential is cached locally, and clientSessionCache avoids repeated authenticate() + // calls for the same client. + // + // TODO: Support password expiration and login lock. This requires either: + // (a) async auth support in Arrow Flight (not yet available), or + // (b) resolving the Netty classpath conflict so directExecutor() is no longer needed. + + // Always verify credentials — never skip password verification even if a cached + // session exists for this client. + org.apache.iotdb.common.rpc.thrift.TSStatus status; try { - org.apache.iotdb.common.rpc.thrift.TSStatus status = - AuthorityChecker.checkUser(username, password); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - LOGGER.warn("Authentication failed for client: {}", clientAddress); - throw new SecurityException("Authentication failed: wrong username or password"); - } - } catch (SecurityException e) { - throw e; + status = AuthorityChecker.checkUser(username, password); } catch (Exception e) { throw new SecurityException("Authentication failed", e); } + if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.warn("Authentication failed for client: {}", clientAddress); + throw new SecurityException("Authentication failed: wrong username or password"); + } + + // Reuse existing session for this client. + // Key uses \0 (null byte) delimiter — cannot appear in usernames or HTTP headers, + // so the mapping (username, clientId) -> cacheKey is injective (no collisions). + String validClientId = validateClientId(clientId); + String cacheKey = validClientId != null ? username + "\0" + validClientId : username; + String existingToken = clientSessionCache.getIfPresent(cacheKey); + if (existingToken != null && tokenCache.getIfPresent(existingToken) != null) { + return existingToken; + } - // Create and register session (REST pattern) + // Create session. Do NOT call registerSession() — it sets a ThreadLocal (currSession) + // designed for the client-thread model (Thrift). gRPC with directExecutor() runs all + // handlers on the Netty event loop, so ThreadLocal-based session tracking would pollute. IClientSession session = new InternalClientSession("FlightSQL-" + clientAddress); session.setSqlDialect(IClientSession.SqlDialect.TABLE); - sessionManager.registerSession(session); - - long userId = AuthorityChecker.getUserId(username).orElse(-1L); + // Pass -1L for userId — getUserId() sends blocking RPC to ConfigNode. sessionManager.supplySession( - session, userId, username, ZoneId.systemDefault(), IoTDBConstant.ClientVersion.V_1_0); + session, + -1L, + username, + java.time.ZoneId.systemDefault(), + IoTDBConstant.ClientVersion.V_1_0); // Generate cryptographically secure Bearer token (32 bytes = 256 bits) byte[] tokenBytes = new byte[32]; SECURE_RANDOM.nextBytes(tokenBytes); String token = Base64.getUrlEncoder().withoutPadding().encodeToString(tokenBytes); tokenCache.put(token, session); - userTokenCache.put(username, token); + clientSessionCache.put(cacheKey, token); LOGGER.info("Flight SQL authentication successful for client: {}", clientAddress); return token; } @@ -162,6 +209,29 @@ public IClientSession getSessionByToken(String token) { return session; } + /** + * Validates the client ID from the x-flight-sql-client-id header. Returns the validated clientId, + * or null if the header was absent (null/empty). Non-empty invalid clientIds are rejected + * (fail-closed) to prevent silent fallback to shared username-only sessions, which would break + * USE database isolation. + * + * @throws SecurityException if clientId is non-empty but invalid (too long or bad characters) + */ + private static String validateClientId(String clientId) { + if (clientId == null || clientId.isEmpty()) { + return null; + } + if (clientId.length() > MAX_CLIENT_ID_LENGTH) { + throw new SecurityException( + "Client ID exceeds maximum length of " + MAX_CLIENT_ID_LENGTH + " characters"); + } + if (!CLIENT_ID_PATTERN.matcher(clientId).matches()) { + throw new SecurityException( + "Client ID contains invalid characters (only alphanumeric and dash allowed)"); + } + return clientId; + } + /** Invalidates all sessions and cleans up resources. */ public void close() { tokenCache.invalidateAll(); diff --git a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java index ca5532f13f1ca..84c6984d9a688 100644 --- a/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java +++ b/external-service-impl/flight-sql/src/main/java/org/apache/iotdb/flight/TsBlockToArrowConverter.java @@ -123,9 +123,7 @@ public static void fillVectorSchemaRoot( for (int colIdx = 0; colIdx < columnNames.size(); colIdx++) { String colName = columnNames.get(colIdx); int sourceIdx = - (headerMap != null && headerMap.containsKey(colName)) - ? headerMap.get(colName) - : colIdx; + (headerMap != null && headerMap.containsKey(colName)) ? headerMap.get(colName) : colIdx; Column column = tsBlock.getColumn(sourceIdx); TSDataType dataType = dataTypes.get(colIdx); FieldVector fieldVector = root.getVector(colIdx); diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java index 94d531c4d3d85..625337ab4fa93 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/flightsql/IoTDBArrowFlightSqlIT.java @@ -25,7 +25,11 @@ import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; import org.apache.iotdb.itbase.env.BaseEnv; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightClientMiddleware; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; @@ -49,6 +53,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.UUID; import static org.junit.Assert.*; @@ -66,8 +71,8 @@ public class IoTDBArrowFlightSqlIT { private static final String USER = "root"; private static final String PASSWORD = "root"; + private String clientId; private BufferAllocator allocator; - private FlightClient flightClient; private FlightSqlClient flightSqlClient; private CredentialCallOption bearerToken; @@ -82,10 +87,8 @@ public void setUp() throws Exception { allocator = new RootAllocator(Long.MAX_VALUE); Location location = Location.forGrpcInsecure("127.0.0.1", port); - ClientIncomingAuthHeaderMiddleware.Factory authFactory = - new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); - flightClient = FlightClient.builder(allocator, location).intercept(authFactory).build(); - flightSqlClient = new FlightSqlClient(flightClient); + clientId = UUID.randomUUID().toString(); + flightSqlClient = createFlightSqlClient(clientId); bearerToken = new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); // Create test data via native session (not Flight SQL) @@ -123,13 +126,6 @@ public void tearDown() throws Exception { // ignore } } - if (flightClient != null) { - try { - flightClient.close(); - } catch (Exception e) { - // ignore - } - } if (allocator != null) { allocator.close(); } @@ -177,8 +173,7 @@ public void testQueryWithAllDataTypes() throws Exception { public void testQueryWithFilter() throws Exception { FlightInfo flightInfo = flightSqlClient.execute( - "SELECT id1, s1 FROM " + TABLE + " WHERE id1 = 'device1' ORDER BY time", - bearerToken); + "SELECT id1, s1 FROM " + TABLE + " WHERE id1 = 'device1' ORDER BY time", bearerToken); List> rows = fetchAllRows(flightInfo); assertEquals("Should have 2 rows for device1", 2, rows.size()); @@ -189,7 +184,9 @@ public void testQueryWithAggregation() throws Exception { FlightInfo flightInfo = flightSqlClient.execute( "SELECT id1, COUNT(*) as cnt, SUM(s1) as s1_sum " - + "FROM " + TABLE + " GROUP BY id1 ORDER BY id1", + + "FROM " + + TABLE + + " GROUP BY id1 ORDER BY id1", bearerToken); List> rows = fetchAllRows(flightInfo); @@ -206,6 +203,111 @@ public void testEmptyResult() throws Exception { assertEquals("Should have 0 rows", 0, rows.size()); } + @Test + public void testUseDbSessionPersistence() throws Exception { + // Connection 1: USE database (same clientId shares the session) + flightSqlClient.execute("USE " + DATABASE, bearerToken); + + // Connection 2: query without fully-qualified table name. + // Same clientId ensures the same session is reused, so USE context persists. + FlightSqlClient client2 = createFlightSqlClient(clientId); + try { + CredentialCallOption token2 = + new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); + FlightInfo flightInfo = client2.execute("SELECT * FROM test_table ORDER BY time", token2); + List> rows = fetchAllRows(flightInfo, client2, token2); + assertEquals("Should have 3 rows from unqualified query after USE", 3, rows.size()); + } finally { + client2.close(); + } + } + + @Test + public void testUseDbWithFullyQualifiedFallback() throws Exception { + // Connection 1: USE database + flightSqlClient.execute("USE " + DATABASE, bearerToken); + + // Connection 2: unqualified query (same clientId → same session) + FlightSqlClient client2 = createFlightSqlClient(clientId); + try { + CredentialCallOption token2 = + new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); + FlightInfo infoUnqualified = + client2.execute("SELECT * FROM test_table ORDER BY time", token2); + List> rowsUnqualified = fetchAllRows(infoUnqualified, client2, token2); + assertEquals("Unqualified query should return 3 rows", 3, rowsUnqualified.size()); + } finally { + client2.close(); + } + + // Connection 3: fully-qualified query + FlightSqlClient client3 = createFlightSqlClient(clientId); + try { + CredentialCallOption token3 = + new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); + FlightInfo infoQualified = + client3.execute("SELECT * FROM " + TABLE + " ORDER BY time", token3); + List> rowsQualified = fetchAllRows(infoQualified, client3, token3); + assertEquals("Fully-qualified query should also return 3 rows", 3, rowsQualified.size()); + } finally { + client3.close(); + } + } + + @Test + public void testUseDbIsolationAcrossClients() throws Exception { + // Client A (clientId from setUp): USE DATABASE + flightSqlClient.execute("USE " + DATABASE, bearerToken); + + // Client B (different clientId): gets its own independent session with NO USE context. + // Querying an unqualified table name should fail because no database is selected. + String clientIdB = UUID.randomUUID().toString(); + FlightSqlClient clientB = createFlightSqlClient(clientIdB); + CredentialCallOption tokenB = + new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); + try { + clientB.execute("SELECT * FROM test_table", tokenB); + fail("Client B should fail on unqualified table query without USE"); + } catch (Exception expected) { + // Expected: Client B has no database context, so unqualified table query fails. + // Arrow Flight wraps the actual error, so we just verify the query did fail. + assertNotNull("Exception should have a message", expected.getMessage()); + } finally { + clientB.close(); + } + + // Client A's USE context is preserved (same clientId → same session) + FlightSqlClient clientA2 = createFlightSqlClient(clientId); + CredentialCallOption tokenA2 = + new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); + try { + FlightInfo infoA = clientA2.execute("SELECT * FROM test_table ORDER BY time", tokenA2); + List> rowsA = fetchAllRows(infoA, clientA2, tokenA2); + assertEquals("Client A should still see 3 rows after Client B's queries", 3, rowsA.size()); + } finally { + clientA2.close(); + } + } + + @Test + public void testInvalidClientIdRejected() throws Exception { + // A non-empty clientId with invalid characters (contains @) should be rejected (fail-closed). + // Only null/empty clientId should fall back to shared session keying. + String invalidClientId = "bad@client!id"; + FlightSqlClient invalidClient = createFlightSqlClient(invalidClientId); + CredentialCallOption token = + new CredentialCallOption(new BasicAuthCredentialWriter(USER, PASSWORD)); + try { + invalidClient.execute("SHOW DATABASES", token); + fail("Server should reject invalid clientId during authentication"); + } catch (Exception expected) { + // Expected: server rejects the invalid clientId + assertNotNull("Exception should have a message", expected.getMessage()); + } finally { + invalidClient.close(); + } + } + // ===================== Helper Methods ===================== /** @@ -213,9 +315,14 @@ public void testEmptyResult() throws Exception { * of string representations of the column values. */ private List> fetchAllRows(FlightInfo flightInfo) throws Exception { + return fetchAllRows(flightInfo, flightSqlClient, bearerToken); + } + + private List> fetchAllRows( + FlightInfo flightInfo, FlightSqlClient client, CredentialCallOption token) throws Exception { List> rows = new ArrayList<>(); for (FlightEndpoint endpoint : flightInfo.getEndpoints()) { - try (FlightStream stream = flightSqlClient.getStream(endpoint.getTicket(), bearerToken)) { + try (FlightStream stream = client.getStream(endpoint.getTicket(), token)) { while (stream.next()) { VectorSchemaRoot root = stream.getRoot(); int rowCount = root.getRowCount(); @@ -232,4 +339,50 @@ private List> fetchAllRows(FlightInfo flightInfo) throws Exception } return rows; } + + private FlightSqlClient createFlightSqlClient(String flightClientId) { + int port = EnvFactory.getEnv().getArrowFlightSqlPort(); + Location location = Location.forGrpcInsecure("127.0.0.1", port); + ClientIncomingAuthHeaderMiddleware.Factory authFactory = + new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); + FlightClient client = + FlightClient.builder(allocator, location) + .intercept(authFactory) + .intercept(new ClientIdMiddlewareFactory(flightClientId)) + .build(); + return new FlightSqlClient(client); + } + + /** + * FlightClientMiddleware that injects the x-flight-sql-client-id header on every call. This + * allows the server to key sessions per logical client, enabling per-client USE database + * isolation. + */ + private static class ClientIdMiddlewareFactory implements FlightClientMiddleware.Factory { + private final String flightClientId; + + ClientIdMiddlewareFactory(String flightClientId) { + this.flightClientId = flightClientId; + } + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + return new FlightClientMiddleware() { + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert("x-flight-sql-client-id", flightClientId); + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + // no-op + } + + @Override + public void onCallCompleted(CallStatus status) { + // no-op + } + }; + } + } }