Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -6980,6 +6980,18 @@
],
"sqlState" : "428EK"
},
"THETA_FAMILY_MUST_BE_CONSTANT" : {
"message" : [
"Invalid call to <function>; the `family` value must be a constant value, but got a non-constant expression."
],
"sqlState" : "42K0E"
},
"THETA_INVALID_FAMILY" : {
"message" : [
"Invalid call to <function>; the `family` parameter must be one of: <validFamilies>, but got: <value>."
],
"sqlState" : "22546"
},
"THETA_INVALID_INPUT_SKETCH_BUFFER" : {
"message" : [
"Invalid call to <function>; only valid Theta sketch buffers are supported as inputs (such as those produced by the `theta_sketch_agg` function)."
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4594,12 +4594,12 @@ def hll_union(
def theta_sketch_agg(
col: "ColumnOrName",
lgNomEntries: Optional[Union[int, Column]] = None,
family: Optional[str] = None,
) -> Column:
fn = "theta_sketch_agg"
if lgNomEntries is None:
return _invoke_function_over_columns(fn, col)
else:
return _invoke_function_over_columns(fn, col, lit(lgNomEntries))
_lgNomEntries = lit(12) if lgNomEntries is None else lit(lgNomEntries)
_family = lit("QUICKSELECT") if family is None else lit(family)
return _invoke_function_over_columns(fn, col, _lgNomEntries, _family)


theta_sketch_agg.__doc__ = pysparkfuncs.theta_sketch_agg.__doc__
Expand Down
50 changes: 30 additions & 20 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26786,10 +26786,12 @@ def hll_union(
def theta_sketch_agg(
col: "ColumnOrName",
lgNomEntries: Optional[Union[int, Column]] = None,
family: Optional[str] = None,
) -> Column:
"""
Aggregate function: returns the compact binary representation of the Datasketches
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries.
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries
and the specified sketch family.

.. versionadded:: 4.1.0

Expand All @@ -26799,6 +26801,8 @@ def theta_sketch_agg(
lgNomEntries : :class:`~pyspark.sql.Column` or int, optional
The log-base-2 of nominal entries, where nominal entries is the size of the sketch
(must be between 4 and 26, defaults to 12)
family : str, optional
The sketch family: 'QUICKSELECT' or 'ALPHA' (defaults to 'QUICKSELECT').

Returns
-------
Expand All @@ -26819,24 +26823,30 @@ def theta_sketch_agg(
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([1,2,2,3], "INT")
>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+

>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+

>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15, "ALPHA"))).show()
+---------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15, ALPHA))|
+---------------------------------------------------------+
| 3|
+---------------------------------------------------------+
"""
fn = "theta_sketch_agg"
if lgNomEntries is None:
return _invoke_function_over_columns(fn, col)
else:
return _invoke_function_over_columns(fn, col, lit(lgNomEntries))
_lgNomEntries = lit(12) if lgNomEntries is None else lit(lgNomEntries)
_family = lit("QUICKSELECT") if family is None else lit(family)
return _invoke_function_over_columns(fn, col, _lgNomEntries, _family)


@_try_remote_functions
Expand Down Expand Up @@ -28027,11 +28037,11 @@ def theta_sketch_estimate(col: "ColumnOrName") -> Column:
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([1,2,2,3], "INT")
>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+
"""

fn = "theta_sketch_estimate"
Expand Down
31 changes: 31 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,17 @@ object functions {
def theta_sketch_agg(e: Column, lgNomEntries: Column): Column =
Column.fn("theta_sketch_agg", e, lgNomEntries)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column and configured with the `lgNomEntries` nominal
* entries and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, lgNomEntries: Column, family: Column): Column =
Column.fn("theta_sketch_agg", e, lgNomEntries, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column and configured with the `lgNomEntries` nominal
Expand Down Expand Up @@ -1319,6 +1330,26 @@ object functions {
def theta_sketch_agg(columnName: String): Column =
theta_sketch_agg(Column(columnName))

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with `lgNomEntries` and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, lgNomEntries: Int, family: String): Column =
Column.fn("theta_sketch_agg", e, lit(lgNomEntries), lit(family))

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with `lgNomEntries` and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(columnName: String, lgNomEntries: Int, family: String): Column =
theta_sketch_agg(Column(columnName), lgNomEntries, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches
* ThetaSketch, generated by the union of Datasketches ThetaSketch instances in the input column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.datasketches.common.Family
import org.apache.datasketches.memory.Memory
import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Sketch, Union, UpdateSketch, UpdateSketchBuilder}

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.StringTypeWithCollation
Expand Down Expand Up @@ -59,10 +59,12 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
*
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information.
*
* @param left
* @param first
* child expression against which unique counting will occur
* @param right
* @param second
* the log-base-2 of nomEntries decides the number of buckets for the sketch
* @param third
* the family of the sketch (QUICKSELECT or ALPHA)
* @param mutableAggBufferOffset
* offset for mutable aggregation buffer
* @param inputAggBufferOffset
Expand All @@ -71,49 +73,73 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation.
_FUNC_(expr, lgNomEntries, family) - Returns the ThetaSketch compact binary representation.
`lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding
the number buckets or slots for the ThetaSketch. """,
the number buckets or slots for the ThetaSketch.
`family` (optional) is the sketch family, either 'QUICKSELECT' or 'ALPHA' (defaults to
'QUICKSELECT').""",
examples = """
Examples:
> SELECT theta_sketch_estimate(_FUNC_(col)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
> SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
> SELECT theta_sketch_estimate(_FUNC_(col, 15, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
""",
group = "agg_funcs",
since = "4.1.0")
// scalastyle:on line.size.limit
case class ThetaSketchAgg(
left: Expression,
right: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
first: Expression,
second: Expression,
third: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[ThetaSketchState]
with BinaryLike[Expression]
with TernaryLike[Expression]
with ExpectsInputTypes {

// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.

lazy val lgNomEntries: Int = {
if (!right.foldable) {

private lazy val lgNomEntries: Int = {
if (!second.foldable) {
throw QueryExecutionErrors.thetaLgNomEntriesMustBeConstantError(prettyName)
}
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
val lgNomEntriesInput = second.eval().asInstanceOf[Int]
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
lgNomEntriesInput
}

// Constructors
private lazy val family: Family = {
if (!third.foldable) {
throw QueryExecutionErrors.thetaFamilyMustBeConstantError(prettyName)
Copy link
Copy Markdown
Contributor

@cboumalh cboumalh Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we can do this check at the analysis step in checkInputDataTypes to fail early, instead of at runtime. Same applies to the lgNomEntries. I can create a follow up, or we can make the modifications here if other reviewers see it fit.

}
val familyName = third.eval().asInstanceOf[UTF8String]
ThetaSketchUtils.parseFamily(familyName.toString, prettyName)
}

def this(child: Expression) = {
this(child, Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), 0, 0)
this(child,
Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS),
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
0, 0)
}

def this(child: Expression, lgNomEntries: Expression) = {
this(child, lgNomEntries, 0, 0)
this(child,
lgNomEntries,
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
0, 0)
}

def this(child: Expression, lgNomEntries: Expression, family: Expression) = {
this(child, lgNomEntries, family, 0, 0)
}

def this(child: Expression, lgNomEntries: Int) = {
this(child, Literal(lgNomEntries), 0, 0)
this(child, Literal(lgNomEntries))
}

// Copy constructors required by ImperativeAggregate
Expand All @@ -124,16 +150,11 @@ case class ThetaSketchAgg(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAgg =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression): ThetaSketchAgg =
copy(left = newLeft, right = newRight)

// Overrides for TypedImperativeAggregate

override def prettyName: String = "theta_sketch_agg"

override def inputTypes: Seq[AbstractDataType] =
override def inputTypes: Seq[AbstractDataType] = {
Seq(
TypeCollection(
ArrayType(IntegerType),
Expand All @@ -144,21 +165,24 @@ case class ThetaSketchAgg(
IntegerType,
LongType,
StringTypeWithCollation(supportsTrimCollation = true)),
IntegerType)
IntegerType,
StringType)
}

override def dataType: DataType = BinaryType

override def nullable: Boolean = false

/**
* Instantiate an UpdateSketch instance using the lgNomEntries param.
* Instantiate an UpdateSketch instance using the lgNomEntries and family params.
*
* @return
* an UpdateSketch instance wrapped with UpdatableSketchBuffer
*/
override def createAggregationBuffer(): ThetaSketchState = {
val builder = new UpdateSketchBuilder
builder.setLogNominalEntries(lgNomEntries)
builder.setFamily(family)
UpdatableSketchBuffer(builder.build)
}

Expand All @@ -179,7 +203,7 @@ case class ThetaSketchAgg(
*/
override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = {
// Return early for null values.
val v = left.eval(input)
val v = first.eval(input)
if (v == null) return updateBuffer

// Initialized buffer should be UpdatableSketchBuffer, else error out.
Expand All @@ -189,7 +213,7 @@ case class ThetaSketchAgg(
}

// Handle the different data types for sketch updates.
left.dataType match {
first.dataType match {
case ArrayType(IntegerType, _) =>
val arr = v.asInstanceOf[ArrayData].toIntArray()
sketch.update(arr)
Expand All @@ -216,7 +240,7 @@ case class ThetaSketchAgg(
case _ =>
throw new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_3121",
messageParameters = Map("dataType" -> left.dataType.toString))
messageParameters = Map("dataType" -> first.dataType.toString))
}

updateBuffer
Expand Down Expand Up @@ -292,6 +316,9 @@ case class ThetaSketchAgg(
this.createAggregationBuffer()
}
}

override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression,
newThird: Expression): ThetaSketchAgg = copy(newFirst, newSecond, newThird)
}

/**
Expand Down Expand Up @@ -334,6 +361,7 @@ case class ThetaUnionAgg(

// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.


lazy val lgNomEntries: Int = {
if (!right.foldable) {
throw QueryExecutionErrors.thetaLgNomEntriesMustBeConstantError(prettyName)
Expand Down
Loading