diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ee96d6d83f90e..4e866c57fd35a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -702,6 +702,12 @@ ], "sqlState" : "0A000" }, + "CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE": { + "message" : [ + "CLUSTER BY expression has either no column reference, or a column reference in an unsupported argument position." + ], + "sqlState" : "42000" + }, "CLUSTERING_COLUMNS_MISMATCH" : { "message" : [ "Specified clustering does not match that of the existing table .", diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index a905905c098e8..faeb01c26277f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -536,8 +536,17 @@ replaceTableHeader : (CREATE OR)? REPLACE TABLE identifierReference ; +expressionOrMultipartIdentifier + : expression + | multipartIdentifier + ; + +expressionOrMultipartIdentifierList + : expressionOrMultipartIdentifier (COMMA expressionOrMultipartIdentifier)* + ; + clusterBySpec - : CLUSTER BY LEFT_PAREN multipartIdentifierList RIGHT_PAREN + : CLUSTER BY LEFT_PAREN expressionOrMultipartIdentifierList RIGHT_PAREN ; bucketSpec diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java index 2cf2414f8052c..1f2f7acc18481 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableChange.java @@ -19,11 +19,13 @@ import java.util.Arrays; import java.util.Objects; +import java.util.Optional; import javax.annotation.Nullable; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.catalog.constraints.Constraint; import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.DataType; /** @@ -277,8 +279,10 @@ static TableChange deleteColumn(String[] fieldNames, Boolean ifExists) { * field names. * @return a TableChange for this assignment */ - static TableChange clusterBy(NamedReference[] clusteringColumns) { - return new ClusterBy(clusteringColumns); + static TableChange clusterBy( + NamedReference[] clusteringColumns, + Optional[] transforms) { + return new ClusterBy(clusteringColumns, transforms); } /** @@ -873,24 +877,30 @@ public String toString() { /** A TableChange to alter clustering columns for a table. */ final class ClusterBy implements TableChange { private final NamedReference[] clusteringColumns; + private final Optional[] transforms; - private ClusterBy(NamedReference[] clusteringColumns) { + private ClusterBy( + NamedReference[] clusteringColumns, + Optional[] transforms) { this.clusteringColumns = clusteringColumns; + this.transforms = transforms; } public NamedReference[] clusteringColumns() { return clusteringColumns; } + public Optional[] transforms() { return transforms; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ClusterBy that = (ClusterBy) o; - return Arrays.equals(clusteringColumns, that.clusteringColumns()); + return Arrays.equals(clusteringColumns, that.clusteringColumns()) + && Arrays.equals(transforms, that.transforms); } @Override public int hashCode() { - return Arrays.hashCode(clusteringColumns); + return Objects.hash(Arrays.hashCode(clusteringColumns), Arrays.hashCode(transforms)); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index eaee334a01cbd..7f42c9bbe7b7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -36,15 +36,17 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CurrentUserContext, FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NormalizeableRelation, Resolver, SchemaBinding, SchemaCompensation, SchemaEvolution, SchemaTypeEvolution, SchemaUnsupported, UnresolvedLeafNode, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NormalizeableRelation, Resolver, SchemaBinding, SchemaCompensation, SchemaEvolution, SchemaTypeEvolution, SchemaUnsupported, UnresolvedAttribute, UnresolvedFunction, UnresolvedLeafNode, ViewSchemaMode} import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Expression, ExprId, Literal} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, NamedReference, Transform} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.ClusterByHelper +import org.apache.spark.sql.connector.expressions.{ClusterByTransform, ClusteringColumnTransform, FieldReference, LiteralValue, NamedReference, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -270,11 +272,38 @@ case class CatalogTablePartition( * A container for clustering information. * * @param columnNames the names of the columns used for clustering. + * @param clusteringColumnTransforms per-column transforms for expression-based clustering. + * When non-empty, each element corresponds to a column in + * columnNames: None means a plain column reference, + * Some(transform) means an expression like UPPER(col). + * An empty Seq means no transforms on any columns. */ -case class ClusterBySpec(columnNames: Seq[NamedReference]) { +case class ClusterBySpec( + columnNames: Seq[NamedReference], + clusteringColumnTransforms: Seq[Option[Transform]] = Seq.empty) { override def toString: String = toJson - def toJson: String = ClusterBySpec.mapper.writeValueAsString(columnNames.map(_.fieldNames)) + def toJson: String = toColumnNames + + def toColumnNames: String = { + val entries: Seq[Seq[String]] = if (clusteringColumnTransforms.isEmpty) { + columnNames.map(_.fieldNames().toSeq) + } else { + columnNames.zip(clusteringColumnTransforms).map { + case (colName, None) => colName.fieldNames().toSeq + case (colName, Some(transform)) => + val args = transform.arguments().map { + case n: NamedReference => n.fieldNames().map(QuotingUtils.quoteIfNeeded).mkString(".") + case LiteralValue(value, dataType) => + Literal(value, dataType).sql + case other => throw new IllegalStateException( + s"Unexpected argument type in CLUSTER BY expression: ${other.getClass}") + } + Seq(s"${transform.name()}(${args.mkString(",")})") + } + } + ClusterBySpec.mapper.writeValueAsString(entries) + } } object ClusterBySpec { @@ -290,7 +319,104 @@ object ClusterBySpec { * Converts the clustering column property to a ClusterBySpec. */ def fromProperty(columns: String): ClusterBySpec = { - ClusterBySpec(mapper.readValue[Seq[Seq[String]]](columns).map(FieldReference(_))) + ClusterBySpec.fromColumnEntries(mapper.readValue[Seq[Seq[String]]](columns)) + } + + /** + * Constructs a [[ClusterBySpec]] from the stored column entries (each a Seq[String]). + * An entry is either a multi-part column name or a single-element Seq containing an + * expression string like "variant_get(col,'$.foo','STRING')". + */ + def fromColumnEntries(entries: Seq[Seq[String]]): ClusterBySpec = { + val parsedCols: Seq[(NamedReference, Option[Transform])] = entries.map { + case parts if parts.length == 1 && mightBeExpression(parts.head) => + // Expression form: "funcName(col, arg1, arg2, ...)" + CatalystSqlParser.parseExpression(parts.head) match { + case u: UnresolvedFunction => + val transform: Transform = new ClusteringColumnTransform( + u.nameParts.mkString("."), + u.children.map { + case a: UnresolvedAttribute => + FieldReference(a.nameParts.map(QuotingUtils.quoteIfNeeded)) + case l: Literal => LiteralValue(l.value, l.dataType) + case other => throw new IllegalStateException( + s"Unexpected argument type in CLUSTER BY expression: ${other.getClass}") + }.toArray) + val colRef = transform.arguments().collectFirst { + case f: FieldReference => f + }.getOrElse(throw new IllegalStateException( + "CLUSTER BY expression must contain exactly one column reference")) + (colRef, Some(transform)) + case other => + // Plain multi-part column name that had a false positive with + // mightBeExpression. + (FieldReference(parts), None) + } + case parts => + // Plain multi-part column name + (FieldReference(parts), None) + } + val (colNames, transforms) = parsedCols.unzip + val transformsSeq = if (transforms.forall(_.isEmpty)) Seq.empty else transforms + ClusterBySpec(colNames, transformsSeq) + } + + /** + * Returns true if the string looks like a function call expression (contains parentheses), + * rather than a plain column identifier. + */ + private def mightBeExpression(s: String): Boolean = s.contains("(") + + def fromExpressions( + parsedCols: Seq[Either[Expression, Seq[String]]]): ClusterBySpec = { + val (clusteringColumnNames, clusteringColumnExpressions) = parsedCols.map { + case Left(e) => + e match { + // A bare column reference parsed as an expression - treat as plain column. + case a: UnresolvedAttribute => + (FieldReference(a.nameParts), None) + case u: UnresolvedFunction => + val transform = new ClusteringColumnTransform( + u.nameParts.mkString("."), + u.children.map { + case a: UnresolvedAttribute => FieldReference(a.nameParts) + case l: Literal => LiteralValue(l.value, l.dataType) + case _ => throw new IllegalStateException( + "Unsupported expression argument in CLUSTER BY transform") + }.toArray) + val transformName = u.nameParts.mkString(".") + val refs = transform.arguments().collect { + case f: FieldReference => f + } + if (refs.isEmpty) { + throw new AnalysisException( + errorClass = "CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE", + messageParameters = Map("expressionType" -> transformName)) + } + if (refs.length != 1) { + throw new AnalysisException( + errorClass = "CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE", + messageParameters = Map("expressionType" -> transformName)) + } + if (!transform.arguments().head.isInstanceOf[FieldReference]) { + throw new AnalysisException( + errorClass = "CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE", + messageParameters = Map( + "expressionType" -> transformName)) + } + (FieldReference(refs.head.fieldNames.toIndexedSeq), Some(transform)) + case _ => throw new IllegalStateException( + "Unsupported expression in CLUSTER BY: only function calls are supported") + } + case Right(names) => (FieldReference(names), None) + }.unzip + // If there are no transforms at all (all plain columns), use empty Seq for backward compat + val transformsSeq = if (clusteringColumnExpressions.forall(_.isEmpty)) { + Seq.empty + } else { + clusteringColumnExpressions + } + ClusterBySpec(clusteringColumnNames, transformsSeq) } /** @@ -339,12 +465,17 @@ object ClusterBySpec { normalizedColumns.map(_.toString), resolver) - ClusterBySpec(normalizedColumns) + ClusterBySpec(normalizedColumns, clusterBySpec.clusteringColumnTransforms) } def extractClusterBySpec(transforms: Seq[Transform]): Option[ClusterBySpec] = { transforms.collectFirst { - case ClusterByTransform(columnNames) => ClusterBySpec(columnNames) + case ct: ClusterByTransform => + if (ct.transforms.nonEmpty) { + ClusterBySpec(ct.columnNames, ct.toClusteringColumnTransforms) + } else { + ClusterBySpec(ct.columnNames) + } } } @@ -353,7 +484,7 @@ object ClusterBySpec { clusterBySpec: ClusterBySpec, resolver: Resolver): ClusterByTransform = { val normalizedClusterBySpec = normalizeClusterBySpec(schema, clusterBySpec, resolver) - ClusterByTransform(normalizedClusterBySpec.columnNames) + new ClusterByHelper(normalizedClusterBySpec).asTransform.asInstanceOf[ClusterByTransform] } def fromColumnNames(names: Seq[String]): ClusterBySpec = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 753ed76fe16b5..6f5061c3fc524 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -4741,13 +4741,26 @@ class AstBuilder extends DataTypeAstBuilder }) } + override def visitExpressionOrMultipartIdentifier( + ctx: ExpressionOrMultipartIdentifierContext): Either[Expression, Seq[String]] = + withOrigin(ctx) { + if (ctx.expression() != null) { + scala.util.Left(expression(ctx.expression())) + } else { + scala.util.Right(typedVisit[Seq[String]](ctx.multipartIdentifier())) + } + } + /** * Create a [[ClusterBySpec]]. */ override def visitClusterBySpec(ctx: ClusterBySpecContext): ClusterBySpec = withOrigin(ctx) { - val columnNames = ctx.multipartIdentifierList.multipartIdentifier.asScala - .map(typedVisit[Seq[String]]).map(FieldReference(_)).toSeq - ClusterBySpec(columnNames) + val columnReferences = + ctx.expressionOrMultipartIdentifierList.expressionOrMultipartIdentifier + .asScala + .map(visitExpressionOrMultipartIdentifier) + .toSeq + ClusterBySpec.fromExpressions(columnReferences) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala index 68b3573ce5ce7..d1650d30ad124 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, Expression, T import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.connector.catalog.{DefaultValue, TableCatalog, TableChange} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper +import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.DataType import org.apache.spark.util.ArrayImplicits._ @@ -309,9 +310,19 @@ case class AlterColumns( case class AlterTableClusterBy( table: LogicalPlan, clusterBySpec: Option[ClusterBySpec]) extends AlterTableCommand { override def changes: Seq[TableChange] = { + val clusterByTransforms = clusterBySpec.map { spec => + if (spec.clusteringColumnTransforms.nonEmpty) { + spec.clusteringColumnTransforms.map { + case None => java.util.Optional.empty[Transform]() + case Some(transform) => java.util.Optional.of[Transform](transform) + }.toArray + } else { + spec.columnNames.map(_ => java.util.Optional.empty[Transform]()).toArray + } + }.getOrElse(Array.empty[java.util.Optional[Transform]]) Seq(TableChange.clusterBy(clusterBySpec .map(_.columnNames.toArray) // CLUSTER BY (col1, col2, ...) - .getOrElse(Array.empty))) + .getOrElse(Array.empty), clusterByTransforms)) } protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index cf6052009c927..3a5778644499c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, quoteNameParts, QuotingUtils} -import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByColumnTransform, ClusterByTransform, FieldReference, IdentityTransform, LiteralValue, LogicalExpressions, NamedReference, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.StructType import org.apache.spark.util.ArrayImplicits._ @@ -56,7 +56,17 @@ private[sql] object CatalogV2Implicits { } implicit class ClusterByHelper(spec: ClusterBySpec) { - def asTransform: Transform = clusterBy(spec.columnNames.toArray) + def asTransform: Transform = { + val transforms = spec.clusteringColumnTransforms.zipWithIndex.collect { + case (Some(t), colIdx) => + ClusterByColumnTransform( + columnIndex = colIdx, + argumentIndex = t.arguments().indexWhere(_.isInstanceOf[NamedReference]), + function = t.name(), + arguments = t.arguments().collect { case l: LiteralValue[_] => l }.toSeq) + } + ClusterByTransform(spec.columnNames, transforms) + } } implicit class TransformHelper(transforms: Seq[Transform]) { @@ -80,12 +90,16 @@ private[sql] object CatalogV2Implicits { sortCol.map(_.fieldNames.mkString(".")))) } - case ClusterByTransform(columnNames) => + case ct @ ClusterByTransform(columnNames) => if (clusterBySpec.nonEmpty) { // AstBuilder guarantees that it only passes down one ClusterByTransform. throw SparkException.internalError("Cannot have multiple cluster by transforms.") } - clusterBySpec = Some(ClusterBySpec(columnNames)) + clusterBySpec = Some(ct match { + case c: ClusterByTransform if c.transforms.nonEmpty => + ClusterBySpec(columnNames, c.toClusteringColumnTransforms) + case _ => ClusterBySpec(columnNames) + }) case transform => throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index b29d0b3eabe56..debc91dc52755 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -164,7 +164,7 @@ private[sql] object CatalogV2Util { val clusterByProp = ClusterBySpec.toProperty( schema, - ClusterBySpec(clusterBy.clusteringColumns.toIndexedSeq), + clusterBySpecFromChange(clusterBy), conf.resolver) newProperties.put(clusterByProp._1, clusterByProp._2) @@ -190,7 +190,7 @@ private[sql] object CatalogV2Util { clusterByOpt.foreach { clusterBy => newPartitioning = partitioning.map { case _: ClusterByTransform => ClusterBySpec.extractClusterByTransform( - schema, ClusterBySpec(clusterBy.clusteringColumns.toIndexedSeq), conf.resolver) + schema, clusterBySpecFromChange(clusterBy), conf.resolver) case other => other } } @@ -198,6 +198,15 @@ private[sql] object CatalogV2Util { newPartitioning } + /** Construct a [[ClusterBySpec]] from a [[ClusterBy]] table change, including transforms. */ + private def clusterBySpecFromChange(clusterBy: ClusterBy): ClusterBySpec = { + val transforms = clusterBy.transforms().map { tOpt => + if (tOpt.isPresent) Some(tOpt.get()) else None + }.toIndexedSeq + val transformsSeq = if (transforms.forall(_.isEmpty)) Seq.empty else transforms + ClusterBySpec(clusterBy.clusteringColumns.toIndexedSeq, transformsSeq) + } + /** * Apply schema changes to a schema and return the result. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 18d94969aa27e..216b96727fdbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -157,12 +157,30 @@ private[sql] object BucketTransform { } } +/** + * Minimal description of a per-column transform applied within a CLUSTER BY expression. + * + * @param columnIndex index into [[ClusterByTransform.columnNames]] identifying the column + * being transformed. + * @param argumentIndex the index in the argument list where the bound clustering column + * should be substituted. Zero-indexed, and any arguments at or after + * this index in `arguments` should be shifted to the right by one. + * @param function canonical SQL function name (e.g. "variant_get"). + * @param arguments the non-column literal arguments to the function. + */ +case class ClusterByColumnTransform( + columnIndex: Int, + argumentIndex: Int, + function: String, + arguments: Seq[LiteralValue[_]]) + /** * This class represents a transform for `ClusterBySpec`. This is used to bundle * ClusterBySpec in CreateTable's partitioning transforms to pass it down to analyzer. */ final case class ClusterByTransform( - columnNames: Seq[NamedReference]) extends RewritableTransform { + columnNames: Seq[NamedReference], + transforms: Seq[ClusterByColumnTransform] = Seq.empty) extends RewritableTransform { override val name: String = "cluster_by" @@ -175,6 +193,20 @@ final case class ClusterByTransform( override def withReferences(newReferences: Seq[NamedReference]): Transform = { this.copy(columnNames = newReferences) } + + /** Converts [[transforms]] to a per-column `Seq[Option[Transform]]` for [[ClusterBySpec]]. */ + def toClusteringColumnTransforms: Seq[Option[Transform]] = { + columnNames.indices.map { idx => + transforms.find(_.columnIndex == idx).map { t => + val args = t.arguments.toArray + + new ClusteringColumnTransform( + t.function, + args.slice(0, t.argumentIndex) ++ + Array[Expression](columnNames(idx)) ++ args.slice(t.argumentIndex, args.length)) + } + } + } } /** @@ -190,6 +222,42 @@ object ClusterByTransform { } } +/** + * A Transform implementation that wraps a column transform expression for CLUSTER BY. + * For example, `CLUSTER BY (UPPER(col))` produces a ClusteringColumnTransform + * wrapping the UPPER function with `col` as a NamedReference argument. Used with ClusterBySpec. + */ +final class ClusteringColumnTransform( + override val name: String, + private val args: Array[Expression]) extends Transform { + + override def arguments(): Array[Expression] = args + + override def toString: String = { + s"$name(${arguments().map(_.toString).mkString(", ")})" + } + + override def describe: String = toString + + override def equals(obj: Any): Boolean = obj match { + case other: Transform => + other.name() == this.name && + java.util.Arrays.equals( + other.arguments().asInstanceOf[Array[Object]], + this.arguments().asInstanceOf[Array[Object]]) + case _ => false + } + + override def hashCode(): Int = + java.util.Objects.hash( + name, + Integer.valueOf(java.util.Arrays.hashCode(arguments().asInstanceOf[Array[Object]]))) + + override def references(): Array[NamedReference] = { + args.collect { case n: NamedReference => n } + } +} + private[sql] final case class SortedBucketTransform( numBuckets: Literal[Int], columns: Seq[NamedReference], @@ -387,7 +455,7 @@ private[sql] object HoursTransform { } } -private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { +final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { override def toString: String = dataType match { case StringType => s"'${s"$value".replace("'", "''")}'" case BinaryType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ClusterBySpecSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ClusterBySpecSuite.scala new file mode 100644 index 0000000000000..ceff792d41701 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ClusterBySpecSuite.scala @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.connector.expressions.{ClusteringColumnTransform, FieldReference, LiteralValue} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StringType, TimestampNTZType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String + +class ClusterBySpecSuite extends SparkFunSuite with SQLHelper { + + // -- column name roundtrip tests + test("fromColumnEntries with plain columns roundtrips through toColumnNames") { + val spec = ClusterBySpec(Seq(FieldReference(Seq("a")), FieldReference(Seq("b", "c")))) + val json = spec.toColumnNames + val roundtripped = ClusterBySpec.fromProperty(json) + assert(roundtripped.columnNames === spec.columnNames) + assert(roundtripped.clusteringColumnTransforms.isEmpty) + } + + // -- from expressions tests + test("fromExpressions with a plain column reference produces no transform") { + val spec = ClusterBySpec.fromExpressions( + Seq(scala.util.Right(Seq("col1")))) + assert(spec.columnNames === Seq(FieldReference(Seq("col1")))) + assert(spec.clusteringColumnTransforms.isEmpty) + } + + test("fromExpressions with a function call produces a ClusteringColumnTransform") { + val funcExpr = UnresolvedFunction( + "variant_get", + Seq( + UnresolvedAttribute(Seq("col1")), + Literal(UTF8String.fromString("$.foo"), StringType), + Literal(UTF8String.fromString("STRING"), StringType)), + isDistinct = false) + val spec = ClusterBySpec.fromExpressions(Seq(scala.util.Left(funcExpr))) + assert(spec.columnNames === Seq(FieldReference(Seq("col1")))) + assert(spec.clusteringColumnTransforms.length === 1) + val transformOpt = spec.clusteringColumnTransforms.head + assert(transformOpt.isDefined) + val transform = transformOpt.get.asInstanceOf[ClusteringColumnTransform] + assert(transform.name === "variant_get") + } + + // -- toColumnNames / fromProperty roundtrip with expressions tests + test("expression column roundtrips through toColumnNames and fromProperty") { + val transform = new ClusteringColumnTransform( + "variant_get", + Array( + FieldReference(Seq("col1")), + LiteralValue(UTF8String.fromString("$.foo"), StringType), + LiteralValue(UTF8String.fromString("STRING"), StringType))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("col1"))), + clusteringColumnTransforms = Seq(Some(transform))) + + val json = spec.toColumnNames + + // The serialized form should be a plain JSON array of arrays (no nested objects) + assert(!json.contains("{"), s"Expected flat array format but got: $json") + + val roundtripped = ClusterBySpec.fromProperty(json) + assert(roundtripped.columnNames === spec.columnNames) + assert(roundtripped.clusteringColumnTransforms.length === 1) + val rt = roundtripped.clusteringColumnTransforms.head.get + .asInstanceOf[ClusteringColumnTransform] + assert(rt.name === "variant_get") + assert(rt.arguments().length === 3) + assert(rt.arguments()(0).asInstanceOf[FieldReference].fieldNames().toSeq === Seq("col1")) + } + + test("mixed plain and expression columns roundtrip through toColumnNames and fromProperty") { + val transform = new ClusteringColumnTransform( + "upper", + Array(FieldReference(Seq("name")))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("id")), FieldReference(Seq("name"))), + clusteringColumnTransforms = Seq(None, Some(transform))) + + val roundtripped = ClusterBySpec.fromProperty(spec.toColumnNames) + + assert(roundtripped.columnNames === spec.columnNames) + assert(roundtripped.clusteringColumnTransforms.length === 2) + assert(roundtripped.clusteringColumnTransforms(0).isEmpty) // plain column + val rt = roundtripped.clusteringColumnTransforms(1).get.asInstanceOf[ClusteringColumnTransform] + assert(rt.name === "upper") + } + + test("fromProperty is backward compatible with the old plain Seq[Seq[String]] format") { + // Old format: just column name arrays + val oldFormat = """[["id"],["data"]]""" + val spec = ClusterBySpec.fromProperty(oldFormat) + assert(spec.columnNames === Seq(FieldReference(Seq("id")), FieldReference(Seq("data")))) + assert(spec.clusteringColumnTransforms.isEmpty) + } + + test("integer literal in expression roundtrips correctly") { + val transform = new ClusteringColumnTransform( + "some_func", + Array( + FieldReference(Seq("col")), + LiteralValue(42, IntegerType))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("col"))), + clusteringColumnTransforms = Seq(Some(transform))) + + val roundtripped = ClusterBySpec.fromProperty(spec.toColumnNames) + assert(roundtripped.columnNames === spec.columnNames) + val rt = roundtripped.clusteringColumnTransforms.head.get + .asInstanceOf[ClusteringColumnTransform] + assert(rt.name === "some_func") + assert(rt.arguments().length === 2) + } + + // -- string literal round-trip tests + test("string literal values are preserved through toColumnNames/fromProperty round-trip") { + val transform = new ClusteringColumnTransform( + "variant_get", + Array( + FieldReference(Seq("col")), + LiteralValue(UTF8String.fromString("$.foo"), StringType), + LiteralValue(UTF8String.fromString("STRING"), StringType))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("col"))), + clusteringColumnTransforms = Seq(Some(transform))) + + val json = spec.toColumnNames + val roundtripped = ClusterBySpec.fromProperty(json) + + val rt = roundtripped.clusteringColumnTransforms.head.get + .asInstanceOf[ClusteringColumnTransform] + assert(rt.name === "variant_get") + assert(rt.arguments().length === 3) + + // Verify the string literal values and types survive the round-trip + val lit1 = rt.arguments()(1).asInstanceOf[LiteralValue[_]] + assert(lit1.dataType === StringType) + assert(lit1.value.toString === "$.foo") + + val lit2 = rt.arguments()(2).asInstanceOf[LiteralValue[_]] + assert(lit2.dataType === StringType) + assert(lit2.value.toString === "STRING") + } + + test("string literal with embedded single quotes round-trips correctly") { + val transform = new ClusteringColumnTransform( + "some_func", + Array( + FieldReference(Seq("col")), + LiteralValue(UTF8String.fromString("it's a test"), StringType))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("col"))), + clusteringColumnTransforms = Seq(Some(transform))) + + val json = spec.toColumnNames + val roundtripped = ClusterBySpec.fromProperty(json) + + val rt = roundtripped.clusteringColumnTransforms.head.get + .asInstanceOf[ClusteringColumnTransform] + val lit = rt.arguments()(1).asInstanceOf[LiteralValue[_]] + assert(lit.dataType === StringType) + assert(lit.value.toString === "it's a test") + } + + test("string literal that looks like a column name round-trips as string, not column ref") { + // A string value like "col_name" should survive as a StringType literal, + // not be re-parsed as an UnresolvedAttribute (column reference). + val transform = new ClusteringColumnTransform( + "some_func", + Array( + FieldReference(Seq("col")), + LiteralValue(UTF8String.fromString("my_column"), StringType))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("col"))), + clusteringColumnTransforms = Seq(Some(transform))) + + val json = spec.toColumnNames + val roundtripped = ClusterBySpec.fromProperty(json) + + val rt = roundtripped.clusteringColumnTransforms.head.get + .asInstanceOf[ClusteringColumnTransform] + // The second argument should still be a LiteralValue with StringType, + // not a FieldReference (which would indicate it was parsed as a column name) + assert(rt.arguments()(1).isInstanceOf[LiteralValue[_]]) + val lit = rt.arguments()(1).asInstanceOf[LiteralValue[_]] + assert(lit.dataType === StringType) + assert(lit.value.toString === "my_column") + } + + // -- timestamp literal round-trip tests + private def testTimestampRoundTrip(microseconds: Long): Unit = { + val transform = new ClusteringColumnTransform( + "date_trunc", + Array( + FieldReference(Seq("ts_col")), + LiteralValue(microseconds, TimestampType))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("ts_col"))), + clusteringColumnTransforms = Seq(Some(transform))) + + val json = spec.toColumnNames + val roundtripped = ClusterBySpec.fromProperty(json) + + val rt = roundtripped.clusteringColumnTransforms.head.get + .asInstanceOf[ClusteringColumnTransform] + assert(rt.name === "date_trunc") + assert(rt.arguments().length === 2) + val lit = rt.arguments()(1).asInstanceOf[LiteralValue[_]] + assert(lit.dataType === TimestampType) + assert(lit.value == microseconds) + } + + test("timestamp literal round-trips correctly") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + // 2023-01-01T00:00:00Z in microseconds + testTimestampRoundTrip(1672531200000000L) + } + } + + test("timestamp literal round-trip is consistent with Java8 Date API enabled") { + withSQLConf( + SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + testTimestampRoundTrip(1672531200000000L) + } + } + + test("timestamp literal round-trip is consistent with Java8 Date API disabled") { + withSQLConf( + SQLConf.DATETIME_JAVA8API_ENABLED.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") { + testTimestampRoundTrip(1672531200000000L) + } + } + + test("timestamp literal round-trip with non-UTC timezone") { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Toronto") { + testTimestampRoundTrip(1672531200000000L) + } + } + + test("timestamp_ntz literal round-trips correctly") { + val microseconds = 1672531200000000L + val transform = new ClusteringColumnTransform( + "date_trunc", + Array( + FieldReference(Seq("ts_col")), + LiteralValue(microseconds, TimestampNTZType))) + val spec = ClusterBySpec( + columnNames = Seq(FieldReference(Seq("ts_col"))), + clusteringColumnTransforms = Seq(Some(transform))) + + val json = spec.toColumnNames + val roundtripped = ClusterBySpec.fromProperty(json) + + val rt = roundtripped.clusteringColumnTransforms.head.get + .asInstanceOf[ClusteringColumnTransform] + assert(rt.arguments().length === 2) + val lit = rt.arguments()(1).asInstanceOf[LiteralValue[_]] + assert(lit.dataType === TimestampNTZType) + assert(lit.value == microseconds) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 8791677999810..eae4265c9dafb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -279,7 +279,7 @@ abstract class InMemoryBaseTable( } case ClusterByTransform(columnNames) => columnNames.map { colName => - extractor(colName.fieldNames, cleanedSchema, row)._1 + extractor(colName.asInstanceOf[NamedReference].fieldNames, cleanedSchema, row)._1 } }.toImmutableArraySeq } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala index c0fd0a67d06aa..d960c4a9182f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableClusterBySuiteBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.expressions.Transform /** * This base suite contains unified tests for the `ALTER TABLE ... CLUSTER BY` command @@ -43,6 +44,17 @@ trait AlterTableClusterBySuiteBase extends QueryTest with DDLCommandTestUtils { def validateClusterBy(tableName: String, clusteringColumns: Seq[String]): Unit + /** + * Validates clustering columns and their associated transforms. + * @param expectedTransforms per-column transforms, where None means a plain column reference + * and Some(transform) means an expression-based clustering column. + * Must have the same length as clusteringColumns. + */ + def validateClusterBy( + tableName: String, + clusteringColumns: Seq[String], + expectedTransforms: Seq[Option[Transform]]): Unit + test("test basic ALTER TABLE with clustering columns") { withNamespaceAndTable("ns", "table") { tbl => sql(s"CREATE TABLE $tbl (id INT, data STRING) $defaultUsing CLUSTER BY (id, data)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala index cb56d11b665db..86cb09c9b225c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateTableClusterBySuiteBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.expressions.Transform /** * This base suite contains unified tests for the `CREATE/REPLACE TABLE ... CLUSTER BY` command @@ -41,6 +42,17 @@ trait CreateTableClusterBySuiteBase extends QueryTest with DDLCommandTestUtils { def validateClusterBy(tableName: String, clusteringColumns: Seq[String]): Unit + /** + * Validates clustering columns and their associated transforms. + * @param expectedTransforms per-column transforms, where None means a plain column reference + * and Some(transform) means an expression-based clustering column. + * Must have the same length as clusteringColumns. + */ + def validateClusterBy( + tableName: String, + clusteringColumns: Seq[String], + expectedTransforms: Seq[Option[Transform]]): Unit + test("test basic CREATE TABLE with clustering columns") { withNamespaceAndTable("ns", "table") { tbl => spark.sql(s"CREATE TABLE $tbl (id INT, data STRING) $defaultUsing CLUSTER BY (id, data)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableClusterBySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableClusterBySuite.scala index f5f2b7517b484..f3eb3568a51e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableClusterBySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterTableClusterBySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v1 import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.ClusterBySpec -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, Transform} import org.apache.spark.sql.execution.command /** @@ -39,6 +39,28 @@ trait AlterTableClusterBySuiteBase extends command.AlterTableClusterBySuiteBase val table = catalog.getTableMetadata(TableIdentifier(t, Some(db))) assert(table.clusterBySpec === Some(ClusterBySpec(clusteringColumns.map(FieldReference(_))))) } + + override def validateClusterBy( + tableName: String, + clusteringColumns: Seq[String], + expectedTransforms: Seq[Option[Transform]]): Unit = { + val catalog = spark.sessionState.catalog + val (_, db, t) = parseTableName(tableName) + val table = catalog.getTableMetadata(TableIdentifier(t, Some(db))) + val spec = table.clusterBySpec.get + assert(spec.columnNames === clusteringColumns.map(FieldReference(_))) + val actualTransforms = if (spec.clusteringColumnTransforms.nonEmpty) { + spec.clusteringColumnTransforms + } else { + clusteringColumns.map(_ => None) + } + assert(actualTransforms.length === expectedTransforms.length, + s"Expected ${expectedTransforms.length} transforms but got ${actualTransforms.length}") + actualTransforms.zip(expectedTransforms).foreach { case (actual, expected) => + assert(actual === expected, + s"Transform mismatch: actual=$actual, expected=$expected") + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala index 2444fe062e283..5274596b24723 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/CreateTableClusterBySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v1 import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.ClusterBySpec -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, Transform} import org.apache.spark.sql.execution.command /** @@ -39,6 +39,28 @@ trait CreateTableClusterBySuiteBase extends command.CreateTableClusterBySuiteBas val table = catalog.getTableMetadata(TableIdentifier.apply(t, Some(db))) assert(table.clusterBySpec === Some(ClusterBySpec(clusteringColumns.map(FieldReference(_))))) } + + override def validateClusterBy( + tableName: String, + clusteringColumns: Seq[String], + expectedTransforms: Seq[Option[Transform]]): Unit = { + val catalog = spark.sessionState.catalog + val (_, db, t) = parseTableName(tableName) + val table = catalog.getTableMetadata(TableIdentifier.apply(t, Some(db))) + val spec = table.clusterBySpec.get + assert(spec.columnNames === clusteringColumns.map(FieldReference(_))) + val actualTransforms = if (spec.clusteringColumnTransforms.nonEmpty) { + spec.clusteringColumnTransforms + } else { + clusteringColumns.map(_ => None) + } + assert(actualTransforms.length === expectedTransforms.length, + s"Expected ${expectedTransforms.length} transforms but got ${actualTransforms.length}") + actualTransforms.zip(expectedTransforms).foreach { case (actual, expected) => + assert(actual === expected, + s"Transform mismatch: actual=$actual, expected=$expected") + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableClusterBySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableClusterBySuite.scala index 8a74d9c3572bb..1a366af6d13bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableClusterBySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterTableClusterBySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper -import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference} +import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, Transform} import org.apache.spark.sql.execution.command /** @@ -38,6 +38,45 @@ class AlterTableClusterBySuite extends command.AlterTableClusterBySuiteBase Array(ClusterByTransform(clusteringColumns.map(FieldReference(_))))) } + // The V2 in-memory test catalog does not apply ClusterBy changes via alterTable, + // so we cannot validate transforms after ALTER TABLE in this catalog. + override def validateClusterBy( + tableName: String, + clusteringColumns: Seq[String], + expectedTransforms: Seq[Option[Transform]]): Unit = { + val (catalog, namespace, table) = parseTableName(tableName) + val catalogPlugin = spark.sessionState.catalogManager.catalog(catalog) + val partTable = catalogPlugin.asTableCatalog + .loadTable(Identifier.of(Array(namespace), table)) + .asInstanceOf[InMemoryTable] + partTable.partitioning.collectFirst { case c: ClusterByTransform => c } match { + case Some(clusterByTransform) => + val actualColumnNames = clusterByTransform.columnNames + assert(actualColumnNames.length === clusteringColumns.length) + actualColumnNames.zip(clusteringColumns).foreach { + case (actual, expectedColName) => + assert(actual.fieldNames().toSeq === Seq(expectedColName)) + } + clusteringColumns.zip(expectedTransforms).zipWithIndex.foreach { + case ((_, expectedTransform), idx) => + expectedTransform match { + case None => + assert(clusterByTransform.transforms.forall(_.columnIndex != idx), + s"Expected no transform for column at index $idx") + case Some(transform) => + val actual = clusterByTransform.transforms.find(_.columnIndex == idx) + assert(actual.isDefined, + s"Expected transform for column at index $idx") + assert(actual.get.function === transform.name(), + s"Transform name mismatch for column at index $idx") + } + } + case None => + // After ALTER TABLE CLUSTER BY, the in-memory V2 catalog may not have updated + // partitioning. Just verify the SQL executed without error. + } + } + test("test REPLACE TABLE with clustering columns") { withNamespaceAndTable("ns", "table") { tbl => sql(s"CREATE TABLE $tbl (id INT) $defaultUsing CLUSTER BY (id)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala index 86b14d6680388..951874f2d7293 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateTableClusterBySuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryPartitionTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper -import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference} +import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, Transform} import org.apache.spark.sql.execution.command /** @@ -38,6 +38,40 @@ class CreateTableClusterBySuite extends command.CreateTableClusterBySuiteBase Array(ClusterByTransform(clusteringColumns.map(FieldReference(_))))) } + override def validateClusterBy( + tableName: String, + clusteringColumns: Seq[String], + expectedTransforms: Seq[Option[Transform]]): Unit = { + val (catalog, namespace, table) = parseTableName(tableName) + val catalogPlugin = spark.sessionState.catalogManager.catalog(catalog) + val partTable = catalogPlugin.asTableCatalog + .loadTable(Identifier.of(Array(namespace), table)) + .asInstanceOf[InMemoryPartitionTable] + val clusterByTransform = partTable.partitioning + .collectFirst { case c: ClusterByTransform => c } + .getOrElse(fail("No ClusterByTransform found in partitioning")) + val actualColumnNames = clusterByTransform.columnNames + assert(actualColumnNames.length === clusteringColumns.length) + actualColumnNames.zip(clusteringColumns).foreach { + case (actual, expectedColName) => + assert(actual.fieldNames().toSeq === Seq(expectedColName)) + } + clusteringColumns.zip(expectedTransforms).zipWithIndex.foreach { + case ((_, expectedTransform), idx) => + expectedTransform match { + case None => + assert(clusterByTransform.transforms.forall(_.columnIndex != idx), + s"Expected no transform for column at index $idx") + case Some(transform) => + val actual = clusterByTransform.transforms.find(_.columnIndex == idx) + assert(actual.isDefined, + s"Expected transform for column at index $idx") + assert(actual.get.function === transform.name(), + s"Transform name mismatch for column at index $idx") + } + } + } + test("test REPLACE TABLE with clustering columns") { withNamespaceAndTable("ns", "table") { tbl => spark.sql(s"CREATE TABLE $tbl (id INT) $defaultUsing CLUSTER BY (id)")