diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index fc96f2f389c..c25b027e4ec 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -106,6 +106,7 @@ import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.Transpose; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.Union; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; @@ -897,6 +898,11 @@ public LogicalPlan visitMultisearch(Multisearch node, AnalysisContext context) { throw getOnlyForCalciteException("Multisearch"); } + @Override + public LogicalPlan visitUnion(Union node, AnalysisContext context) { + throw getOnlyForCalciteException("Union"); + } + private LogicalSort buildSort( LogicalPlan child, AnalysisContext context, Integer count, List sortFields) { ExpressionReferenceOptimizer optimizer = diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 7f02bb3ef1b..be02547a2da 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -93,6 +93,7 @@ import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.Transpose; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.Union; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; @@ -472,6 +473,10 @@ public T visitMultisearch(Multisearch node, C context) { return visitChildren(node, context); } + public T visitUnion(Union node, C context) { + return visitChildren(node, context); + } + public T visitAddTotals(AddTotals node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Union.java b/core/src/main/java/org/opensearch/sql/ast/tree/Union.java new file mode 100644 index 00000000000..a96831567cb --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Union.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + +/** Logical plan node for Union operation. Combines results from multiple datasets (UNION ALL). */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +@AllArgsConstructor +public class Union extends UnresolvedPlan { + private final List datasets; + + private Integer maxout; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + List newDatasets = + ImmutableList.builder().add(child).addAll(datasets).build(); + return new Union(newDatasets, maxout); + } + + @Override + public List getChild() { + return datasets; + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitUnion(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index e4e036da3a6..8b9210290fb 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -159,6 +159,7 @@ import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Trendline.TrendlineType; +import org.opensearch.sql.ast.tree.Union; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; @@ -2607,6 +2608,40 @@ private String findTimestampField(RelDataType rowType) { return null; } + @Override + public RelNode visitUnion(Union node, CalcitePlanContext context) { + List inputNodes = new ArrayList<>(); + + for (UnresolvedPlan dataset : node.getDatasets()) { + UnresolvedPlan prunedDataset = dataset.accept(new EmptySourcePropagateVisitor(), null); + prunedDataset.accept(this, context); + inputNodes.add(context.relBuilder.build()); + } + + if (inputNodes.size() < 2) { + throw new IllegalArgumentException( + "Union command requires at least two datasets. Provided: " + inputNodes.size()); + } + + List unifiedInputs = + SchemaUnifier.buildUnifiedSchemaWithTypeCoercion(inputNodes, context); + + for (RelNode input : unifiedInputs) { + context.relBuilder.push(input); + } + context.relBuilder.union(true, unifiedInputs.size()); // true = UNION ALL + + if (node.getMaxout() != null) { + context.relBuilder.push( + LogicalSystemLimit.create( + LogicalSystemLimit.SystemLimitType.SUBSEARCH_MAXOUT, + context.relBuilder.build(), + context.relBuilder.literal(node.getMaxout()))); + } + + return context.relBuilder.peek(); + } + /* * Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta */ diff --git a/core/src/main/java/org/opensearch/sql/calcite/SchemaUnifier.java b/core/src/main/java/org/opensearch/sql/calcite/SchemaUnifier.java index 05380ce8c48..e01cbe3992d 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/SchemaUnifier.java +++ b/core/src/main/java/org/opensearch/sql/calcite/SchemaUnifier.java @@ -14,10 +14,16 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.SqlTypeName; /** - * Utility class for unifying schemas across multiple RelNodes. Throws an exception when type - * conflicts are detected. + * Utility class for unifying schemas across multiple RelNodes. Supports two strategies: + * + *
    + *
  • Conflict resolution (multisearch): throws on type mismatch, fills missing fields with NULL + *
  • Type coercion (union): widens compatible types (e.g. INTEGER→BIGINT), falls back to VARCHAR + * for incompatible types, fills missing fields with NULL + *
*/ public class SchemaUnifier { @@ -147,4 +153,236 @@ RelDataType getType() { return type; } } + + /** + * Builds unified schema with type coercion for UNION command. Coerces compatible types to a + * common supertype (e.g. int+float→float), falls back to VARCHAR for incompatible types, and + * fills missing fields with NULL. + */ + public static List buildUnifiedSchemaWithTypeCoercion( + List inputs, CalcitePlanContext context) { + if (inputs.isEmpty() || inputs.size() == 1) { + return inputs; + } + + List coercedInputs = coerceUnionTypes(inputs, context); + return unifySchemasForUnion(coercedInputs, context); + } + + /** + * Aligns schemas by projecting NULL for missing fields and CAST for type mismatches. Uses + * force=true to clear collation traits and prevent EnumerableMergeUnion cast exception. + */ + private static List unifySchemasForUnion( + List inputs, CalcitePlanContext context) { + List unifiedSchema = buildUnifiedSchemaForUnion(inputs); + List fieldNames = + unifiedSchema.stream().map(SchemaField::getName).collect(Collectors.toList()); + + List projectedNodes = new ArrayList<>(); + for (RelNode node : inputs) { + List projection = buildProjectionForUnion(node, unifiedSchema, context); + RelNode projectedNode = + context.relBuilder.push(node).project(projection, fieldNames, true).build(); + projectedNodes.add(projectedNode); + } + return projectedNodes; + } + + private static List buildUnifiedSchemaForUnion(List nodes) { + List schema = new ArrayList<>(); + Map seenFields = new HashMap<>(); + + for (RelNode node : nodes) { + for (RelDataTypeField field : node.getRowType().getFieldList()) { + if (!seenFields.containsKey(field.getName())) { + schema.add(new SchemaField(field.getName(), field.getType())); + seenFields.put(field.getName(), field.getType()); + } + } + } + return schema; + } + + private static List buildProjectionForUnion( + RelNode node, List unifiedSchema, CalcitePlanContext context) { + Map nodeFieldMap = + node.getRowType().getFieldList().stream() + .collect(Collectors.toMap(RelDataTypeField::getName, field -> field)); + + List projection = new ArrayList<>(); + for (SchemaField schemaField : unifiedSchema) { + RelDataTypeField nodeField = nodeFieldMap.get(schemaField.getName()); + + if (nodeField != null) { + RexNode fieldRef = context.rexBuilder.makeInputRef(node, nodeField.getIndex()); + if (!nodeField.getType().equals(schemaField.getType())) { + projection.add(context.rexBuilder.makeCast(schemaField.getType(), fieldRef)); + } else { + projection.add(fieldRef); + } + } else { + projection.add(context.rexBuilder.makeNullLiteral(schemaField.getType())); + } + } + return projection; + } + + /** Casts fields to their common supertypes across all inputs when types differ. */ + private static List coerceUnionTypes(List inputs, CalcitePlanContext context) { + Map> fieldTypeMap = new HashMap<>(); + for (RelNode input : inputs) { + for (RelDataTypeField field : input.getRowType().getFieldList()) { + String fieldName = field.getName(); + SqlTypeName typeName = field.getType().getSqlTypeName(); + if (typeName != null) { + fieldTypeMap.computeIfAbsent(fieldName, k -> new ArrayList<>()).add(typeName); + } + } + } + + Map targetTypeMap = new HashMap<>(); + for (Map.Entry> entry : fieldTypeMap.entrySet()) { + String fieldName = entry.getKey(); + List types = entry.getValue(); + + SqlTypeName commonType = types.getFirst(); + for (int i = 1; i < types.size(); i++) { + commonType = findCommonTypeForUnion(commonType, types.get(i)); + } + targetTypeMap.put(fieldName, commonType); + } + + boolean needsCoercion = false; + for (RelNode input : inputs) { + for (RelDataTypeField field : input.getRowType().getFieldList()) { + SqlTypeName targetType = targetTypeMap.get(field.getName()); + if (targetType != null && field.getType().getSqlTypeName() != targetType) { + needsCoercion = true; + break; + } + } + if (needsCoercion) break; + } + + if (!needsCoercion) { + return inputs; + } + + List coercedInputs = new ArrayList<>(); + for (RelNode input : inputs) { + List projections = new ArrayList<>(); + List projectionNames = new ArrayList<>(); + boolean needsProjection = false; + + for (RelDataTypeField field : input.getRowType().getFieldList()) { + String fieldName = field.getName(); + SqlTypeName currentType = field.getType().getSqlTypeName(); + SqlTypeName targetType = targetTypeMap.get(fieldName); + + RexNode fieldRef = context.rexBuilder.makeInputRef(input, field.getIndex()); + + if (currentType != targetType && targetType != null) { + projections.add(context.relBuilder.cast(fieldRef, targetType)); + needsProjection = true; + } else { + projections.add(fieldRef); + } + projectionNames.add(fieldName); + } + + if (needsProjection) { + context.relBuilder.push(input); + context.relBuilder.project(projections, projectionNames, true); + coercedInputs.add(context.relBuilder.build()); + } else { + coercedInputs.add(input); + } + } + + return coercedInputs; + } + + /** + * Returns the wider type for two SqlTypeNames. Within the same family, returns the wider type + * (e.g. INTEGER+BIGINT-->BIGINT). Across families, falls back to VARCHAR. + */ + private static SqlTypeName findCommonTypeForUnion(SqlTypeName type1, SqlTypeName type2) { + if (type1 == type2) { + return type1; + } + + if (type1 == SqlTypeName.NULL) { + return type2; + } + if (type2 == SqlTypeName.NULL) { + return type1; + } + + if (isNumericTypeForUnion(type1) && isNumericTypeForUnion(type2)) { + return getWiderNumericTypeForUnion(type1, type2); + } + + if (isStringTypeForUnion(type1) && isStringTypeForUnion(type2)) { + return SqlTypeName.VARCHAR; + } + + if (isTemporalTypeForUnion(type1) && isTemporalTypeForUnion(type2)) { + return getWiderTemporalTypeForUnion(type1, type2); + } + + return SqlTypeName.VARCHAR; + } + + private static boolean isNumericTypeForUnion(SqlTypeName typeName) { + return typeName == SqlTypeName.TINYINT + || typeName == SqlTypeName.SMALLINT + || typeName == SqlTypeName.INTEGER + || typeName == SqlTypeName.BIGINT + || typeName == SqlTypeName.FLOAT + || typeName == SqlTypeName.REAL + || typeName == SqlTypeName.DOUBLE + || typeName == SqlTypeName.DECIMAL; + } + + private static boolean isStringTypeForUnion(SqlTypeName typeName) { + return typeName == SqlTypeName.CHAR || typeName == SqlTypeName.VARCHAR; + } + + private static boolean isTemporalTypeForUnion(SqlTypeName typeName) { + return typeName == SqlTypeName.DATE + || typeName == SqlTypeName.TIMESTAMP + || typeName == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + } + + private static SqlTypeName getWiderNumericTypeForUnion(SqlTypeName type1, SqlTypeName type2) { + int rank1 = getNumericTypeRankForUnion(type1); + int rank2 = getNumericTypeRankForUnion(type2); + return rank1 >= rank2 ? type1 : type2; + } + + private static int getNumericTypeRankForUnion(SqlTypeName typeName) { + return switch (typeName) { + case TINYINT -> 1; + case SMALLINT -> 2; + case INTEGER -> 3; + case BIGINT -> 4; + case DECIMAL -> 5; + case REAL -> 6; + case FLOAT -> 7; + case DOUBLE -> 8; + default -> 0; + }; + } + + private static SqlTypeName getWiderTemporalTypeForUnion(SqlTypeName type1, SqlTypeName type2) { + if (type1 == SqlTypeName.TIMESTAMP || type2 == SqlTypeName.TIMESTAMP) { + return SqlTypeName.TIMESTAMP; + } + if (type1 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE + || type2 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE) { + return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE; + } + return SqlTypeName.DATE; + } } diff --git a/docs/category.json b/docs/category.json index 5e9b6f954a5..2342ada464d 100644 --- a/docs/category.json +++ b/docs/category.json @@ -48,6 +48,7 @@ "user/ppl/cmd/top.md", "user/ppl/cmd/trendline.md", "user/ppl/cmd/transpose.md", + "user/ppl/cmd/union.md", "user/ppl/cmd/where.md", "user/ppl/functions/aggregations.md", "user/ppl/functions/collection.md", diff --git a/docs/user/ppl/cmd/union.md b/docs/user/ppl/cmd/union.md new file mode 100644 index 00000000000..8c148b998c2 --- /dev/null +++ b/docs/user/ppl/cmd/union.md @@ -0,0 +1,197 @@ + +# union + +The `union` command combines results from multiple datasets using UNION ALL semantics. It merges rows from two or more sources into a single result set, preserving all rows including duplicates. You can optionally apply subsequent processing, such as aggregation or sorting, to the combined results. Each dataset can be a subsearch with different filtering criteria, data transformations, and field selections, or a direct index reference. + +Union is particularly useful for combining data from multiple sources, creating comprehensive datasets from different criteria, and consolidating results while handling schema differences through automatic type coercion. + +Use union for: + +* **Multi-source data combination**: Merge data from different indexes or apply different filters to the same source. +* **Dataset consolidation**: Combine results from different queries while preserving all rows including duplicates. +* **Flexible dataset patterns**: Use subsearches or direct index references with optional maxout control. +* **Schema unification**: Automatically handle different schemas with type coercion for conflicting field types and NULL-fill for missing fields. + +## Syntax + +The `union` command has the following syntax: + +```syntax +union [maxout=] [ ...] +``` + +Each dataset can be: +- **Direct index reference**: `index_name`, `index_pattern*`, `index_alias` +- **Subsearch**: `[search source=index | ]` + +The following are examples of the `union` command syntax: + +```syntax +| union logs-*, security-logs +| union [search source=accounts | where age > 30], [search source=accounts | where age < 30] +| union maxout=100 [search source=logs | fields user, action], [search source=events | fields user, action] +| union [search source=accounts | where status="active"], [search source=accounts | where status="pending"] +``` + +## Parameters + +The `union` command supports the following parameters. + +| Parameter | Required/Optional | Description | +| --- | --- | --- | +| `maxout` | Optional | Maximum number of results to return from the union operation. Default: unlimited (0). | +| `` | Required | At least two datasets are required. Each dataset can be either a subsearch enclosed in square brackets (`[search source=index | ]`) or a direct index reference (for example, `accounts`, `logs-*`). All PPL commands are supported within subsearches. | +| `` | Optional | Commands applied to the merged results after the union operation (for example, `stats`, `sort`, or `head`). | + +## Example 1: Combining age groups for demographic analysis + +This example demonstrates how to merge customers from different age segments into a unified dataset. It combines `young` and `adult` customers into a single result set and adds categorization labels for further analysis: + +```ppl +| union [search source=accounts +| where age < 30 +| eval age_group = "young" +| fields firstname, age, age_group] [search source=accounts +| where age >= 30 +| eval age_group = "adult" +| fields firstname, age, age_group] +| sort age +``` + +The query returns the following results: + +```text +fetched rows / total rows = 4/4 ++-----------+-----+-----------+ +| firstname | age | age_group | +|-----------+-----+-----------| +| Nanette | 28 | young | +| Amber | 32 | adult | +| Dale | 33 | adult | +| Hattie | 36 | adult | ++-----------+-----+-----------+ +``` + + +## Example 2: Combining filtered subsets from the same index + +This example demonstrates how to combine multiple filtered subsets from the same index using union: + +```ppl +| union [search source=accounts | where balance > 30000] [search source=accounts | where age < 30] +| fields firstname, age, balance +| sort balance desc +``` + +The query returns the following results: + +```text +fetched rows / total rows = 3/3 ++-----------+-----+---------+ +| firstname | age | balance | +|-----------+-----+---------| +| Amber | 32 | 39225 | +| Nanette | 28 | 32838 | +| Nanette | 28 | 32838 | ++-----------+-----+---------+ +``` + +Note: Nanette appears twice because she meets both conditions (balance > 30000 AND age < 30), demonstrating UNION ALL semantics which preserve all rows including duplicates. + + +## Example 3: Mid-pipeline union (implicit first dataset) + +This example demonstrates using union mid-pipeline where the upstream result is implicitly included as the first dataset: + +```ppl +search source=accounts | where age > 30 | union [search source=accounts | where age < 30] +| fields firstname, age +| sort age +``` + +The query returns the following results: + +```text +fetched rows / total rows = 4/4 ++-----------+-----+ +| firstname | age | +|-----------+-----| +| Nanette | 28 | +| Amber | 32 | +| Dale | 33 | +| Hattie | 36 | ++-----------+-----+ +``` + +Note: The upstream result `where age > 30` is automatically the first dataset, then unioned with `where age < 30`. + + +## Example 4: Using maxout option to limit results + +This example demonstrates how to limit the total number of results returned from a union operation using the `maxout` option. Note that UNION ALL semantics preserve duplicate rows: + +```ppl +| union maxout=3 [search source=accounts +| where balance > 20000] [search source=accounts +| where age > 30] +| fields firstname, age, balance +``` + +The query returns the following results: + +```text +fetched rows / total rows = 3/3 ++-----------+-----+---------+ +| firstname | age | balance | +|-----------+-----+---------| +| Amber | 32 | 39225 | +| Nanette | 28 | 32838 | +| Amber | 32 | 39225 | ++-----------+-----+---------+ +``` + +Note: Amber appears twice because she meets both conditions (balance > 20000 AND age > 30), demonstrating UNION ALL semantics which preserve all rows including duplicates. + + +## Example 5: Segmenting accounts by balance tier + +This example demonstrates how to create account segments based on balance thresholds for comparative analysis. It separates `high_balance` accounts from `regular` accounts and labels them for easy comparison: + +```ppl +| union [search source=accounts +| where balance > 20000 +| eval query_type = "high_balance" +| fields firstname, balance, query_type] [search source=accounts +| where balance > 0 AND balance <= 20000 +| eval query_type = "regular" +| fields firstname, balance, query_type] +| sort balance desc +``` + +The query returns the following results: + +```text +fetched rows / total rows = 4/4 ++-----------+---------+--------------+ +| firstname | balance | query_type | +|-----------+---------+--------------| +| Amber | 39225 | high_balance | +| Nanette | 32838 | high_balance | +| Hattie | 5686 | regular | +| Dale | 4180 | regular | ++-----------+---------+--------------+ +``` + + +## Limitations + +The `union` command has the following limitations: + +* At least two datasets must be specified. +* When fields with the same name exist across datasets but have different types, the system automatically performs type coercion to find a common supertype: + * **Compatible numeric types** → wider numeric type (for example, `INTEGER` and `BIGINT` coerce to `BIGINT`; `INTEGER` and `FLOAT` coerce to `FLOAT`) + * **String types** → `VARCHAR` (for example, `CHAR` and `VARCHAR` coerce to `VARCHAR`) + * **Temporal types** → wider temporal type (for example, `DATE` and `TIMESTAMP` coerce to `TIMESTAMP`) + * **Incompatible types** (different type families) → `VARCHAR` fallback (for example, `INTEGER` and `VARCHAR` coerce to `VARCHAR`) +* Missing fields across datasets are automatically filled with `NULL` values to unify schemas. +* Direct index references must be valid index names, patterns, or aliases (for example, `accounts`, `logs-*`, `security-alias`). diff --git a/docs/user/ppl/index.md b/docs/user/ppl/index.md index 27f59fa4b95..37947113800 100644 --- a/docs/user/ppl/index.md +++ b/docs/user/ppl/index.md @@ -73,6 +73,7 @@ source=accounts | [appendcol command](cmd/appendcol.md) | 3.1 | experimental (since 3.1) | Append the result of a sub-search and attach it alongside the input search results. | | [lookup command](cmd/lookup.md) | 3.0 | experimental (since 3.0) | Add or replace data from a lookup index. | | [multisearch command](cmd/multisearch.md) | 3.4 | experimental (since 3.4) | Execute multiple search queries and combine their results. | +| [union command](cmd/union.md) | 3.7 | experimental (since 3.7) | Combine results from multiple datasets using UNION ALL semantics. | | [ml command](cmd/ml.md) | 2.5 | stable (since 2.5) | Apply machine learning algorithms to analyze data. | | [kmeans command](cmd/kmeans.md) | 1.3 | stable (since 1.3) | Apply the kmeans algorithm on the search result returned by a PPL command. | | [ad command](cmd/ad.md) | 1.3 | deprecated (since 2.5) | Apply Random Cut Forest algorithm on the search result returned by a PPL command. | diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java index a56715845ab..b0a1b757858 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java @@ -107,6 +107,7 @@ CalciteTopCommandIT.class, CalciteTrendlineCommandIT.class, CalciteTransposeCommandIT.class, + CalciteUnionCommandIT.class, CalciteVisualizationFormatIT.class, CalciteWhereCommandIT.class, CalcitePPLTpchIT.class, diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 844283d937e..0d8d4e8fa0c 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -2917,4 +2917,17 @@ public void testExplainConsecutiveSortsAfterAggIssue5125() throws IOException { "source=%s | stats count() as c by gender | sort gender | sort - gender", TEST_INDEX_BANK))); } + + @Test + public void testExplainUnion() throws IOException { + String query = + "| union " + + "[search source=opensearch-sql_test_index_account | where age < 30] " + + "[search source=opensearch-sql_test_index_account | where age >= 30] " + + "| stats count() by gender"; + + String actual = explainQueryYaml(query); + String expected = loadExpectedPlan("explain_union.yaml"); + assertYamlEqualsIgnoreId(expected, actual); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteUnionCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteUnionCommandIT.java new file mode 100644 index 00000000000..1dbd34357ab --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteUnionCommandIT.java @@ -0,0 +1,270 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.remote; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_LOCATIONS_TYPE_CONFLICT; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.ppl.PPLIntegTestCase; + +public class CalciteUnionCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws Exception { + super.init(); + enableCalcite(); + loadIndex(Index.ACCOUNT); + loadIndex(Index.BANK); + loadIndex(Index.TIME_TEST_DATA); + loadIndex(Index.TIME_TEST_DATA2); + loadIndex(Index.LOCATIONS_TYPE_CONFLICT); + } + + @Test + public void testBasicUnionTwoSubsearches() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union " + + "[search source=%s | where age < 30 | eval age_group = \\\"young\\\"] " + + "[search source=%s | where age >= 30 | eval age_group = \\\"adult\\\"] " + + "| stats count by age_group | sort age_group", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema(result, schema("count", null, "bigint"), schema("age_group", null, "string")); + verifyDataRows(result, rows(549L, "adult"), rows(451L, "young")); + } + + @Test + public void testUnionThreeSubsearches() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union [search source=%s | where state = \\\"IL\\\" | eval region" + + " = \\\"Illinois\\\"] [search source=%s | where state = \\\"TN\\\" | eval" + + " region = \\\"Tennessee\\\"] [search source=%s | where state = \\\"CA\\\" |" + + " eval region = \\\"California\\\"] | stats count by region | sort region", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema(result, schema("count", null, "bigint"), schema("region", null, "string")); + verifyDataRows(result, rows(17L, "California"), rows(22L, "Illinois"), rows(25L, "Tennessee")); + } + + @Test + public void testUnionDirectTableNames() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union %s, %s | where account_number = 1 | fields firstname, city", + TEST_INDEX_ACCOUNT, TEST_INDEX_BANK)); + + verifySchema(result, schema("firstname", null, "string"), schema("city", null, "string")); + + verifyDataRows(result, rows("Amber", "Brogan"), rows("Amber JOHnny", "Brogan")); + } + + @Test + public void testUnionMixedDirectTableAndSubsearch() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union %s, [search source=%s | where age > 30] | stats count() as total", + TEST_INDEX_ACCOUNT, TEST_INDEX_BANK)); + + verifySchema(result, schema("total", null, "bigint")); + verifyDataRows(result, rows(1006L)); + } + + @Test + public void testUnionWithDifferentIndicesSchemaMerge() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union [search source=%s | where age > 35 | fields account_number," + + " firstname, balance] [search source=%s | where age > 35 | fields" + + " account_number, balance] | stats count() as total_count", + TEST_INDEX_ACCOUNT, TEST_INDEX_BANK)); + + verifySchema(result, schema("total_count", null, "bigint")); + verifyDataRows(result, rows(241L)); + } + + @Test + public void testUnionNumericCoercion_BigIntPlusInteger() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union [search source=%s | where account_number = 1 | fields balance] [search" + + " source=%s | where account_number = 1 | eval balance = 100 | fields balance]" + + " | head 2", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema(result, schema("balance", null, "bigint")); + + assertEquals(2, result.getJSONArray("datarows").length()); + } + + @Test + public void testUnionIncompatibleTypes_MultipleFieldConflicts() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union [search source=%s | where account_number = 1 | fields firstname, age," + + " balance] [search source=%s | where place_id = 1001 | fields description," + + " age, place_id] | head 2", + TEST_INDEX_ACCOUNT, TEST_INDEX_LOCATIONS_TYPE_CONFLICT)); + + verifySchema( + result, + schema("firstname", null, "string"), + schema("age", null, "string"), + schema("balance", null, "bigint"), + schema("description", null, "string"), + schema("place_id", null, "int")); + + assertEquals(2, result.getJSONArray("datarows").length()); + } + + @Test + public void testUnionAllDatasetsDifferentSchemas() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union [search source=%s | where account_number = 1 | fields account_number," + + " balance] [search source=%s | where place_id = 1001 | fields description," + + " place_id] [search source=%s | where category = \\\"A\\\" | fields category," + + " value] | stats count() as total", + TEST_INDEX_ACCOUNT, + TEST_INDEX_LOCATIONS_TYPE_CONFLICT, + "opensearch-sql_test_index_time_data")); + + verifySchema(result, schema("total", null, "bigint")); + verifyDataRows(result, rows(28L)); + } + + @Test + public void testUnionMidPipeline_SingleExplicitDataset() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | where gender = \\\"M\\\" " + + "| union [search source=%s | where gender = \\\"F\\\"] " + + "| stats count() as total", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema(result, schema("total", null, "bigint")); + verifyDataRows(result, rows(1000L)); + } + + @Test + public void testUnionWithExplicitOrdering() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union [search source=%s | where account_number = 1 | fields account_number," + + " balance] [search source=%s | where account_number = 6 | fields" + + " account_number, balance] | sort balance desc", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema( + result, schema("account_number", null, "bigint"), schema("balance", null, "bigint")); + + verifyDataRows(result, rows(1L, 39225L), rows(6L, 5686L)); + } + + @Test + public void testUnionWithMaxout() throws IOException { + String ppl = + "| union maxout=5 " + + "[search source=%s | where gender = \\\"M\\\"] " + + "[search source=%s | where gender = \\\"F\\\"]"; + JSONObject result = executeQuery(String.format(ppl, TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema( + result, + schema("account_number", null, "bigint"), + schema("firstname", null, "string"), + schema("address", null, "string"), + schema("balance", null, "bigint"), + schema("gender", null, "string"), + schema("city", null, "string"), + schema("employer", null, "string"), + schema("state", null, "string"), + schema("age", null, "bigint"), + schema("email", null, "string"), + schema("lastname", null, "string")); + + assertEquals(5, result.getJSONArray("datarows").length()); + } + + @Test + public void testUnionWithEmptySubsearch() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union " + + "[search source=%s | where age > 25] " + + "[search source=%s | where age > 200 | eval impossible = \\\"yes\\\"] " + + "| stats count", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema(result, schema("count", null, "bigint")); + verifyDataRows(result, rows(733L)); + } + + @Test + public void testUnionWithAllEmptyDatasets() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union " + + "[search source=%s | where age > 1000] " + + "[search source=%s | where age > 1000] " + + "| stats count() as total", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema(result, schema("total", null, "bigint")); + verifyDataRows(result, rows(0L)); + } + + @Test + public void testUnionPreservesDuplicatesExactCopy() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union " + + "[search source=%s | where account_number = 1] " + + "[search source=%s | where account_number = 1] " + + "[search source=%s | where account_number = 1] " + + "| stats count() as total", + TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT, TEST_INDEX_ACCOUNT)); + + verifySchema(result, schema("total", null, "bigint")); + verifyDataRows(result, rows(3L)); + } + + @Test + public void testUnionWithSingleSubsearchThrowsError() { + Exception exception = + assertThrows( + ResponseException.class, + () -> + executeQuery( + String.format( + "| union " + "[search source=%s | where age > 30]", TEST_INDEX_ACCOUNT))); + + assertTrue(exception.getMessage().contains("Union command requires at least two datasets")); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java index ded727765f7..837865a3585 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/NewAddedCommandsIT.java @@ -530,4 +530,20 @@ public void testMvExpandInvalidLimitNegative() throws IOException { assertThat(error.getString("type"), equalTo("SyntaxCheckException")); } } + + @Test + public void testUnionUnsupportedInV2() throws IOException { + JSONObject result; + try { + result = + executeQuery( + String.format( + "| union [search source=%s | where age < 30] [search source=%s | where age >=" + + " 30]", + TEST_INDEX_BANK, TEST_INDEX_BANK)); + } catch (ResponseException e) { + result = new JSONObject(TestUtils.getResponseBody(e.getResponse())); + } + verifyQuery(result); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/security/CalciteCrossClusterSearchIT.java b/integ-test/src/test/java/org/opensearch/sql/security/CalciteCrossClusterSearchIT.java index e55e406de7b..be64be8e019 100644 --- a/integ-test/src/test/java/org/opensearch/sql/security/CalciteCrossClusterSearchIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/security/CalciteCrossClusterSearchIT.java @@ -528,4 +528,15 @@ public void testCrossClusterMvExpandWithLimit() throws IOException { verifySchema(result, schema("username", "string"), schema("skills.name", "string")); verifyDataRows(result, rows("limituser", "a"), rows("limituser", "b")); } + + @Test + public void testCrossClusterUnion() throws IOException { + JSONObject result = + executeQuery( + String.format( + "| union [search source=%s | where age < 30] [search source=%s | where age >= 30] |" + + " stats count() by gender", + TEST_INDEX_BANK_REMOTE, TEST_INDEX_BANK_REMOTE)); + verifyColumn(result, columnName("count()"), columnName("gender")); + } } diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_union.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_union.yaml new file mode 100644 index 00000000000..8b2a4aecab3 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_union.yaml @@ -0,0 +1,20 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(count()=[$1], gender=[$0]) + LogicalAggregate(group=[{0}], count()=[COUNT()]) + LogicalProject(gender=[$4]) + LogicalUnion(all=[true]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10]) + LogicalFilter(condition=[<($8, 30)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10]) + LogicalFilter(condition=[>=($8, 30)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..1=[{inputs}], count()=[$t1], gender=[$t0]) + EnumerableAggregate(group=[{0}], count()=[COUNT()]) + EnumerableUnion(all=[true]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[gender, age], FILTER-><($1, 30), PROJECT->[gender]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","query":{"range":{"age":{"from":null,"to":30,"include_lower":true,"include_upper":false,"boost":1.0}}},"_source":{"includes":["gender"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[gender, age], FILTER->>=($1, 30), PROJECT->[gender]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","query":{"range":{"age":{"from":30,"to":null,"include_lower":true,"include_upper":true,"boost":1.0}}},"_source":{"includes":["gender"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_union.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_union.yaml new file mode 100644 index 00000000000..22a9bb6b5bd --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_union.yaml @@ -0,0 +1,22 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(count()=[$1], gender=[$0]) + LogicalAggregate(group=[{0}], count()=[COUNT()]) + LogicalProject(gender=[$4]) + LogicalUnion(all=[true]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10]) + LogicalFilter(condition=[<($8, 30)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10]) + LogicalFilter(condition=[>=($8, 30)]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..1=[{inputs}], count()=[$t1], gender=[$t0]) + EnumerableAggregate(group=[{0}], count()=[COUNT()]) + EnumerableUnion(all=[true]) + EnumerableCalc(expr#0..16=[{inputs}], expr#17=[30], expr#18=[<($t8, $t17)], gender=[$t4], $condition=[$t18]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + EnumerableCalc(expr#0..16=[{inputs}], expr#17=[30], expr#18=[>=($t8, $t17)], gender=[$t4], $condition=[$t18]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 36b52cf8a1b..2349f6176dd 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -171,6 +171,8 @@ TRAINING_DATA_SIZE: 'TRAINING_DATA_SIZE'; ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD'; APPEND: 'APPEND'; MULTISEARCH: 'MULTISEARCH'; +UNION: 'UNION'; +MAXOUT: 'MAXOUT'; COUNTFIELD: 'COUNTFIELD'; SHOWCOUNT: 'SHOWCOUNT'; LIMIT: 'LIMIT'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 7e3862a8683..09d90e4003f 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -49,6 +49,7 @@ pplCommands | searchCommand | multisearchCommand | graphLookupCommand + | unionCommand ; commands @@ -96,6 +97,7 @@ commands | fieldformatCommand | nomvCommand | graphLookupCommand + | unionCommand ; commandName @@ -139,6 +141,7 @@ commandName | ADDCOLTOTALS | APPEND | MULTISEARCH + | UNION | REX | APPENDPIPE | REPLACE @@ -596,6 +599,19 @@ multisearchCommand : MULTISEARCH (LT_SQR_PRTHS subSearch RT_SQR_PRTHS)+ ; +unionCommand + : UNION subsearchOptions? unionDataset (COMMA? unionDataset)* + ; + +subsearchOptions + : (MAXOUT EQUAL maxout=integerLiteral)? + ; + +unionDataset + : LT_SQR_PRTHS subSearch RT_SQR_PRTHS + | tableSource + ; + kmeansCommand : KMEANS (kmeansParameter)* ; @@ -1683,6 +1699,7 @@ searchableKeyWord | ANOMALY_SCORE_THRESHOLD | COUNTFIELD | SHOWCOUNT + | MAXOUT | PATH | INPUT | OUTPUT diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index ff0ef4bd8db..e85a9b508d8 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -120,6 +120,7 @@ import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.Transpose; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.Union; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Window; import org.opensearch.sql.calcite.plan.OpenSearchConstants; @@ -1339,6 +1340,37 @@ public UnresolvedPlan visitMultisearchCommand(OpenSearchPPLParser.MultisearchCom return new Multisearch(subsearches); } + @Override + public UnresolvedPlan visitUnionCommand(OpenSearchPPLParser.UnionCommandContext ctx) { + List datasets = new ArrayList<>(); + + Integer maxout = null; + if (ctx.subsearchOptions() != null) { + OpenSearchPPLParser.SubsearchOptionsContext opts = ctx.subsearchOptions(); + if (opts.maxout != null) { + maxout = Integer.parseInt(opts.maxout.getText()); + } + } + + for (OpenSearchPPLParser.UnionDatasetContext datasetCtx : ctx.unionDataset()) { + if (datasetCtx.subSearch() != null) { + datasets.add(visitSubSearch(datasetCtx.subSearch())); + } else if (datasetCtx.tableSource() != null) { + datasets.add( + new Relation( + Collections.singletonList(internalVisitExpression(datasetCtx.tableSource())))); + } + } + + // Allow 1+ here; total count (including implicit upstream) validated during planning + if (datasets.isEmpty()) { + throw new SyntaxCheckException( + "Union command requires at least one dataset. Provided: " + datasets.size()); + } + + return new Union(datasets, maxout); + } + @Override public UnresolvedPlan visitRexCommand(OpenSearchPPLParser.RexCommandContext ctx) { UnresolvedExpression field = internalVisitExpression(ctx.rexExpr().field); diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index fd1c10fea9c..94904d3ef18 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -107,6 +107,7 @@ import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.Transpose; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.ast.tree.Union; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.ast.tree.Window; @@ -793,32 +794,37 @@ public String visitAppend(Append node, String context) { @Override public String visitMultisearch(Multisearch node, String context) { + return anonymizeSubsearchCommand("multisearch", node.getSubsearches()); + } + + @Override + public String visitUnion(Union node, String context) { + return anonymizeSubsearchCommand("union", node.getDatasets()); + } + + private String anonymizeSubsearchCommand(String commandName, List subsearches) { + String keywords = + "source|fields|where|stats|head|tail|sort|eval|rename|" + + commandName + + "|search|table|identifier|\\*\\*\\*"; List anonymizedSubsearches = new ArrayList<>(); - for (UnresolvedPlan subsearch : node.getSubsearches()) { + for (UnresolvedPlan subsearch : subsearches) { String anonymizedSubsearch = anonymizeData(subsearch); anonymizedSubsearch = "search " + anonymizedSubsearch; anonymizedSubsearch = anonymizedSubsearch - .replaceAll("\\bsource=\\w+", "source=table") // Replace table names after source= - .replaceAll( - "\\b(?!source|fields|where|stats|head|tail|sort|eval|rename|multisearch|search|table|identifier|\\*\\*\\*)\\w+(?=\\s*[<>=!])", - "identifier") // Replace field names before operators - .replaceAll( - "\\b(?!source|fields|where|stats|head|tail|sort|eval|rename|multisearch|search|table|identifier|\\*\\*\\*)\\w+(?=\\s*,)", - "identifier") // Replace field names before commas - .replaceAll( - "fields" - + " \\+\\s*\\b(?!source|fields|where|stats|head|tail|sort|eval|rename|multisearch|search|table|identifier|\\*\\*\\*)\\w+", - "fields + identifier") // Replace field names after 'fields +' + .replaceAll("\\bsource=\\w+", "source=table") + .replaceAll("\\b(?!" + keywords + ")\\w+(?=\\s*[<>=!])", "identifier") + .replaceAll("\\b(?!" + keywords + ")\\w+(?=\\s*,)", "identifier") + .replaceAll("fields \\+\\s*\\b(?!" + keywords + ")\\w+", "fields + identifier") .replaceAll( - "fields" - + " \\+\\s*identifier,\\s*\\b(?!source|fields|where|stats|head|tail|sort|eval|rename|multisearch|search|table|identifier|\\*\\*\\*)\\w+", - "fields + identifier,identifier"); // Handle multiple fields + "fields \\+\\s*identifier,\\s*\\b(?!" + keywords + ")\\w+", + "fields + identifier,identifier"); anonymizedSubsearches.add(StringUtils.format("[%s]", anonymizedSubsearch)); } - return StringUtils.format("| multisearch %s", String.join(" ", anonymizedSubsearches)); + return StringUtils.format("| %s %s", commandName, String.join(" ", anonymizedSubsearches)); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLUnionTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLUnionTest.java new file mode 100644 index 00000000000..a16e0e6a6be --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLUnionTest.java @@ -0,0 +1,591 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import com.google.common.collect.ImmutableList; +import java.sql.Timestamp; +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.DataContext; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.linq4j.Enumerable; +import org.apache.calcite.linq4j.Linq4j; +import org.apache.calcite.plan.RelTraitDef; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelProtoDataType; +import org.apache.calcite.schema.ScannableTable; +import org.apache.calcite.schema.Schema; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.Statistic; +import org.apache.calcite.schema.Statistics; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.test.CalciteAssert; +import org.apache.calcite.tools.Frameworks; +import org.apache.calcite.tools.Programs; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.junit.Test; + +public class CalcitePPLUnionTest extends CalcitePPLAbstractTest { + + public CalcitePPLUnionTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Override + protected Frameworks.ConfigBuilder config(CalciteAssert.SchemaSpec... schemaSpecs) { + final SchemaPlus rootSchema = Frameworks.createRootSchema(true); + final SchemaPlus schema = CalciteAssert.addSchema(rootSchema, schemaSpecs); + + ImmutableList timeData1 = + ImmutableList.of( + new Object[] { + Timestamp.valueOf("2025-08-01 03:47:41"), + 8762, + "A", + Timestamp.valueOf("2025-08-01 03:47:41") + }, + new Object[] { + Timestamp.valueOf("2025-08-01 01:14:11"), + 9015, + "B", + Timestamp.valueOf("2025-08-01 01:14:11") + }, + new Object[] { + Timestamp.valueOf("2025-07-31 23:40:33"), + 8676, + "A", + Timestamp.valueOf("2025-07-31 23:40:33") + }, + new Object[] { + Timestamp.valueOf("2025-07-31 21:07:03"), + 8490, + "B", + Timestamp.valueOf("2025-07-31 21:07:03") + }); + + ImmutableList timeData2 = + ImmutableList.of( + new Object[] { + Timestamp.valueOf("2025-08-01 04:00:00"), + 2001, + "E", + Timestamp.valueOf("2025-08-01 04:00:00") + }, + new Object[] { + Timestamp.valueOf("2025-08-01 02:30:00"), + 2002, + "F", + Timestamp.valueOf("2025-08-01 02:30:00") + }, + new Object[] { + Timestamp.valueOf("2025-08-01 01:00:00"), + 2003, + "E", + Timestamp.valueOf("2025-08-01 01:00:00") + }, + new Object[] { + Timestamp.valueOf("2025-07-31 22:15:00"), + 2004, + "F", + Timestamp.valueOf("2025-07-31 22:15:00") + }); + + ImmutableList nonTimeData = + ImmutableList.of( + new Object[] {1001, "Product A", 100.0}, new Object[] {1002, "Product B", 200.0}); + + schema.add("TIME_DATA1", new TimeDataTable(timeData1)); + schema.add("TIME_DATA2", new TimeDataTable(timeData2)); + schema.add("NON_TIME_DATA", new NonTimeDataTable(nonTimeData)); + + return Frameworks.newConfigBuilder() + .parserConfig(SqlParser.Config.DEFAULT) + .defaultSchema(schema) + .traitDefs((List) null) + .programs(Programs.heuristicJoinOrder(Programs.RULE_SET, true, 2)); + } + + @Test + public void testBasicUnionTwoDatasets() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10] " + + "[search source=EMP | where DEPTNO = 20]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 10\n" + + "UNION ALL\n" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 20"; + verifyPPLToSparkSQL(root, expectedSparkSql); + verifyResultCount(root, 8); + } + + @Test + public void testUnionThreeDatasets() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10] " + + "[search source=EMP | where DEPTNO = 20] " + + "[search source=EMP | where DEPTNO = 30]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalFilter(condition=[=($7, 30)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 10\n" + + "UNION ALL\n" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 20\n" + + "UNION ALL\n" + + "SELECT *\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 30"; + verifyPPLToSparkSQL(root, expectedSparkSql); + verifyResultCount(root, 14); + } + + @Test + public void testUnionCrossIndicesSchemaDifference() { + String ppl = + "| union [search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME," + + " JOB] [search source=DEPT | where DEPTNO = 10 | fields DEPTNO, DNAME, LOC]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], DEPTNO=[null:TINYINT]," + + " DNAME=[null:VARCHAR(14)], LOC=[null:VARCHAR(13)])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], DEPTNO=[$0], DNAME=[$1], LOC=[$2])\n" + + " LogicalFilter(condition=[=($0, 10)])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `ENAME`, `JOB`, CAST(NULL AS TINYINT) `DEPTNO`, CAST(NULL AS STRING)" + + " `DNAME`, CAST(NULL AS STRING) `LOC`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 10\n" + + "UNION ALL\n" + + "SELECT CAST(NULL AS SMALLINT) `EMPNO`, CAST(NULL AS STRING) `ENAME`, CAST(NULL AS" + + " STRING) `JOB`, `DEPTNO`, `DNAME`, `LOC`\n" + + "FROM `scott`.`DEPT`\n" + + "WHERE `DEPTNO` = 10"; + verifyPPLToSparkSQL(root, expectedSparkSql); + verifyResultCount(root, 4); + } + + @Test + public void testUnionWithStats() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10 | eval type = \"accounting\"] " + + "[search source=EMP | where DEPTNO = 20 | eval type = \"research\"] " + + "| stats count by type"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalProject(count=[$1], type=[$0])\n" + + " LogicalAggregate(group=[{0}], count=[COUNT()])\n" + + " LogicalProject(type=[$8])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], type=['accounting':VARCHAR])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], type=['research':VARCHAR])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT COUNT(*) `count`, `type`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " 'accounting' `type`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 10\n" + + "UNION ALL\n" + + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " 'research' `type`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `DEPTNO` = 20) `t3`\n" + + "GROUP BY `type`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + verifyResultCount(root, 2); + } + + @Test + public void testUnionDirectTableNames() { + String ppl = "| union EMP, DEPT"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], DNAME=[null:VARCHAR(14)]," + + " LOC=[null:VARCHAR(13)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)]," + + " JOB=[null:VARCHAR(9)], MGR=[null:SMALLINT], HIREDATE=[null:DATE]," + + " SAL=[null:DECIMAL(7, 2)], COMM=[null:DECIMAL(7, 2)], DEPTNO=[CAST($0):TINYINT]," + + " DNAME=[$1], LOC=[$2])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testUnionNonStreamingModeAppend() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME] " + + "[search source=NON_TIME_DATA | fields id, name]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], id=[null:INTEGER], name=[null:VARCHAR])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)], id=[$0]," + + " name=[$1])\n" + + " LogicalTableScan(table=[[scott, NON_TIME_DATA]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testUnionWithMaxout() { + String ppl = + "| union maxout=5 " + + "[search source=EMP | where DEPTNO = 10] " + + "[search source=EMP | where DEPTNO = 20]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalSystemLimit(fetch=[5], type=[SUBSEARCH_MAXOUT])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testUnionWithIdenticalSchemasAndFieldProjection() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME], " + + "[search source=EMP | where DEPTNO = 20 | fields EMPNO, ENAME]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 8); + } + + @Test + public void testUnionAsFirstCommand() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME] " + + "[search source=EMP | where DEPTNO = 20 | fields EMPNO, ENAME]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 8); + } + + @Test + public void testUnionWithCompletelyDifferentSchemas() { + String ppl = + "| union " + + "[search source=EMP | fields EMPNO, ENAME] " + + "[search source=DEPT | fields DEPTNO, DNAME]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], DEPTNO=[null:TINYINT]," + + " DNAME=[null:VARCHAR(14)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[null:SMALLINT], ENAME=[null:VARCHAR(10)], DEPTNO=[$0]," + + " DNAME=[$1])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 18); + } + + @Test + public void testUnionWithPartialSchemaOverlap() { + String ppl = + "| union " + + "[search source=EMP | fields EMPNO, ENAME, JOB] " + + "[search source=EMP | fields EMPNO, ENAME, SAL]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], SAL=[null:DECIMAL(7, 2)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[null:VARCHAR(9)], SAL=[$5])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 28); + } + + @Test + public void testUnionWithFilteredSubsearches() { + String ppl = + "| union " + + "[search source=EMP | where SAL > 2000 | fields EMPNO, ENAME] " + + "[search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[>($5, 2000)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testUnionPreservesDuplicateRows() { + String ppl = + "| union " + + "[search source=EMP | where EMPNO = 7369 | fields EMPNO, ENAME] " + + "[search source=EMP | where EMPNO = 7369 | fields EMPNO, ENAME]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($0, 7369)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($0, 7369)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 2); + } + + @Test + public void testUnionWithEmptyDataset() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME] " + + "[search source=EMP | where DEPTNO = 99 | fields EMPNO, ENAME]"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 99)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 3); + } + + @Test + public void testUnionFollowedByAggregation() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME], " + + "[search source=EMP | where DEPTNO = 20 | fields EMPNO, ENAME] " + + "| stats count()"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalAggregate(group=[{}], count()=[COUNT()])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 1); + } + + @Test + public void testUnionFollowedBySort() { + String ppl = + "| union " + + "[search source=EMP | where DEPTNO = 10 | fields EMPNO, ENAME] " + + "[search source=EMP | where DEPTNO = 20 | fields EMPNO, ENAME] " + + "| sort ENAME"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalSort(sort0=[$1], dir0=[ASC-nulls-first])\n" + + " LogicalUnion(all=[true])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 10)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1])\n" + + " LogicalFilter(condition=[=($7, 20)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + verifyResultCount(root, 8); + } + + @RequiredArgsConstructor + static class TimeDataTable implements ScannableTable { + private final ImmutableList rows; + + protected final RelProtoDataType protoRowType = + factory -> + factory + .builder() + .add("timestamp", SqlTypeName.TIMESTAMP) + .nullable(true) + .add("value", SqlTypeName.INTEGER) + .nullable(true) + .add("category", SqlTypeName.VARCHAR) + .nullable(true) + .add("@timestamp", SqlTypeName.TIMESTAMP) + .nullable(true) + .build(); + + @Override + public Enumerable<@Nullable Object[]> scan(DataContext root) { + return Linq4j.asEnumerable(rows); + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + return protoRowType.apply(typeFactory); + } + + @Override + public Statistic getStatistic() { + return Statistics.of(0d, ImmutableList.of(), RelCollations.createSingleton(0)); + } + + @Override + public Schema.TableType getJdbcTableType() { + return Schema.TableType.TABLE; + } + + @Override + public boolean isRolledUp(String column) { + return false; + } + + @Override + public boolean rolledUpColumnValidInsideAgg( + String column, + SqlCall call, + @Nullable SqlNode parent, + @Nullable CalciteConnectionConfig config) { + return false; + } + } + + @RequiredArgsConstructor + static class NonTimeDataTable implements ScannableTable { + private final ImmutableList rows; + + protected final RelProtoDataType protoRowType = + factory -> + factory + .builder() + .add("id", SqlTypeName.INTEGER) + .nullable(true) + .add("name", SqlTypeName.VARCHAR) + .nullable(true) + .add("value", SqlTypeName.DOUBLE) + .nullable(true) + .build(); + + @Override + public Enumerable<@Nullable Object[]> scan(DataContext root) { + return Linq4j.asEnumerable(rows); + } + + @Override + public RelDataType getRowType(RelDataTypeFactory typeFactory) { + return protoRowType.apply(typeFactory); + } + + @Override + public Statistic getStatistic() { + return Statistics.of(0d, ImmutableList.of(), RelCollations.createSingleton(0)); + } + + @Override + public Schema.TableType getJdbcTableType() { + return Schema.TableType.TABLE; + } + + @Override + public boolean isRolledUp(String column) { + return false; + } + + @Override + public boolean rolledUpColumnValidInsideAgg( + String column, + SqlCall call, + @Nullable SqlNode parent, + @Nullable CalciteConnectionConfig config) { + return false; + } + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index a53e4a5d8dd..e7f3f986752 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -1828,4 +1828,39 @@ public void testEmptyPipeAndTrailingPipeTogether() { public void testMalformedPipeProducesSyntaxError() { plan("source=t | invalidCmd |"); } + + @Test + public void testUnionWithSubsearches() { + plan("| union [search source=t1 | where age > 30] " + "[search source=t2 | where age < 20]"); + } + + @Test + public void testUnionWithDirectTableNames() { + plan("| union t1, t2"); + } + + @Test + public void testUnionWithDateSuffixIndex() { + plan("| union logs-2024.01.01, logs-2024.01.02"); + } + + @Test + public void testUnionWithDottedCatalogPath() { + plan("| union catalog.my_index, catalog.other_index"); + } + + @Test + public void testUnionMidPipeline() { + plan("source=t1 | union t2, t3"); + } + + @Test + public void testUnionWithMaxoutOption() { + plan("| union maxout=500 t1, t2"); + } + + @Test + public void testMaxoutAsFieldName() { + plan("source=t | eval maxout = 1"); + } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index de230d208bb..ed32892b3b4 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -1160,4 +1160,28 @@ public void testMvexpandCommandWithLimit() { "source=table | mvexpand identifier limit=***", anonymize("source=t | mvexpand skills limit=5")); } + + @Test + public void testUnion() { + assertEquals( + "| union [search source=table | where identifier < ***] [search source=table |" + + " where identifier >= ***]", + anonymize( + "| union [search source=accounts | where age < 30] [search source=accounts" + + " | where age >= 30]")); + + assertEquals( + "| union [search source=table | where identifier > ***] [search source=table |" + + " where identifier = ***]", + anonymize( + "| union [search source=accounts | where balance > 20000] [search" + + " source=accounts | where state = 'CA']")); + + assertEquals( + "| union [search source=table | fields + identifier,identifier] [search" + + " source=table | where identifier = ***]", + anonymize( + "| union [search source=accounts | fields firstname, lastname] [search" + + " source=accounts | where age = 25]")); + } }