Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,12 @@
],
"sqlState" : "0A000"
},
"CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE": {
"message" : [
"CLUSTER BY expression <expressionType> has either no column reference, or a column reference in an unsupported argument position."
],
"sqlState" : "42000"
},
"CLUSTERING_COLUMNS_MISMATCH" : {
"message" : [
"Specified clustering does not match that of the existing table <tableName>.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,17 @@ replaceTableHeader
: (CREATE OR)? REPLACE TABLE identifierReference
;

expressionOrMultipartIdentifier
: expression
| multipartIdentifier
;

expressionOrMultipartIdentifierList
: expressionOrMultipartIdentifier (COMMA expressionOrMultipartIdentifier)*
;

clusterBySpec
: CLUSTER BY LEFT_PAREN multipartIdentifierList RIGHT_PAREN
: CLUSTER BY LEFT_PAREN expressionOrMultipartIdentifierList RIGHT_PAREN
;

bucketSpec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.Nullable;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.catalog.constraints.Constraint;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.types.DataType;

/**
Expand Down Expand Up @@ -277,8 +279,10 @@ static TableChange deleteColumn(String[] fieldNames, Boolean ifExists) {
* field names.
* @return a TableChange for this assignment
*/
static TableChange clusterBy(NamedReference[] clusteringColumns) {
return new ClusterBy(clusteringColumns);
static TableChange clusterBy(
NamedReference[] clusteringColumns,
Optional<Transform>[] transforms) {
return new ClusterBy(clusteringColumns, transforms);
}

/**
Expand Down Expand Up @@ -873,24 +877,30 @@ public String toString() {
/** A TableChange to alter clustering columns for a table. */
final class ClusterBy implements TableChange {
private final NamedReference[] clusteringColumns;
private final Optional<Transform>[] transforms;

private ClusterBy(NamedReference[] clusteringColumns) {
private ClusterBy(
NamedReference[] clusteringColumns,
Optional<Transform>[] transforms) {
this.clusteringColumns = clusteringColumns;
this.transforms = transforms;
}

public NamedReference[] clusteringColumns() { return clusteringColumns; }
public Optional<Transform>[] transforms() { return transforms; }

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClusterBy that = (ClusterBy) o;
return Arrays.equals(clusteringColumns, that.clusteringColumns());
return Arrays.equals(clusteringColumns, that.clusteringColumns())
&& Arrays.equals(transforms, that.transforms);
}

@Override
public int hashCode() {
return Arrays.hashCode(clusteringColumns);
return Objects.hash(Arrays.hashCode(clusteringColumns), Arrays.hashCode(transforms));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CurrentUserContext, FunctionIdentifier, InternalRow, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NormalizeableRelation, Resolver, SchemaBinding, SchemaCompensation, SchemaEvolution, SchemaTypeEvolution, SchemaUnsupported, UnresolvedLeafNode, ViewSchemaMode}
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NormalizeableRelation, Resolver, SchemaBinding, SchemaCompensation, SchemaEvolution, SchemaTypeEvolution, SchemaUnsupported, UnresolvedAttribute, UnresolvedFunction, UnresolvedLeafNode, ViewSchemaMode}
import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, Expression, ExprId, Literal}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, NamedReference, Transform}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.ClusterByHelper
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, ClusteringColumnTransform, FieldReference, LiteralValue, NamedReference, Transform}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -270,11 +272,38 @@ case class CatalogTablePartition(
* A container for clustering information.
*
* @param columnNames the names of the columns used for clustering.
* @param clusteringColumnTransforms per-column transforms for expression-based clustering.
* When non-empty, each element corresponds to a column in
* columnNames: None means a plain column reference,
* Some(transform) means an expression like UPPER(col).
* An empty Seq means no transforms on any columns.
*/
case class ClusterBySpec(columnNames: Seq[NamedReference]) {
case class ClusterBySpec(
columnNames: Seq[NamedReference],
clusteringColumnTransforms: Seq[Option[Transform]] = Seq.empty) {
override def toString: String = toJson

def toJson: String = ClusterBySpec.mapper.writeValueAsString(columnNames.map(_.fieldNames))
def toJson: String = toColumnNames

def toColumnNames: String = {
val entries: Seq[Seq[String]] = if (clusteringColumnTransforms.isEmpty) {
columnNames.map(_.fieldNames().toSeq)
} else {
columnNames.zip(clusteringColumnTransforms).map {
case (colName, None) => colName.fieldNames().toSeq
case (colName, Some(transform)) =>
val args = transform.arguments().map {
case n: NamedReference => n.fieldNames().map(QuotingUtils.quoteIfNeeded).mkString(".")
case LiteralValue(value, dataType) =>
Literal(value, dataType).sql
case other => throw new IllegalStateException(
s"Unexpected argument type in CLUSTER BY expression: ${other.getClass}")
}
Seq(s"${transform.name()}(${args.mkString(",")})")
}
}
ClusterBySpec.mapper.writeValueAsString(entries)
}
}

object ClusterBySpec {
Expand All @@ -290,7 +319,104 @@ object ClusterBySpec {
* Converts the clustering column property to a ClusterBySpec.
*/
def fromProperty(columns: String): ClusterBySpec = {
ClusterBySpec(mapper.readValue[Seq[Seq[String]]](columns).map(FieldReference(_)))
ClusterBySpec.fromColumnEntries(mapper.readValue[Seq[Seq[String]]](columns))
}

/**
* Constructs a [[ClusterBySpec]] from the stored column entries (each a Seq[String]).
* An entry is either a multi-part column name or a single-element Seq containing an
* expression string like "variant_get(col,'$.foo','STRING')".
*/
def fromColumnEntries(entries: Seq[Seq[String]]): ClusterBySpec = {
val parsedCols: Seq[(NamedReference, Option[Transform])] = entries.map {
case parts if parts.length == 1 && mightBeExpression(parts.head) =>
// Expression form: "funcName(col, arg1, arg2, ...)"
CatalystSqlParser.parseExpression(parts.head) match {
case u: UnresolvedFunction =>
val transform: Transform = new ClusteringColumnTransform(
u.nameParts.mkString("."),
u.children.map {
case a: UnresolvedAttribute =>
FieldReference(a.nameParts.map(QuotingUtils.quoteIfNeeded))
case l: Literal => LiteralValue(l.value, l.dataType)
case other => throw new IllegalStateException(
s"Unexpected argument type in CLUSTER BY expression: ${other.getClass}")
}.toArray)
val colRef = transform.arguments().collectFirst {
case f: FieldReference => f
}.getOrElse(throw new IllegalStateException(
"CLUSTER BY expression must contain exactly one column reference"))
(colRef, Some(transform))
case other =>
// Plain multi-part column name that had a false positive with
// mightBeExpression.
(FieldReference(parts), None)
}
case parts =>
// Plain multi-part column name
(FieldReference(parts), None)
}
val (colNames, transforms) = parsedCols.unzip
val transformsSeq = if (transforms.forall(_.isEmpty)) Seq.empty else transforms
ClusterBySpec(colNames, transformsSeq)
}

/**
* Returns true if the string looks like a function call expression (contains parentheses),
* rather than a plain column identifier.
*/
private def mightBeExpression(s: String): Boolean = s.contains("(")

def fromExpressions(
parsedCols: Seq[Either[Expression, Seq[String]]]): ClusterBySpec = {
val (clusteringColumnNames, clusteringColumnExpressions) = parsedCols.map {
case Left(e) =>
e match {
// A bare column reference parsed as an expression - treat as plain column.
case a: UnresolvedAttribute =>
(FieldReference(a.nameParts), None)
case u: UnresolvedFunction =>
val transform = new ClusteringColumnTransform(
u.nameParts.mkString("."),
u.children.map {
case a: UnresolvedAttribute => FieldReference(a.nameParts)
case l: Literal => LiteralValue(l.value, l.dataType)
case _ => throw new IllegalStateException(
"Unsupported expression argument in CLUSTER BY transform")
}.toArray)
val transformName = u.nameParts.mkString(".")
val refs = transform.arguments().collect {
case f: FieldReference => f
}
if (refs.isEmpty) {
throw new AnalysisException(
errorClass = "CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE",
messageParameters = Map("expressionType" -> transformName))
}
if (refs.length != 1) {
throw new AnalysisException(
errorClass = "CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE",
messageParameters = Map("expressionType" -> transformName))
}
if (!transform.arguments().head.isInstanceOf[FieldReference]) {
throw new AnalysisException(
errorClass = "CLUSTER_BY_EXPRESSION_INCORRECT_COLUMN_REFERENCE",
messageParameters = Map(
"expressionType" -> transformName))
}
(FieldReference(refs.head.fieldNames.toIndexedSeq), Some(transform))
case _ => throw new IllegalStateException(
"Unsupported expression in CLUSTER BY: only function calls are supported")
}
case Right(names) => (FieldReference(names), None)
}.unzip
// If there are no transforms at all (all plain columns), use empty Seq for backward compat
val transformsSeq = if (clusteringColumnExpressions.forall(_.isEmpty)) {
Seq.empty
} else {
clusteringColumnExpressions
}
ClusterBySpec(clusteringColumnNames, transformsSeq)
}

/**
Expand Down Expand Up @@ -339,12 +465,17 @@ object ClusterBySpec {
normalizedColumns.map(_.toString),
resolver)

ClusterBySpec(normalizedColumns)
ClusterBySpec(normalizedColumns, clusterBySpec.clusteringColumnTransforms)
}

def extractClusterBySpec(transforms: Seq[Transform]): Option[ClusterBySpec] = {
transforms.collectFirst {
case ClusterByTransform(columnNames) => ClusterBySpec(columnNames)
case ct: ClusterByTransform =>
if (ct.transforms.nonEmpty) {
ClusterBySpec(ct.columnNames, ct.toClusteringColumnTransforms)
} else {
ClusterBySpec(ct.columnNames)
}
}
}

Expand All @@ -353,7 +484,7 @@ object ClusterBySpec {
clusterBySpec: ClusterBySpec,
resolver: Resolver): ClusterByTransform = {
val normalizedClusterBySpec = normalizeClusterBySpec(schema, clusterBySpec, resolver)
ClusterByTransform(normalizedClusterBySpec.columnNames)
new ClusterByHelper(normalizedClusterBySpec).asTransform.asInstanceOf[ClusterByTransform]
}

def fromColumnNames(names: Seq[String]): ClusterBySpec = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4741,13 +4741,26 @@ class AstBuilder extends DataTypeAstBuilder
})
}

override def visitExpressionOrMultipartIdentifier(
ctx: ExpressionOrMultipartIdentifierContext): Either[Expression, Seq[String]] =
withOrigin(ctx) {
if (ctx.expression() != null) {
scala.util.Left(expression(ctx.expression()))
} else {
scala.util.Right(typedVisit[Seq[String]](ctx.multipartIdentifier()))
}
}

/**
* Create a [[ClusterBySpec]].
*/
override def visitClusterBySpec(ctx: ClusterBySpecContext): ClusterBySpec = withOrigin(ctx) {
val columnNames = ctx.multipartIdentifierList.multipartIdentifier.asScala
.map(typedVisit[Seq[String]]).map(FieldReference(_)).toSeq
ClusterBySpec(columnNames)
val columnReferences =
ctx.expressionOrMultipartIdentifierList.expressionOrMultipartIdentifier
.asScala
.map(visitExpressionOrMultipartIdentifier)
.toSeq
ClusterBySpec.fromExpressions(columnReferences)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{CheckConstraint, Expression, T
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.connector.catalog.{DefaultValue, TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -309,9 +310,19 @@ case class AlterColumns(
case class AlterTableClusterBy(
table: LogicalPlan, clusterBySpec: Option[ClusterBySpec]) extends AlterTableCommand {
override def changes: Seq[TableChange] = {
val clusterByTransforms = clusterBySpec.map { spec =>
if (spec.clusteringColumnTransforms.nonEmpty) {
spec.clusteringColumnTransforms.map {
case None => java.util.Optional.empty[Transform]()
case Some(transform) => java.util.Optional.of[Transform](transform)
}.toArray
} else {
spec.columnNames.map(_ => java.util.Optional.empty[Transform]()).toArray
}
}.getOrElse(Array.empty[java.util.Optional[Transform]])
Seq(TableChange.clusterBy(clusterBySpec
.map(_.columnNames.toArray) // CLUSTER BY (col1, col2, ...)
.getOrElse(Array.empty)))
.getOrElse(Array.empty), clusterByTransforms))
}

protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = copy(table = newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, quoteNameParts, QuotingUtils}
import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByColumnTransform, ClusterByTransform, FieldReference, IdentityTransform, LiteralValue, LogicalExpressions, NamedReference, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -56,7 +56,17 @@ private[sql] object CatalogV2Implicits {
}

implicit class ClusterByHelper(spec: ClusterBySpec) {
def asTransform: Transform = clusterBy(spec.columnNames.toArray)
def asTransform: Transform = {
val transforms = spec.clusteringColumnTransforms.zipWithIndex.collect {
case (Some(t), colIdx) =>
ClusterByColumnTransform(
columnIndex = colIdx,
argumentIndex = t.arguments().indexWhere(_.isInstanceOf[NamedReference]),
function = t.name(),
arguments = t.arguments().collect { case l: LiteralValue[_] => l }.toSeq)
}
ClusterByTransform(spec.columnNames, transforms)
}
}

implicit class TransformHelper(transforms: Seq[Transform]) {
Expand All @@ -80,12 +90,16 @@ private[sql] object CatalogV2Implicits {
sortCol.map(_.fieldNames.mkString("."))))
}

case ClusterByTransform(columnNames) =>
case ct @ ClusterByTransform(columnNames) =>
if (clusterBySpec.nonEmpty) {
// AstBuilder guarantees that it only passes down one ClusterByTransform.
throw SparkException.internalError("Cannot have multiple cluster by transforms.")
}
clusterBySpec = Some(ClusterBySpec(columnNames))
clusterBySpec = Some(ct match {
case c: ClusterByTransform if c.transforms.nonEmpty =>
ClusterBySpec(columnNames, c.toClusteringColumnTransforms)
case _ => ClusterBySpec(columnNames)
})

case transform =>
throw QueryExecutionErrors.unsupportedPartitionTransformError(transform)
Expand Down
Loading