Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Transpose;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Union;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;
Expand Down Expand Up @@ -897,6 +898,11 @@ public LogicalPlan visitMultisearch(Multisearch node, AnalysisContext context) {
throw getOnlyForCalciteException("Multisearch");
}

@Override
public LogicalPlan visitUnion(Union node, AnalysisContext context) {
throw getOnlyForCalciteException("Union");
}

private LogicalSort buildSort(
LogicalPlan child, AnalysisContext context, Integer count, List<Field> sortFields) {
ExpressionReferenceOptimizer optimizer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Transpose;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Union;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;

Expand Down Expand Up @@ -472,6 +473,10 @@ public T visitMultisearch(Multisearch node, C context) {
return visitChildren(node, context);
}

public T visitUnion(Union node, C context) {
return visitChildren(node, context);
}

public T visitAddTotals(AddTotals node, C context) {
return visitChildren(node, context);
}
Expand Down
44 changes: 44 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Union.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/** Logical plan node for Union operation. Combines results from multiple datasets (UNION ALL). */
@Getter
@ToString
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
@AllArgsConstructor
public class Union extends UnresolvedPlan {
private final List<UnresolvedPlan> datasets;

private Integer maxout;

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
List<UnresolvedPlan> newDatasets =
ImmutableList.<UnresolvedPlan>builder().add(child).addAll(datasets).build();
return new Union(newDatasets, maxout);
}

@Override
public List<? extends UnresolvedPlan> getChild() {
return datasets;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitUnion(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.Trendline.TrendlineType;
import org.opensearch.sql.ast.tree.Union;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.ast.tree.Window;
Expand Down Expand Up @@ -2607,6 +2608,40 @@ private String findTimestampField(RelDataType rowType) {
return null;
}

@Override
public RelNode visitUnion(Union node, CalcitePlanContext context) {
List<RelNode> inputNodes = new ArrayList<>();

for (UnresolvedPlan dataset : node.getDatasets()) {
UnresolvedPlan prunedDataset = dataset.accept(new EmptySourcePropagateVisitor(), null);
prunedDataset.accept(this, context);
inputNodes.add(context.relBuilder.build());
}

if (inputNodes.size() < 2) {
throw new IllegalArgumentException(
"Union command requires at least two datasets. Provided: " + inputNodes.size());
}

List<RelNode> unifiedInputs =
SchemaUnifier.buildUnifiedSchemaWithTypeCoercion(inputNodes, context);

for (RelNode input : unifiedInputs) {
context.relBuilder.push(input);
}
context.relBuilder.union(true, unifiedInputs.size()); // true = UNION ALL

if (node.getMaxout() != null) {
context.relBuilder.push(
LogicalSystemLimit.create(
LogicalSystemLimit.SystemLimitType.SUBSEARCH_MAXOUT,
context.relBuilder.build(),
context.relBuilder.literal(node.getMaxout())));
}

return context.relBuilder.peek();
}

/*
* Unsupported Commands of PPL with Calcite for OpenSearch 3.0.0-beta
*/
Expand Down
242 changes: 240 additions & 2 deletions core/src/main/java/org/opensearch/sql/calcite/SchemaUnifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.type.SqlTypeName;

/**
* Utility class for unifying schemas across multiple RelNodes. Throws an exception when type
* conflicts are detected.
* Utility class for unifying schemas across multiple RelNodes. Supports two strategies:
*
* <ul>
* <li>Conflict resolution (multisearch): throws on type mismatch, fills missing fields with NULL
* <li>Type coercion (union): widens compatible types (e.g. INTEGER→BIGINT), falls back to VARCHAR
* for incompatible types, fills missing fields with NULL
* </ul>
*/
public class SchemaUnifier {

Expand Down Expand Up @@ -147,4 +153,236 @@ RelDataType getType() {
return type;
}
}

/**
* Builds unified schema with type coercion for UNION command. Coerces compatible types to a
* common supertype (e.g. int+float→float), falls back to VARCHAR for incompatible types, and
* fills missing fields with NULL.
*/
public static List<RelNode> buildUnifiedSchemaWithTypeCoercion(
List<RelNode> inputs, CalcitePlanContext context) {
if (inputs.isEmpty() || inputs.size() == 1) {
return inputs;
}

List<RelNode> coercedInputs = coerceUnionTypes(inputs, context);
return unifySchemasForUnion(coercedInputs, context);
}

/**
* Aligns schemas by projecting NULL for missing fields and CAST for type mismatches. Uses
* force=true to clear collation traits and prevent EnumerableMergeUnion cast exception.
*/
private static List<RelNode> unifySchemasForUnion(
List<RelNode> inputs, CalcitePlanContext context) {
List<SchemaField> unifiedSchema = buildUnifiedSchemaForUnion(inputs);
List<String> fieldNames =
unifiedSchema.stream().map(SchemaField::getName).collect(Collectors.toList());

List<RelNode> projectedNodes = new ArrayList<>();
for (RelNode node : inputs) {
List<RexNode> projection = buildProjectionForUnion(node, unifiedSchema, context);
RelNode projectedNode =
context.relBuilder.push(node).project(projection, fieldNames, true).build();
projectedNodes.add(projectedNode);
}
return projectedNodes;
}

private static List<SchemaField> buildUnifiedSchemaForUnion(List<RelNode> nodes) {
List<SchemaField> schema = new ArrayList<>();
Map<String, RelDataType> seenFields = new HashMap<>();

for (RelNode node : nodes) {
for (RelDataTypeField field : node.getRowType().getFieldList()) {
if (!seenFields.containsKey(field.getName())) {
schema.add(new SchemaField(field.getName(), field.getType()));
seenFields.put(field.getName(), field.getType());
}
}
}
return schema;
}

private static List<RexNode> buildProjectionForUnion(
RelNode node, List<SchemaField> unifiedSchema, CalcitePlanContext context) {
Map<String, RelDataTypeField> nodeFieldMap =
node.getRowType().getFieldList().stream()
.collect(Collectors.toMap(RelDataTypeField::getName, field -> field));

List<RexNode> projection = new ArrayList<>();
for (SchemaField schemaField : unifiedSchema) {
RelDataTypeField nodeField = nodeFieldMap.get(schemaField.getName());

if (nodeField != null) {
RexNode fieldRef = context.rexBuilder.makeInputRef(node, nodeField.getIndex());
if (!nodeField.getType().equals(schemaField.getType())) {
projection.add(context.rexBuilder.makeCast(schemaField.getType(), fieldRef));
} else {
projection.add(fieldRef);
}
} else {
projection.add(context.rexBuilder.makeNullLiteral(schemaField.getType()));
}
}
return projection;
}

/** Casts fields to their common supertypes across all inputs when types differ. */
private static List<RelNode> coerceUnionTypes(List<RelNode> inputs, CalcitePlanContext context) {
Map<String, List<SqlTypeName>> fieldTypeMap = new HashMap<>();
for (RelNode input : inputs) {
for (RelDataTypeField field : input.getRowType().getFieldList()) {
String fieldName = field.getName();
SqlTypeName typeName = field.getType().getSqlTypeName();
if (typeName != null) {
fieldTypeMap.computeIfAbsent(fieldName, k -> new ArrayList<>()).add(typeName);
}
}
}

Map<String, SqlTypeName> targetTypeMap = new HashMap<>();
for (Map.Entry<String, List<SqlTypeName>> entry : fieldTypeMap.entrySet()) {
String fieldName = entry.getKey();
List<SqlTypeName> types = entry.getValue();

SqlTypeName commonType = types.getFirst();
for (int i = 1; i < types.size(); i++) {
commonType = findCommonTypeForUnion(commonType, types.get(i));
}
targetTypeMap.put(fieldName, commonType);
}

boolean needsCoercion = false;
for (RelNode input : inputs) {
for (RelDataTypeField field : input.getRowType().getFieldList()) {
SqlTypeName targetType = targetTypeMap.get(field.getName());
if (targetType != null && field.getType().getSqlTypeName() != targetType) {
needsCoercion = true;
break;
}
}
if (needsCoercion) break;
}

if (!needsCoercion) {
return inputs;
}

List<RelNode> coercedInputs = new ArrayList<>();
for (RelNode input : inputs) {
List<RexNode> projections = new ArrayList<>();
List<String> projectionNames = new ArrayList<>();
boolean needsProjection = false;

for (RelDataTypeField field : input.getRowType().getFieldList()) {
String fieldName = field.getName();
SqlTypeName currentType = field.getType().getSqlTypeName();
SqlTypeName targetType = targetTypeMap.get(fieldName);

RexNode fieldRef = context.rexBuilder.makeInputRef(input, field.getIndex());

if (currentType != targetType && targetType != null) {
projections.add(context.relBuilder.cast(fieldRef, targetType));
needsProjection = true;
} else {
projections.add(fieldRef);
}
projectionNames.add(fieldName);
}

if (needsProjection) {
context.relBuilder.push(input);
context.relBuilder.project(projections, projectionNames, true);
coercedInputs.add(context.relBuilder.build());
} else {
coercedInputs.add(input);
}
}

return coercedInputs;
}

/**
* Returns the wider type for two SqlTypeNames. Within the same family, returns the wider type
* (e.g. INTEGER+BIGINT-->BIGINT). Across families, falls back to VARCHAR.
*/
private static SqlTypeName findCommonTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
if (type1 == type2) {
return type1;
}

if (type1 == SqlTypeName.NULL) {
return type2;
}
if (type2 == SqlTypeName.NULL) {
return type1;
}

if (isNumericTypeForUnion(type1) && isNumericTypeForUnion(type2)) {
return getWiderNumericTypeForUnion(type1, type2);
}

if (isStringTypeForUnion(type1) && isStringTypeForUnion(type2)) {
return SqlTypeName.VARCHAR;
}

if (isTemporalTypeForUnion(type1) && isTemporalTypeForUnion(type2)) {
return getWiderTemporalTypeForUnion(type1, type2);
}

return SqlTypeName.VARCHAR;
}

private static boolean isNumericTypeForUnion(SqlTypeName typeName) {
return typeName == SqlTypeName.TINYINT
|| typeName == SqlTypeName.SMALLINT
|| typeName == SqlTypeName.INTEGER
|| typeName == SqlTypeName.BIGINT
|| typeName == SqlTypeName.FLOAT
|| typeName == SqlTypeName.REAL
|| typeName == SqlTypeName.DOUBLE
|| typeName == SqlTypeName.DECIMAL;
}

private static boolean isStringTypeForUnion(SqlTypeName typeName) {
return typeName == SqlTypeName.CHAR || typeName == SqlTypeName.VARCHAR;
}

private static boolean isTemporalTypeForUnion(SqlTypeName typeName) {
return typeName == SqlTypeName.DATE
|| typeName == SqlTypeName.TIMESTAMP
|| typeName == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
}

private static SqlTypeName getWiderNumericTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
int rank1 = getNumericTypeRankForUnion(type1);
int rank2 = getNumericTypeRankForUnion(type2);
return rank1 >= rank2 ? type1 : type2;
}

private static int getNumericTypeRankForUnion(SqlTypeName typeName) {
return switch (typeName) {
case TINYINT -> 1;
case SMALLINT -> 2;
case INTEGER -> 3;
case BIGINT -> 4;
case DECIMAL -> 5;
case REAL -> 6;
case FLOAT -> 7;
case DOUBLE -> 8;
default -> 0;
};
}

private static SqlTypeName getWiderTemporalTypeForUnion(SqlTypeName type1, SqlTypeName type2) {
if (type1 == SqlTypeName.TIMESTAMP || type2 == SqlTypeName.TIMESTAMP) {
return SqlTypeName.TIMESTAMP;
}
if (type1 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE
|| type2 == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE) {
return SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
}
return SqlTypeName.DATE;
}
}
1 change: 1 addition & 0 deletions docs/category.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"user/ppl/cmd/top.md",
"user/ppl/cmd/trendline.md",
"user/ppl/cmd/transpose.md",
"user/ppl/cmd/union.md",
"user/ppl/cmd/where.md",
"user/ppl/functions/aggregations.md",
"user/ppl/functions/collection.md",
Expand Down
Loading
Loading