From c7d79bc08b97a1ebf5d6f39db6a4e4a3a811c1f9 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 3 Apr 2026 20:22:53 +0000 Subject: [PATCH 1/5] feat(bigframes): Add substrait-datafusion engine --- .../core/compile/substrait/__init__.py | 19 + .../core/compile/substrait/compiler.py | 334 ++ .../bigframes/session/substrait_executor.py | 132 + .../bigframes/testing/substrait_session.py | 112 + .../unit/session/test_substrait_executor.py | 135 + .../tests/unit/test_dataframe_substrait.py | 4492 +++++++++++++++++ 6 files changed, 5224 insertions(+) create mode 100644 packages/bigframes/bigframes/core/compile/substrait/__init__.py create mode 100644 packages/bigframes/bigframes/core/compile/substrait/compiler.py create mode 100644 packages/bigframes/bigframes/session/substrait_executor.py create mode 100644 packages/bigframes/bigframes/testing/substrait_session.py create mode 100644 packages/bigframes/tests/unit/session/test_substrait_executor.py create mode 100644 packages/bigframes/tests/unit/test_dataframe_substrait.py diff --git a/packages/bigframes/bigframes/core/compile/substrait/__init__.py b/packages/bigframes/bigframes/core/compile/substrait/__init__.py new file mode 100644 index 000000000000..13021b2fec3f --- /dev/null +++ b/packages/bigframes/bigframes/core/compile/substrait/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from .compiler import SubstraitCompiler + +__all__ = ["SubstraitCompiler"] diff --git a/packages/bigframes/bigframes/core/compile/substrait/compiler.py b/packages/bigframes/bigframes/core/compile/substrait/compiler.py new file mode 100644 index 000000000000..755c232d92c4 --- /dev/null +++ b/packages/bigframes/bigframes/core/compile/substrait/compiler.py @@ -0,0 +1,334 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import Any, Dict, Optional + +import substrait.algebra_pb2 as algebra_pb2 +import substrait.plan_pb2 as plan_pb2 +from google.protobuf import json_format + +from bigframes.core import bigframe_node, nodes +import bigframes.core.expression as ex +import pandas as pd + + +class SubstraitCompiler: + """ + Compiles BigFrameNode plans to Substrait schema (JSON representation). + """ + + def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: + """ + Compiles a BigFrameNode to Substrait bytes (JSON encoded via protobuf). + """ + if not self.can_compile(plan): + return None + + plan_dict = self._compile_node(plan) + + pb_plan = plan_pb2.Plan() + pb_plan.version.minor_number = 42 + + plan_rel = pb_plan.relations.add() + json_format.ParseDict(plan_dict, plan_rel.root.input) + + plan_rel.root.names.extend([item.column for item in plan.schema.items]) + + extensions = [ + ("add", 1), + ("sub", 2), + ("mul", 3), + ("div", 4), + ("eq", 5), + ("ne", 6), + ("lt", 7), + ("gt", 8), + ("le", 9), + ("ge", 10), + ] + for name, anchor in extensions: + ext = pb_plan.extensions.add() + ext.extension_function.function_anchor = anchor + ext.extension_function.name = name + + return json_format.MessageToJson(pb_plan).encode('utf-8') + + def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: + """ + Checks if the plan can be compiled to Substrait. + For the skeleton, we support ReadLocalNode, SelectionNode, and FilterNode. + """ + supported_nodes = ( + nodes.ReadLocalNode, + nodes.SelectionNode, + nodes.FilterNode, + nodes.SliceNode, + nodes.ProjectionNode, + nodes.JoinNode, + nodes.AggregateNode, + ) + return all(isinstance(node, supported_nodes) for node in plan.unique_nodes()) + + def _compile_node(self, node: bigframe_node.BigFrameNode) -> Dict[str, Any]: + if isinstance(node, nodes.ReadLocalNode): + return self._compile_read(node) + elif isinstance(node, nodes.SelectionNode): + return self._compile_selection(node) + elif isinstance(node, nodes.FilterNode): + return self._compile_filter(node) + elif isinstance(node, nodes.SliceNode): + return self._compile_slice(node) + elif isinstance(node, nodes.ProjectionNode): + return self._compile_projection(node) + elif isinstance(node, nodes.JoinNode): + return self._compile_join(node) + elif isinstance(node, nodes.AggregateNode): + return self._compile_aggregate(node) + else: + raise NotImplementedError(f"Node type {type(node)} not supported in Substrait compiler yet") + + def _compile_read(self, node: nodes.ReadLocalNode) -> Dict[str, Any]: + table_name = f"table_{node.local_data_source.id.hex}" + + rel = algebra_pb2.Rel() + read_rel = rel.read + read_rel.named_table.names.append(table_name) + + schema_dict = self._convert_schema(node.local_data_source.schema) + json_format.ParseDict(schema_dict, read_rel.base_schema) + + return json_format.MessageToDict(rel, preserving_proto_field_name=True) + + def _compile_selection(self, node: nodes.SelectionNode) -> Dict[str, Any]: + # Selection usually maps to ProjectRel or FilterRel depending on if it filters or just selects columns. + # If it's just column selection (Projection), it's a ProjectRel. + # Let's assume it's a ProjectRel for now. + input_rel = self._compile_node(node.child) + return { + "project": { + "input": input_rel, + "expressions": [ + # Skeletal expression mapping + {"selection": {"direct_reference": {"struct_field": {"field": i}}}} + for i in range(len(node.schema)) + ] + } + } + + def _compile_filter(self, node: nodes.FilterNode) -> Dict[str, Any]: + input_rel = self._compile_node(node.child) + condition_rel = self._compile_expression(node.condition, node.child) + return { + "filter": { + "input": input_rel, + "condition": condition_rel + } + } + + def _compile_slice(self, node: nodes.SliceNode) -> Dict[str, Any]: + input_rel = self._compile_node(node.child) + count = node.stop if node.stop is not None else -1 + offset = node.start if node.start is not None else 0 + + return { + "fetch": { + "input": input_rel, + "offset": offset, + "count": count + } + } + + def _compile_projection(self, node: nodes.ProjectionNode) -> Dict[str, Any]: + input_rel_dict = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + project_rel = rel.project + + json_format.ParseDict(input_rel_dict, project_rel.input) + + # DataFusion ProjectRel seems to be additive (appends to input). + # So we don't need to add passthrough expressions for input fields. + + # Add new assignments + for expr, _ in node.assignments: + expr_dict = self._compile_expression(expr, node.child) + expr_pb = project_rel.expressions.add() + json_format.ParseDict(expr_dict, expr_pb) + + return json_format.MessageToDict(rel, preserving_proto_field_name=True) + + def _compile_join(self, node: nodes.JoinNode) -> Dict[str, Any]: + left_rel = self._compile_node(node.left_child) + right_rel = self._compile_node(node.right_child) + + type_map = { + "inner": "JOIN_TYPE_INNER", + "left": "JOIN_TYPE_LEFT", + "right": "JOIN_TYPE_RIGHT", + "outer": "JOIN_TYPE_OUTER", + "cross": "JOIN_TYPE_CROSS", + } + join_type = type_map.get(node.type, "JOIN_TYPE_UNSPECIFIED") + + left_len = len(node.left_child.schema) + + eq_expressions = [] + for left_deref, right_deref in node.conditions: + left_idx = list(node.left_child.ids).index(left_deref.id) + right_idx = list(node.right_child.ids).index(right_deref.id) + left_len + + eq_expressions.append({ + "scalar_function": { + "function_reference": 0, + "arguments": [ + {"value": {"selection": {"direct_reference": {"struct_field": {"field": left_idx}}}}}, + {"value": {"selection": {"direct_reference": {"struct_field": {"field": right_idx}}}}} + ] + } + }) + + if len(eq_expressions) > 1: + expr = eq_expressions[0] + elif len(eq_expressions) == 1: + expr = eq_expressions[0] + else: + expr = {"literal": {"boolean": True}} + + return { + "join": { + "left": left_rel, + "right": right_rel, + "expression": expr, + "type": join_type + } + } + + def _compile_aggregate(self, node: nodes.AggregateNode) -> Dict[str, Any]: + input_rel = self._compile_node(node.child) + + groupings = [] + grouping_expressions = [] + for deref in node.by_column_ids: + idx = list(node.child.ids).index(deref.id) + grouping_expressions.append({"selection": {"direct_reference": {"struct_field": {"field": idx}}}}) + if grouping_expressions: + groupings.append({"grouping_expressions": grouping_expressions}) + + measures = [] + for agg, _ in node.aggregations: + func_ref = 1 if "Sum" in type(agg).__name__ else 2 + args = [] + if hasattr(agg, "column_references"): + for col_id in agg.column_references: + try: + idx = list(node.child.ids).index(col_id) + args.append({"value": {"selection": {"direct_reference": {"struct_field": {"field": idx}}}}}) + except ValueError: + pass + measures.append({ + "measure": { + "function_reference": func_ref, + "arguments": args + } + }) + + return { + "aggregate": { + "input": input_rel, + "groupings": groupings, + "measures": measures + } + } + + def _compile_expression(self, expr: ex.Expression, child: nodes.BigFrameNode) -> Dict[str, Any]: + if isinstance(expr, ex.ScalarConstantExpression): + val = expr.value + if isinstance(val, int): + return {"literal": {"i64": val}} + elif isinstance(val, float): + return {"literal": {"fp64": val}} + elif isinstance(val, str): + return {"literal": {"string": val}} + elif pd.isna(val): + return {"literal": {"null": {"varchar": {"length": 0}}}} + else: + return {"literal": {"string": str(val)}} + + elif isinstance(expr, ex.DerefOp): + try: + # print(f"DerefOp: id={expr.id}, child.ids={list(child.ids)}") # Debug + idx = list(child.ids).index(expr.id) + return {"selection": {"direct_reference": {"struct_field": {"field": idx}}}} + except ValueError: + raise ValueError(f"Column {expr.id} not found in child schema") + + elif isinstance(expr, ex.OpExpression): + op_name = expr.op.name + op_mapping = { + "add": 1, + "sub": 2, + "mul": 3, + "div": 4, + "eq": 5, + "ne": 6, + "lt": 7, + "gt": 8, + "le": 9, + "ge": 10, + } + if op_name not in op_mapping: + raise NotImplementedError(f"Operation {op_name} not supported in Substrait compiler yet") + func_ref = op_mapping[op_name] + + args = [self._compile_expression(arg, child) for arg in expr.inputs] + return { + "scalar_function": { + "function_reference": func_ref, + "arguments": [{"value": arg} for arg in args] + } + } + else: + raise NotImplementedError(f"Expression type {type(expr)} not supported in Substrait compiler yet") + + def _convert_schema(self, schema: Any) -> Dict[str, Any]: + # Convert bigframes schema to Substrait Type.NamedStruct + fields = [] + types = [] + for item in schema.items: + col = item.column + name = col.name if hasattr(col, "name") else str(col) + fields.append(name) + types.append(self._convert_type(item.dtype)) + + return { + "names": fields, + "struct": {"types": types} + } + + def _convert_type(self, dtype: Any) -> Dict[str, Any]: + import bigframes.dtypes + if dtype == bigframes.dtypes.INT_DTYPE: + return {"i64": {}} + elif dtype == bigframes.dtypes.FLOAT_DTYPE: + return {"fp64": {}} + elif dtype == bigframes.dtypes.BOOL_DTYPE: + return {"bool": {}} + elif dtype == bigframes.dtypes.STRING_DTYPE: + return {"string": {}} + else: + # Fallback to string for now + return {"string": {}} diff --git a/packages/bigframes/bigframes/session/substrait_executor.py b/packages/bigframes/bigframes/session/substrait_executor.py new file mode 100644 index 000000000000..243c51d8c634 --- /dev/null +++ b/packages/bigframes/bigframes/session/substrait_executor.py @@ -0,0 +1,132 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Optional + +from bigframes.core import bigframe_node +from bigframes.session import executor, semi_executor +import bigframes.core.rewrite.slices as slices_rewrite +from bigframes.core import nodes + +if TYPE_CHECKING: + import pyarrow as pa + + +class SubstraitConsumer(abc.ABC): + """ + Interface for consuming Substrait plans and executing them. + This acts as a plugin interface for different Substrait execution engines. + """ + + @abc.abstractmethod + def consume(self, plan: bytes, tables: dict[str, pa.Table]) -> pa.Table: + """ + Executes a Substrait plan and returns a PyArrow Table. + + Args: + plan: The Substrait plan as bytes (usually a serialized Protobuf). + tables: A dictionary of table names to PyArrow Tables for local data. + + Returns: + A PyArrow Table containing the results. + """ + pass + + +class DataFusionSubstraitConsumer(SubstraitConsumer): + """ + Executes Substrait plans using Apache DataFusion. + """ + + def consume(self, plan: bytes, tables: dict[str, pa.Table]) -> pa.Table: + # Import datafusion lazily to avoid hard dependency + try: + import datafusion + except ImportError: + raise ImportError( + "The datafusion package is required to use DataFusionSubstraitConsumer. " + "Install it with `pip install datafusion`." + ) + + # Create a DataFusion context + ctx = datafusion.SessionContext() + + for name, table in tables.items(): + df = ctx.from_arrow_table(table) + ctx.register_table(name, df) + + # NOTE: The actual API for running Substrait in DataFusion python bindings may vary. + # Assuming something like ctx.from_substrait(plan) or ctx.execute_substrait(plan). + # We will need to verify this with the actual datafusion python package if available. + # For now, we raise NotImplementedError if we cannot find the method, or try a likely one. + + import datafusion.substrait + + json_str = plan.decode('utf-8') + plan_obj = datafusion.substrait.Plan.from_json(json_str) + print("DEBUG RE-SERIALIZED JSON SUBSTRAIT PLAN:") + print(plan_obj.to_json()) + logical_plan = datafusion.substrait.Consumer.from_substrait_plan(ctx, plan_obj) + df = ctx.create_dataframe_from_logical_plan(logical_plan) + return df.to_arrow_table() + + +class SubstraitExecutor(semi_executor.SemiExecutor): + """ + Executes plans by compiling them to Substrait and running them via a consumer. + """ + + def __init__(self, consumer: SubstraitConsumer): + self._consumer = consumer + # Lazy import to avoid circular dependencies + from bigframes.core.compile.substrait.compiler import SubstraitCompiler + self._compiler = SubstraitCompiler() + + def execute( + self, + plan: bigframe_node.BigFrameNode, + ordered: bool, + peek: Optional[int] = None, + ) -> Optional[executor.ExecuteResult]: + rewritten_plan = plan.bottom_up(slices_rewrite.rewrite_slice) + + if not self._can_execute(rewritten_plan): + return None + + substrait_plan = self._compiler.compile(rewritten_plan) + + if substrait_plan is None: + return None + + tables = {} + for node in rewritten_plan.unique_nodes(): + if isinstance(node, nodes.ReadLocalNode): + table_name = f"table_{node.local_data_source.id.hex}" + tables[table_name] = node.local_data_source.data + + pa_table = self._consumer.consume(substrait_plan, tables) + + if peek is not None: + pa_table = pa_table.slice(0, peek) + + return executor.LocalExecuteResult( + data=pa_table, + bf_schema=rewritten_plan.schema, + ) + + def _can_execute(self, plan: bigframe_node.BigFrameNode) -> bool: + return self._compiler.can_compile(plan) diff --git a/packages/bigframes/bigframes/testing/substrait_session.py b/packages/bigframes/bigframes/testing/substrait_session.py new file mode 100644 index 000000000000..cd77ab888adb --- /dev/null +++ b/packages/bigframes/bigframes/testing/substrait_session.py @@ -0,0 +1,112 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import dataclasses +import weakref +from typing import TYPE_CHECKING, Union + +import pandas + +import bigframes +import bigframes.core.blocks +import bigframes.dataframe +import bigframes.session.execution_spec +import bigframes.session.executor +import bigframes.session.metrics + +if TYPE_CHECKING: + import bigframes.core + + +@dataclasses.dataclass +class SubstraitTestExecutor(bigframes.session.executor.Executor): + def __init__(self): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer, SubstraitExecutor + self.executor = SubstraitExecutor(DataFusionSubstraitConsumer()) + + def execute( + self, + array_value: bigframes.core.ArrayValue, + execution_spec: bigframes.session.execution_spec.ExecutionSpec, + ): + if execution_spec.destination_spec is not None: + raise ValueError( + f"SubstraitTestExecutor does not support destination spec: {execution_spec.destination_spec}" + ) + + result = self.executor.execute(array_value.node, ordered=True, peek=execution_spec.peek) + if result is None: + raise NotImplementedError("SubstraitExecutor cannot execute this plan") + + return result + + def cached( + self, + array_value: bigframes.core.ArrayValue, + *, + config, + ) -> None: + return + + +class TestSession(bigframes.session.Session): + def __init__(self): + self._location = None # type: ignore + self._bq_kms_key_name = None # type: ignore + self._clients_provider = None # type: ignore + self._bq_connection = None # type: ignore + self._skip_bq_connection_check = True + self._session_id: str = "substrait_test_session" + self._objects: list[ + weakref.ReferenceType[ + Union[ + bigframes.core.indexes.Index, + bigframes.series.Series, + bigframes.dataframe.DataFrame, + ] + ] + ] = [] + self._strictly_ordered: bool = True + self._allow_ambiguity = False # type: ignore + self._default_index_type = bigframes.enums.DefaultIndexKind.SEQUENTIAL_INT64 + self._metrics = bigframes.session.metrics.ExecutionMetrics() + self._function_session = None # type: ignore + self._temp_storage_manager = None # type: ignore + self._executor = SubstraitTestExecutor() + self._loader = None # type: ignore + + def read_pandas(self, pandas_dataframe, write_engine="default"): + original_input = pandas_dataframe + + if isinstance(pandas_dataframe, (pandas.Series, pandas.Index)): + pandas_dataframe = pandas_dataframe.to_frame() + + local_block = bigframes.core.blocks.Block.from_local(pandas_dataframe, self) + bf_df = bigframes.dataframe.DataFrame(local_block) + + if isinstance(original_input, pandas.Series): + series = bf_df[bf_df.columns[0]] + series.name = original_input.name + return series + + if isinstance(original_input, pandas.Index): + return bf_df.index + + return bf_df + + @property + def bqclient(self): + return None diff --git a/packages/bigframes/tests/unit/session/test_substrait_executor.py b/packages/bigframes/tests/unit/session/test_substrait_executor.py new file mode 100644 index 000000000000..bdd21bed2b79 --- /dev/null +++ b/packages/bigframes/tests/unit/session/test_substrait_executor.py @@ -0,0 +1,135 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from bigframes.core import identifiers, local_data, nodes +from bigframes.session import substrait_executor +from bigframes.testing import mocks +import bigframes.core.expression as ex + + +class MockConsumer(substrait_executor.SubstraitConsumer): + def consume(self, plan: bytes, tables: dict[str, pa.Table]) -> pa.Table: + # Return a simple table regardless of the plan + return pa.Table.from_pydict({"a": [1, 2, 3]}) + + +@pytest.fixture +def object_under_test(): + return substrait_executor.SubstraitExecutor(MockConsumer()) + + +def create_read_local_node(): + session = mocks.create_bigquery_session() + arrow_table = pa.Table.from_pydict({"a": [1, 2, 3]}) + local_data_source = local_data.ManagedArrowTable.from_pyarrow(arrow_table) + return nodes.ReadLocalNode( + local_data_source=local_data_source, + session=session, + scan_list=nodes.ScanList( + items=( + nodes.ScanItem( + id=identifiers.ColumnId("a"), + source_id="a", + ), + ) + ), + ) + + +def test_substrait_executor_execute(object_under_test): + plan = create_read_local_node() + + result = object_under_test.execute(plan, ordered=True) + assert result is not None + + # Verify the result table + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 3 + assert result_table.column_names == ["a"] + assert result_table.column("a").to_pylist() == [1, 2, 3] + + +def test_substrait_executor_unsupported_node(object_under_test): + # ConcatNode is not supported by our skeletal compiler + session = mocks.create_bigquery_session() + read_node = create_read_local_node() + plan = nodes.ConcatNode( + children=(read_node, read_node), + output_ids=(identifiers.ColumnId("concat"),), + ) + + result = object_under_test.execute(plan, ordered=True) + assert result is None + + +def test_execute_projection_literal_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + read_node = create_read_local_node() + + assignment_expr = ex.ScalarConstantExpression(42) + plan = nodes.ProjectionNode( + child=read_node, + assignments=((assignment_expr, identifiers.ColumnId("b")),), + ) + + result = executor.execute(plan, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 3 + assert "b" in result_table.column_names + # Depending on our passthrough implementation, "a" should also be there. + # Our _compile_projection passes through child fields! + assert "a" in result_table.column_names + assert result_table.column("b").to_pylist() == [42, 42, 42] + + +def test_execute_projection_add_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + from bigframes.operations.numeric_ops import add_op + + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + read_node = create_read_local_node() + + # a + 42 + add_expr = ex.OpExpression( + op=add_op, + inputs=( + ex.DerefOp(identifiers.ColumnId("a")), + ex.ScalarConstantExpression(42), + ), + ) + plan = nodes.ProjectionNode( + child=read_node, + assignments=((add_expr, identifiers.ColumnId("b")),), + ) + + result = executor.execute(plan, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 3 + assert "b" in result_table.column_names + assert "a" in result_table.column_names + assert result_table.column("b").to_pylist() == [43, 44, 45] diff --git a/packages/bigframes/tests/unit/test_dataframe_substrait.py b/packages/bigframes/tests/unit/test_dataframe_substrait.py new file mode 100644 index 000000000000..701e3f413f84 --- /dev/null +++ b/packages/bigframes/tests/unit/test_dataframe_substrait.py @@ -0,0 +1,4492 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import operator +import pathlib +import tempfile +import typing +from typing import Generator, List, Tuple + +import numpy as np +import pandas as pd +import pandas.testing +import pytest + +import bigframes +import bigframes._config.display_options as display_options +import bigframes.core.indexes as bf_indexes +import bigframes.dataframe as dataframe +import bigframes.pandas as bpd +import bigframes.series as series +from bigframes.testing.utils import ( + assert_dfs_equivalent, + assert_frame_equal, + assert_series_equal, + assert_series_equivalent, + convert_pandas_dtypes, +) + +pytest.importorskip("datafusion") +pytest.importorskip("pandas", minversion="2.0.0") + +CURRENT_DIR = pathlib.Path(__file__).parent +DATA_DIR = CURRENT_DIR.parent / "data" + + +@pytest.fixture(scope="module", autouse=True) +def session() -> Generator[bigframes.Session, None, None]: + import bigframes.core.global_session + from bigframes.testing import substrait_session + + session = substrait_session.TestSession() + with bigframes.core.global_session._GlobalSessionContext(session): + yield session + + +@pytest.fixture(scope="module") +def scalars_pandas_df_index() -> pd.DataFrame: + """pd.DataFrame pointing at test data.""" + + df = pd.read_json( + DATA_DIR / "scalars.jsonl", + lines=True, + ) + convert_pandas_dtypes(df, bytes_col=True) + + df = df.set_index("rowindex", drop=False) + df.index.name = None + return df.set_index("rowindex").sort_index() + + +@pytest.fixture(scope="module") +def scalars_df_index( + session: bigframes.Session, scalars_pandas_df_index +) -> bpd.DataFrame: + return session.read_pandas(scalars_pandas_df_index) + + +@pytest.fixture(scope="module") +def scalars_df_2_index( + session: bigframes.Session, scalars_pandas_df_index +) -> bpd.DataFrame: + return session.read_pandas(scalars_pandas_df_index) + + +@pytest.fixture(scope="module") +def scalars_dfs( + scalars_df_index, + scalars_pandas_df_index, +): + return scalars_df_index, scalars_pandas_df_index + + +def test_df_construct_copy(scalars_dfs): + columns = ["int64_col", "string_col", "float64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + # Make the mapping from label to col_id non-trivial + bf_df = scalars_df.copy() + bf_df["int64_col"] = bf_df["int64_col"] / 2 + pd_df = scalars_pandas_df.copy() + pd_df["int64_col"] = pd_df["int64_col"] / 2 + + bf_result = dataframe.DataFrame(bf_df, columns=columns).to_pandas() + + pd_result = pd.DataFrame(pd_df, columns=columns) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_construct_pandas_default(scalars_dfs): + # This should trigger the inlined codepath + columns = [ + "int64_too", + "int64_col", + "float64_col", + "bool_col", + "string_col", + "date_col", + "datetime_col", + "numeric_col", + "float64_col", + "time_col", + "timestamp_col", + ] + _, scalars_pandas_df = scalars_dfs + bf_result = dataframe.DataFrame(scalars_pandas_df, columns=columns).to_pandas() + pd_result = pd.DataFrame(scalars_pandas_df, columns=columns) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_construct_structs(session): + pd_frame = pd.Series( + [ + {"version": 1, "project": "pandas"}, + {"version": 2, "project": "pandas"}, + {"version": 1, "project": "numpy"}, + ] + ).to_frame() + bf_series = session.read_pandas(pd_frame) + pd.testing.assert_frame_equal( + bf_series.to_pandas(), pd_frame, check_index_type=False, check_dtype=False + ) + + +def test_df_construct_pandas_set_dtype(scalars_dfs): + columns = [ + "int64_too", + "int64_col", + "float64_col", + "bool_col", + ] + _, scalars_pandas_df = scalars_dfs + bf_result = dataframe.DataFrame( + scalars_pandas_df, columns=columns, dtype="Float64" + ).to_pandas() + pd_result = pd.DataFrame(scalars_pandas_df, columns=columns, dtype="Float64") + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_construct_from_series(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = dataframe.DataFrame( + {"a": scalars_df["int64_col"], "b": scalars_df["string_col"]}, + dtype="string[pyarrow]", + ) + pd_result = pd.DataFrame( + {"a": scalars_pandas_df["int64_col"], "b": scalars_pandas_df["string_col"]}, + dtype="string[pyarrow]", + ) + assert_dfs_equivalent(pd_result, bf_result) + + +def test_df_construct_from_dict(): + input_dict = { + "Animal": ["Falcon", "Falcon", "Parrot", "Parrot"], + # With a space in column name. We use standardized SQL schema ids to solve the problem that BQ schema doesn't support column names with spaces. b/296751058 + "Max Speed": [380.0, 370.0, 24.0, 26.0], + } + bf_result = dataframe.DataFrame(input_dict).to_pandas() + pd_result = pd.DataFrame(input_dict) + + pandas.testing.assert_frame_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + +def test_df_construct_dtype(): + data = { + "int_col": [1, 2, 3], + "string_col": ["1.1", "2.0", "3.5"], + "float_col": [1.0, 2.0, 3.0], + } + dtype = pd.StringDtype(storage="pyarrow") + bf_result = dataframe.DataFrame(data, dtype=dtype) + pd_result = pd.DataFrame(data, dtype=dtype) + pd_result.index = pd_result.index.astype("Int64") + pandas.testing.assert_frame_equal(bf_result.to_pandas(), pd_result) + + +def test_get_column(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name = "int64_col" + series = scalars_df[col_name] + bf_result = series.to_pandas() + pd_result = scalars_pandas_df[col_name] + assert_series_equal(bf_result, pd_result) + + +def test_get_column_nonstring(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + series = scalars_df.rename(columns={"int64_col": 123.1})[123.1] + bf_result = series.to_pandas() + pd_result = scalars_pandas_df.rename(columns={"int64_col": 123.1})[123.1] + assert_series_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + "row_slice", + [ + (slice(1, 7, 2)), + (slice(1, 7, None)), + (slice(None, -3, None)), + ], +) +def test_get_rows_with_slice(scalars_dfs, row_slice): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[row_slice].to_pandas() + pd_result = scalars_pandas_df[row_slice] + assert_frame_equal(bf_result, pd_result) + + +def test_hasattr(scalars_dfs): + scalars_df, _ = scalars_dfs + assert hasattr(scalars_df, "int64_col") + assert hasattr(scalars_df, "head") + assert not hasattr(scalars_df, "not_exist") + + +@pytest.mark.parametrize( + ("ordered"), + [ + (True), + (False), + ], +) +def test_head_with_custom_column_labels( + scalars_df_index, scalars_pandas_df_index, ordered +): + rename_mapping = { + "int64_col": "Integer Column", + "string_col": "言語列", + } + bf_df = scalars_df_index.rename(columns=rename_mapping).head(3) + bf_result = bf_df.to_pandas(ordered=ordered) + pd_result = scalars_pandas_df_index.rename(columns=rename_mapping).head(3) + assert_frame_equal(bf_result, pd_result, ignore_order=not ordered) + + +def test_tail_with_custom_column_labels(scalars_df_index, scalars_pandas_df_index): + rename_mapping = { + "int64_col": "Integer Column", + "string_col": "言語列", + } + bf_df = scalars_df_index.rename(columns=rename_mapping).tail(3) + bf_result = bf_df.to_pandas() + pd_result = scalars_pandas_df_index.rename(columns=rename_mapping).tail(3) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_get_column_by_attr(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + series = scalars_df.int64_col + bf_result = series.to_pandas() + pd_result = scalars_pandas_df.int64_col + assert_series_equal(bf_result, pd_result) + + +def test_get_columns(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = ["bool_col", "float64_col", "int64_col"] + df_subset = scalars_df.get(col_names) + df_pandas = df_subset.to_pandas() + pd.testing.assert_index_equal( + df_pandas.columns, scalars_pandas_df[col_names].columns + ) + + +def test_get_columns_default(scalars_dfs): + scalars_df, _ = scalars_dfs + col_names = ["not", "column", "names"] + result = scalars_df.get(col_names, "default_val") + assert result == "default_val" + + +@pytest.mark.parametrize( + ("loc", "column", "value", "allow_duplicates"), + [ + (0, 666, 2, False), + (5, "float64_col", 2.2, True), + (13, "rowindex_2", [8, 7, 6, 5, 4, 3, 2, 1, 0], True), + pytest.param( + 14, + "test", + 2, + False, + marks=pytest.mark.xfail( + raises=IndexError, + ), + ), + pytest.param( + 12, + "int64_col", + 2, + False, + marks=pytest.mark.xfail( + raises=ValueError, + ), + ), + ], +) +def test_insert(scalars_dfs, loc, column, value, allow_duplicates): + scalars_df, scalars_pandas_df = scalars_dfs + # insert works inplace, so will influence other tests. + # make a copy to avoid inplace changes. + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df.insert(loc, column, value, allow_duplicates) + pd_df.insert(loc, column, value, allow_duplicates) + + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False) + + +def test_where_series_cond(scalars_df_index, scalars_pandas_df_index): + # Condition is dataframe, other is None (as default). + cond_bf = scalars_df_index["int64_col"] > 0 + cond_pd = scalars_pandas_df_index["int64_col"] > 0 + bf_result = scalars_df_index.where(cond_bf).to_pandas() + pd_result = scalars_pandas_df_index.where(cond_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): + cond_bf = scalars_df_index["int64_col"] > 0 + cond_pd = scalars_pandas_df_index["int64_col"] > 0 + + bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]] + pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]] + bf_result = bf_df.mask(cond_bf, bf_df + 1).to_pandas() + pd_result = pd_df.mask(cond_pd, pd_df + 1) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): + # Test when a dataframe has multi-index or multi-columns. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + + dataframe_bf.columns = pd.MultiIndex.from_tuples( + [("str1", 1), ("str2", 2)], names=["STR", "INT"] + ) + cond_bf = dataframe_bf["str1"] > 0 + + with pytest.raises(NotImplementedError) as context: + dataframe_bf.where(cond_bf).to_pandas() + assert ( + str(context.value) + == "The dataframe.where() method does not support multi-column." + ) + + +def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a series, other is a constant. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + dataframe_bf.columns.name = "test_name" + dataframe_pd.columns.name = "test_name" + + cond_bf = dataframe_bf["int64_col"] > 0 + cond_pd = dataframe_pd["int64_col"] > 0 + other = 0 + + bf_result = dataframe_bf.where(cond_bf, other).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_series_cond_dataframe_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a series, other is a dataframe. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf["int64_col"] > 0 + cond_pd = dataframe_pd["int64_col"] > 0 + other_bf = -dataframe_bf + other_pd = -dataframe_pd + + bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond(scalars_df_index, scalars_pandas_df_index): + # Condition is a dataframe, other is None. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + + bf_result = dataframe_bf.where(cond_bf, None).to_pandas() + pd_result = dataframe_pd.where(cond_pd, None) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond_const_other(scalars_df_index, scalars_pandas_df_index): + # Condition is a dataframe, other is a constant. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + other_bf = 10 + other_pd = 10 + + bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_where_dataframe_cond_dataframe_other( + scalars_df_index, scalars_pandas_df_index +): + # Condition is a dataframe, other is a dataframe. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + dataframe_pd = scalars_pandas_df_index[columns] + + cond_bf = dataframe_bf > 0 + cond_pd = dataframe_pd > 0 + other_bf = dataframe_bf * 2 + other_pd = dataframe_pd * 2 + + bf_result = dataframe_bf.where(cond_bf, other_bf).to_pandas() + pd_result = dataframe_pd.where(cond_pd, other_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_drop_column(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name = "int64_col" + df_pandas = scalars_df.drop(columns=col_name).to_pandas() + pd.testing.assert_index_equal( + df_pandas.columns, scalars_pandas_df.drop(columns=col_name).columns + ) + + +def test_drop_columns(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = ["int64_col", "geography_col", "time_col"] + df_pandas = scalars_df.drop(columns=col_names).to_pandas() + pd.testing.assert_index_equal( + df_pandas.columns, scalars_pandas_df.drop(columns=col_names).columns + ) + + +def test_drop_labels_axis_1(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + labels = ["int64_col", "geography_col", "time_col"] + + pd_result = scalars_pandas_df.drop(labels=labels, axis=1) + bf_result = scalars_df.drop(labels=labels, axis=1).to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_with_custom_column_labels(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + rename_mapping = { + "int64_col": "Integer Column", + "string_col": "言語列", + } + dropped_columns = [ + "言語列", + "timestamp_col", + ] + bf_df = scalars_df.rename(columns=rename_mapping).drop(columns=dropped_columns) + bf_result = bf_df.to_pandas() + pd_result = scalars_pandas_df.rename(columns=rename_mapping).drop( + columns=dropped_columns + ) + assert_frame_equal(bf_result, pd_result) + + +def test_df_memory_usage(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + pd_result = scalars_pandas_df.memory_usage() + bf_result = scalars_df.memory_usage() + + pd.testing.assert_series_equal(pd_result, bf_result, rtol=1.5) + + +def test_df_info(scalars_dfs): + expected = ( + "\n" + "Index: 9 entries, 0 to 8\n" + "Data columns (total 14 columns):\n" + " # Column Non-Null Count Dtype\n" + "--- ------------- ---------------- ------------------------------\n" + " 0 bool_col 8 non-null boolean\n" + " 1 bytes_col 6 non-null binary[pyarrow]\n" + " 2 date_col 7 non-null date32[day][pyarrow]\n" + " 3 datetime_col 6 non-null timestamp[us][pyarrow]\n" + " 4 geography_col 4 non-null geometry\n" + " 5 int64_col 8 non-null Int64\n" + " 6 int64_too 9 non-null Int64\n" + " 7 numeric_col 6 non-null decimal128(38, 9)[pyarrow]\n" + " 8 float64_col 7 non-null Float64\n" + " 9 rowindex_2 9 non-null Int64\n" + " 10 string_col 8 non-null string\n" + " 11 time_col 6 non-null time64[us][pyarrow]\n" + " 12 timestamp_col 6 non-null timestamp[us, tz=UTC][pyarrow]\n" + " 13 duration_col 7 non-null duration[us][pyarrow]\n" + "dtypes: Float64(1), Int64(3), binary[pyarrow](1), boolean(1), date32[day][pyarrow](1), decimal128(38, 9)[pyarrow](1), duration[us][pyarrow](1), geometry(1), string(1), time64[us][pyarrow](1), timestamp[us, tz=UTC][pyarrow](1), timestamp[us][pyarrow](1)\n" + "memory usage: 1341 bytes\n" + ) + + scalars_df, _ = scalars_dfs + bf_result = io.StringIO() + + scalars_df.info(buf=bf_result) + + assert expected == bf_result.getvalue() + + +@pytest.mark.parametrize( + ("include", "exclude"), + [ + ("Int64", None), + (["int"], None), + ("number", None), + ([pd.Int64Dtype(), pd.BooleanDtype()], None), + (None, [pd.Int64Dtype(), pd.BooleanDtype()]), + ("Int64", ["boolean"]), + ], +) +def test_select_dtypes(scalars_dfs, include, exclude): + scalars_df, scalars_pandas_df = scalars_dfs + + pd_result = scalars_pandas_df.select_dtypes(include=include, exclude=exclude) + bf_result = scalars_df.select_dtypes(include=include, exclude=exclude).to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_index(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + pd_result = scalars_pandas_df.drop(index=[4, 1, 2]) + bf_result = scalars_df.drop(index=[4, 1, 2]).to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_pandas_index(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + drop_index = scalars_pandas_df.iloc[[4, 1, 2]].index + + pd_result = scalars_pandas_df.drop(index=drop_index) + bf_result = scalars_df.drop(index=drop_index).to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_bigframes_index(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + drop_index = scalars_df.loc[[4, 1, 2]].index + drop_pandas_index = scalars_pandas_df.loc[[4, 1, 2]].index + + pd_result = scalars_pandas_df.drop(index=drop_pandas_index) + bf_result = scalars_df.drop(index=drop_index).to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_bigframes_index_with_na(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + scalars_df = scalars_df.copy() + scalars_pandas_df = scalars_pandas_df.copy() + scalars_df = scalars_df.set_index("bytes_col") + scalars_pandas_df = scalars_pandas_df.set_index("bytes_col") + drop_index = scalars_df.iloc[[2, 5]].index + drop_pandas_index = scalars_pandas_df.iloc[[2, 5]].index + + pd_result = scalars_pandas_df.drop(index=drop_pandas_index) # drop_pandas_index) + bf_result = scalars_df.drop(index=drop_index).to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_bigframes_multiindex(scalars_dfs): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df, scalars_pandas_df = scalars_dfs + scalars_df = scalars_df.copy() + scalars_pandas_df = scalars_pandas_df.copy() + sub_df = scalars_df.iloc[[4, 1, 2]] + sub_pandas_df = scalars_pandas_df.iloc[[4, 1, 2]] + sub_df = sub_df.set_index(["bytes_col", "numeric_col"]) + sub_pandas_df = sub_pandas_df.set_index(["bytes_col", "numeric_col"]) + drop_index = sub_df.index + drop_pandas_index = sub_pandas_df.index + + scalars_df = scalars_df.set_index(["bytes_col", "numeric_col"]) + scalars_pandas_df = scalars_pandas_df.set_index(["bytes_col", "numeric_col"]) + bf_result = scalars_df.drop(index=drop_index).to_pandas() + pd_result = scalars_pandas_df.drop(index=drop_pandas_index) + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_labels_axis_0(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + pd_result = scalars_pandas_df.drop(labels=[4, 1, 2], axis=0) + bf_result = scalars_df.drop(labels=[4, 1, 2], axis=0).to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_drop_index_and_columns(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + pd_result = scalars_pandas_df.drop(index=[4, 1, 2], columns="int64_col") + bf_result = scalars_df.drop(index=[4, 1, 2], columns="int64_col").to_pandas() + + pd.testing.assert_frame_equal(pd_result, bf_result) + + +def test_rename(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name_dict = {"bool_col": 1.2345} + df_pandas = scalars_df.rename(columns=col_name_dict).to_pandas() + pd.testing.assert_index_equal( + df_pandas.columns, scalars_pandas_df.rename(columns=col_name_dict).columns + ) + + +def test_df_peek(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + peek_result = scalars_df.peek(n=3, force=False, allow_large_results=True) + + pd.testing.assert_index_equal(scalars_pandas_df.columns, peek_result.columns) + assert len(peek_result) == 3 + + +def test_df_peek_with_large_results_not_allowed(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + peek_result = scalars_df.peek(n=3, force=False, allow_large_results=False) + + pd.testing.assert_index_equal(scalars_pandas_df.columns, peek_result.columns) + assert len(peek_result) == 3 + + +def test_df_peek_filtered(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + peek_result = scalars_df[scalars_df.int64_col != 0].peek(n=3, force=False) + pd.testing.assert_index_equal(scalars_pandas_df.columns, peek_result.columns) + assert len(peek_result) == 3 + + +def test_df_peek_force_default(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + peek_result = scalars_df[["int64_col", "int64_too"]].cumsum().peek(n=3) + pd.testing.assert_index_equal( + scalars_pandas_df[["int64_col", "int64_too"]].columns, peek_result.columns + ) + assert len(peek_result) == 3 + + +def test_df_peek_reset_index(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + peek_result = ( + scalars_df[["int64_col", "int64_too"]].reset_index(drop=True).peek(n=3) + ) + pd.testing.assert_index_equal( + scalars_pandas_df[["int64_col", "int64_too"]].columns, peek_result.columns + ) + assert len(peek_result) == 3 + + +def test_repr_w_all_rows(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + # Remove columns with flaky formatting, like NUMERIC columns (which use the + # object dtype). Also makes a copy so that mutating the index name doesn't + # break other tests. + scalars_df = scalars_df.drop(columns=["numeric_col"]) + scalars_pandas_df = scalars_pandas_df.drop(columns=["numeric_col"]) + + # When there are 10 or fewer rows, the outputs should be identical. + actual = repr(scalars_df.head(10)) + + with display_options.pandas_repr(bigframes.options.display): + expected = repr(scalars_pandas_df.head(10)) + + assert actual == expected + + +def test_join_repr(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + scalars_df = ( + scalars_df[["int64_col"]] + .join(scalars_df.set_index("int64_col")[["int64_too"]]) + .sort_index() + ) + scalars_pandas_df = ( + scalars_pandas_df[["int64_col"]] + .join(scalars_pandas_df.set_index("int64_col")[["int64_too"]]) + .sort_index() + ) + # Pandas join result index name seems to depend on the index values in a way that bigframes can't match exactly + scalars_pandas_df.index.name = None + + actual = repr(scalars_df) + + with display_options.pandas_repr(bigframes.options.display): + expected = repr(scalars_pandas_df) + + assert actual == expected + + +def test_mimebundle_html_repr_w_all_rows(scalars_dfs, session): + scalars_df, _ = scalars_dfs + # get a pandas df of the expected format + df, _ = scalars_df._block.to_pandas() + pandas_df = df.set_axis(scalars_df._block.column_labels, axis=1) + pandas_df.index.name = scalars_df.index.name + + # When there are 10 or fewer rows, the outputs should be identical except for the extra note. + bundle = scalars_df.head(10)._repr_mimebundle_() + actual = bundle["text/html"] + + with display_options.pandas_repr(bigframes.options.display): + pandas_repr = pandas_df.head(10)._repr_html_() + + expected = ( + pandas_repr + + f"[{len(pandas_df.index)} rows x {len(pandas_df.columns)} columns in total]" + ) + assert actual == expected + + +def test_df_column_name_with_space(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name_dict = {"bool_col": "bool col"} + df_pandas = scalars_df.rename(columns=col_name_dict).to_pandas() + pd.testing.assert_index_equal( + df_pandas.columns, scalars_pandas_df.rename(columns=col_name_dict).columns + ) + + +def test_df_column_name_duplicate(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name_dict = {"int64_too": "int64_col"} + df_pandas = scalars_df.rename(columns=col_name_dict).to_pandas() + pd.testing.assert_index_equal( + df_pandas.columns, scalars_pandas_df.rename(columns=col_name_dict).columns + ) + + +def test_get_df_column_name_duplicate(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name_dict = {"int64_too": "int64_col"} + + bf_result = scalars_df.rename(columns=col_name_dict)["int64_col"].to_pandas() + pd_result = scalars_pandas_df.rename(columns=col_name_dict)["int64_col"] + pd.testing.assert_index_equal(bf_result.columns, pd_result.columns) + + +@pytest.mark.parametrize( + ("indices", "axis"), + [ + ([1, 3, 5], 0), + ([2, 4, 6], 1), + ([1, -3, -5, -6], "index"), + ([-2, -4, -6], "columns"), + ], +) +def test_take_df(scalars_dfs, indices, axis): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.take(indices, axis=axis).to_pandas() + pd_result = scalars_pandas_df.take(indices, axis=axis) + + assert_frame_equal(bf_result, pd_result) + + +def test_filter_df(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_bool_series = scalars_df["bool_col"] + bf_result = scalars_df[bf_bool_series].to_pandas() + + pd_bool_series = scalars_pandas_df["bool_col"] + pd_result = scalars_pandas_df[pd_bool_series] + + assert_frame_equal(bf_result, pd_result) + + +def test_assign_new_column(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + kwargs = {"new_col": 2} + df = scalars_df.assign(**kwargs) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.assign(**kwargs) + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + + assert_frame_equal(bf_result, pd_result) + + +def test_assign_using_pd_col(scalars_dfs): + if pd.__version__.startswith("1.") or pd.__version__.startswith("2."): + pytest.skip("col expression interface only supported for pandas 3+") + scalars_df, scalars_pandas_df = scalars_dfs + bf_kwargs = { + "new_col_1": 4 - bpd.col("int64_col"), + "new_col_2": bpd.col("int64_col") / (bpd.col("float64_col") * 0.5), + } + pd_kwargs = { + "new_col_1": 4 - pd.col("int64_col"), # type: ignore + "new_col_2": pd.col("int64_col") / (pd.col("float64_col") * 0.5), # type: ignore + } + + df = scalars_df.assign(**bf_kwargs) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.assign(**pd_kwargs) + + assert_frame_equal(bf_result, pd_result) + + +def test_assign_new_column_w_loc(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df.loc[:, "new_col"] = 2 + pd_df.loc[:, "new_col"] = 2 + bf_result = bf_df.to_pandas() + pd_result = pd_df + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("scalar",), + [ + (2.1,), + (None,), + ], +) +def test_assign_new_column_w_setitem(scalars_dfs, scalar): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df["new_col"] = scalar + pd_df["new_col"] = scalar + bf_result = bf_df.to_pandas() + pd_result = pd_df + + # Convert default pandas dtypes `float64` to match BigQuery DataFrames dtypes. + pd_result["new_col"] = pd_result["new_col"].astype("Float64") + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +def test_assign_new_column_w_setitem_dataframe(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df["int64_col"] = bf_df["int64_too"].to_frame() + pd_df["int64_col"] = pd_df["int64_too"].to_frame() + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_df["int64_col"] = pd_df["int64_col"].astype("Int64") + + pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df) + + +def test_assign_new_column_w_setitem_dataframe_error(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + + with pytest.raises(ValueError): + bf_df["impossible_col"] = bf_df[["int64_too", "string_col"]] + with pytest.raises(ValueError): + pd_df["impossible_col"] = pd_df[["int64_too", "string_col"]] + + +def test_assign_new_column_w_setitem_list(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df["new_col"] = [9, 8, 7, 6, 5, 4, 3, 2, 1] + pd_df["new_col"] = [9, 8, 7, 6, 5, 4, 3, 2, 1] + bf_result = bf_df.to_pandas() + pd_result = pd_df + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +def test_assign_new_column_w_setitem_list_repeated(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df["new_col"] = [9, 8, 7, 6, 5, 4, 3, 2, 1] + pd_df["new_col"] = [9, 8, 7, 6, 5, 4, 3, 2, 1] + bf_df["new_col_2"] = [1, 3, 2, 5, 4, 7, 6, 9, 8] + pd_df["new_col_2"] = [1, 3, 2, 5, 4, 7, 6, 9, 8] + bf_result = bf_df.to_pandas() + pd_result = pd_df + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + pd_result["new_col_2"] = pd_result["new_col_2"].astype("Int64") + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +def test_assign_new_column_w_setitem_list_custom_index(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + + # set the custom index + pd_df = pd_df.set_index(["string_col", "int64_col"]) + bf_df = bf_df.set_index(["string_col", "int64_col"]) + + bf_df["new_col"] = [9, 8, 7, 6, 5, 4, 3, 2, 1] + pd_df["new_col"] = [9, 8, 7, 6, 5, 4, 3, 2, 1] + bf_result = bf_df.to_pandas() + pd_result = pd_df + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +def test_assign_new_column_w_setitem_list_error(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + + with pytest.raises(ValueError): + pd_df["new_col"] = [1, 2, 3] # should be len 9, is 3 + with pytest.raises(ValueError): + bf_df["new_col"] = [1, 2, 3] + + +def test_assign_existing_column(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + kwargs = {"int64_col": 2} + df = scalars_df.assign(**kwargs) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.assign(**kwargs) + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["int64_col"] = pd_result["int64_col"].astype("Int64") + + assert_frame_equal(bf_result, pd_result) + + +def test_assign_listlike_to_empty_df(session): + empty_df = dataframe.DataFrame(session=session) + empty_pandas_df = pd.DataFrame() + + bf_result = empty_df.assign(new_col=[1, 2, 3]) + pd_result = empty_pandas_df.assign(new_col=[1, 2, 3]) + + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + pd_result.index = pd_result.index.astype("Int64") + assert_frame_equal(bf_result.to_pandas(), pd_result) + + +def test_assign_to_empty_df_multiindex_error(session): + empty_df = dataframe.DataFrame(session=session) + empty_pandas_df = pd.DataFrame() + + empty_df["empty_col_1"] = typing.cast(series.Series, []) + empty_df["empty_col_2"] = typing.cast(series.Series, []) + empty_pandas_df["empty_col_1"] = [] + empty_pandas_df["empty_col_2"] = [] + empty_df = empty_df.set_index(["empty_col_1", "empty_col_2"]) + empty_pandas_df = empty_pandas_df.set_index(["empty_col_1", "empty_col_2"]) + + with pytest.raises(ValueError): + empty_df.assign(new_col=[1, 2, 3, 4, 5, 6, 7, 8, 9]) + with pytest.raises(ValueError): + empty_pandas_df.assign(new_col=[1, 2, 3, 4, 5, 6, 7, 8, 9]) + + +@pytest.mark.parametrize( + ("ordered"), + [ + (True), + (False), + ], +) +def test_assign_series(scalars_dfs, ordered): + scalars_df, scalars_pandas_df = scalars_dfs + column_name = "int64_col" + df = scalars_df.assign(new_col=scalars_df[column_name]) + bf_result = df.to_pandas(ordered=ordered) + pd_result = scalars_pandas_df.assign(new_col=scalars_pandas_df[column_name]) + + assert_frame_equal(bf_result, pd_result, ignore_order=not ordered) + + +def test_assign_series_overwrite(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + column_name = "int64_col" + df = scalars_df.assign(**{column_name: scalars_df[column_name] + 3}) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.assign( + **{column_name: scalars_pandas_df[column_name] + 3} + ) + + assert_frame_equal(bf_result, pd_result) + + +def test_assign_sequential(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + kwargs = {"int64_col": 2, "new_col": 3, "new_col2": 4} + df = scalars_df.assign(**kwargs) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.assign(**kwargs) + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["int64_col"] = pd_result["int64_col"].astype("Int64") + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + pd_result["new_col2"] = pd_result["new_col2"].astype("Int64") + + assert_frame_equal(bf_result, pd_result) + + +# Require an index so that the self-join is consistent each time. +def test_assign_same_table_different_index_performs_self_join( + scalars_df_index, scalars_pandas_df_index +): + column_name = "int64_col" + bf_df = scalars_df_index.assign( + alternative_index=scalars_df_index["rowindex_2"] + 2 + ) + pd_df = scalars_pandas_df_index.assign( + alternative_index=scalars_pandas_df_index["rowindex_2"] + 2 + ) + bf_df_2 = bf_df.set_index("alternative_index") + pd_df_2 = pd_df.set_index("alternative_index") + bf_result = bf_df.assign(new_col=bf_df_2[column_name] * 10).to_pandas() + pd_result = pd_df.assign(new_col=pd_df_2[column_name] * 10) + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +# Different table expression must have Index +def test_assign_different_df( + scalars_df_index, scalars_df_2_index, scalars_pandas_df_index +): + column_name = "int64_col" + df = scalars_df_index.assign(new_col=scalars_df_2_index[column_name]) + bf_result = df.to_pandas() + # Doesn't matter to pandas if it comes from the same DF or a different DF. + pd_result = scalars_pandas_df_index.assign( + new_col=scalars_pandas_df_index[column_name] + ) + + assert_frame_equal(bf_result, pd_result) + + +def test_assign_different_df_w_loc( + scalars_df_index, scalars_df_2_index, scalars_pandas_df_index +): + bf_df = scalars_df_index.copy() + bf_df2 = scalars_df_2_index.copy() + pd_df = scalars_pandas_df_index.copy() + assert "int64_col" in bf_df.columns + assert "int64_col" in pd_df.columns + bf_df.loc[:, "int64_col"] = bf_df2.loc[:, "int64_col"] + 1 + pd_df.loc[:, "int64_col"] = pd_df.loc[:, "int64_col"] + 1 + bf_result = bf_df.to_pandas() + pd_result = pd_df + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["int64_col"] = pd_result["int64_col"].astype("Int64") + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +def test_assign_different_df_w_setitem( + scalars_df_index, scalars_df_2_index, scalars_pandas_df_index +): + bf_df = scalars_df_index.copy() + bf_df2 = scalars_df_2_index.copy() + pd_df = scalars_pandas_df_index.copy() + assert "int64_col" in bf_df.columns + assert "int64_col" in pd_df.columns + bf_df["int64_col"] = bf_df2["int64_col"] + 1 + pd_df["int64_col"] = pd_df["int64_col"] + 1 + bf_result = bf_df.to_pandas() + pd_result = pd_df + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["int64_col"] = pd_result["int64_col"].astype("Int64") + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +def test_assign_callable_lambda(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + kwargs = {"new_col": lambda x: x["int64_col"] + x["int64_too"]} + df = scalars_df.assign(**kwargs) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.assign(**kwargs) + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result["new_col"] = pd_result["new_col"].astype("Int64") + + assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("axis", "how", "ignore_index", "subset"), + [ + (0, "any", False, None), + (0, "any", True, None), + (0, "all", False, ["bool_col", "time_col"]), + (0, "any", False, ["bool_col", "time_col"]), + (0, "all", False, "time_col"), + (1, "any", False, None), + (1, "all", False, None), + ], +) +def test_df_dropna(scalars_dfs, axis, how, ignore_index, subset): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df, scalars_pandas_df = scalars_dfs + df = scalars_df.dropna(axis=axis, how=how, ignore_index=ignore_index, subset=subset) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.dropna( + axis=axis, how=how, ignore_index=ignore_index, subset=subset + ) + + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_dropna_range_columns(scalars_dfs): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df, scalars_pandas_df = scalars_dfs + scalars_df = scalars_df.copy() + scalars_pandas_df = scalars_pandas_df.copy() + scalars_df.columns = pandas.RangeIndex(0, len(scalars_df.columns)) + scalars_pandas_df.columns = pandas.RangeIndex(0, len(scalars_pandas_df.columns)) + + df = scalars_df.dropna() + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.dropna() + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_interpolate(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + columns = ["int64_col", "int64_too", "float64_col"] + bf_result = scalars_df[columns].interpolate().to_pandas() + # Pandas can only interpolate on "float64" columns + # https://github.com/pandas-dev/pandas/issues/40252 + pd_result = scalars_pandas_df[columns].astype("float64").interpolate() + + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + +@pytest.mark.parametrize( + "col, fill_value", + [ + (["int64_col", "float64_col"], 3), + (["string_col"], "A"), + (["datetime_col"], pd.Timestamp("2023-01-01")), + ], +) +def test_df_fillna(scalars_dfs, col, fill_value): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[col].fillna(fill_value).to_pandas() + pd_result = scalars_pandas_df[col].fillna(fill_value) + + pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +@pytest.mark.skip("b/436316698 unit test failed for python 3.12") +def test_df_ffill(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[["int64_col", "float64_col"]].ffill(limit=1).to_pandas() + pd_result = scalars_pandas_df[["int64_col", "float64_col"]].ffill(limit=1) + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_bfill(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[["int64_col", "float64_col"]].bfill().to_pandas() + pd_result = scalars_pandas_df[["int64_col", "float64_col"]].bfill() + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_apply_series_series_callable( + scalars_df_index, + scalars_pandas_df_index, +): + columns = ["int64_too", "int64_col"] + + def foo(series, arg1, arg2, *, kwarg1=0, kwarg2=0): + return series**2 + (arg1 * arg2 % 4) + (kwarg1 * kwarg2 % 7) + + bf_result = ( + scalars_df_index[columns] + .apply(foo, args=(33, 61), kwarg1=52, kwarg2=21) + .to_pandas() + ) + + pd_result = scalars_pandas_df_index[columns].apply( + foo, args=(33, 61), kwarg1=52, kwarg2=21 + ) + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_apply_series_listlike_callable( + scalars_df_index, + scalars_pandas_df_index, +): + columns = ["int64_too", "int64_col"] + bf_result = ( + scalars_df_index[columns].apply(lambda x: [len(x), x.min(), 24]).to_pandas() + ) + + pd_result = scalars_pandas_df_index[columns].apply(lambda x: [len(x), x.min(), 24]) + + # Convert default pandas dtypes `int64` to match BigQuery DataFrames dtypes. + pd_result.index = pd_result.index.astype("Int64") + pd_result = pd_result.astype("Int64") + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_apply_series_scalar_callable( + scalars_df_index, + scalars_pandas_df_index, +): + columns = ["int64_too", "int64_col"] + bf_result = scalars_df_index[columns].apply(lambda x: x.sum()) + + pd_result = scalars_pandas_df_index[columns].apply(lambda x: x.sum()) + + pandas.testing.assert_series_equal(bf_result, pd_result) + + +def test_df_pipe( + scalars_df_index, + scalars_pandas_df_index, +): + columns = ["int64_too", "int64_col"] + + def foo(x: int, y: int, df): + return (df + x) % y + + bf_result = ( + scalars_df_index[columns] + .pipe((foo, "df"), x=7, y=9) + .pipe(lambda x: x**2) + .to_pandas() + ) + + pd_result = ( + scalars_pandas_df_index[columns] + .pipe((foo, "df"), x=7, y=9) + .pipe(lambda x: x**2) + ) + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_keys( + scalars_df_index, + scalars_pandas_df_index, +): + pandas.testing.assert_index_equal( + scalars_df_index.keys(), scalars_pandas_df_index.keys() + ) + + +def test_df_iter( + scalars_df_index, + scalars_pandas_df_index, +): + for bf_i, df_i in zip(scalars_df_index, scalars_pandas_df_index): + assert bf_i == df_i + + +def test_iterrows( + scalars_df_index, + scalars_pandas_df_index, +): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df_index = scalars_df_index.add_suffix("_suffix", axis=1) + scalars_pandas_df_index = scalars_pandas_df_index.add_suffix("_suffix", axis=1) + for (bf_index, bf_series), (pd_index, pd_series) in zip( + scalars_df_index.iterrows(), scalars_pandas_df_index.iterrows() + ): + assert bf_index == pd_index + pandas.testing.assert_series_equal(bf_series, pd_series) + + +@pytest.mark.parametrize( + ( + "index", + "name", + ), + [ + ( + True, + "my_df", + ), + (False, None), + ], +) +def test_itertuples(scalars_df_index, index, name): + # Numeric has slightly different representation as a result of conversions. + bf_tuples = scalars_df_index.itertuples(index, name) + pd_tuples = scalars_df_index.to_pandas().itertuples(index, name) + for bf_tuple, pd_tuple in zip(bf_tuples, pd_tuples): + assert bf_tuple == pd_tuple + + +def test_df_cross_merge(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + left_columns = ["int64_col", "float64_col", "rowindex_2"] + right_columns = ["int64_col", "bool_col", "string_col", "rowindex_2"] + + left = scalars_df[left_columns] + # Offset the rows somewhat so that outer join can have an effect. + right = scalars_df[right_columns].assign(rowindex_2=scalars_df["rowindex_2"] + 2) + + bf_result = left.merge(right, "cross").to_pandas() + + pd_result = scalars_pandas_df[left_columns].merge( + scalars_pandas_df[right_columns].assign( + rowindex_2=scalars_pandas_df["rowindex_2"] + 2 + ), + "cross", + ) + pd.testing.assert_frame_equal(bf_result, pd_result, check_index_type=False) + + +@pytest.mark.parametrize( + ("merge_how",), + [ + ("inner",), + ("outer",), + ("left",), + ("right",), + ], +) +def test_df_merge(scalars_dfs, merge_how): + scalars_df, scalars_pandas_df = scalars_dfs + on = "rowindex_2" + left_columns = ["int64_col", "float64_col", "rowindex_2"] + right_columns = ["int64_col", "bool_col", "string_col", "rowindex_2"] + + left = scalars_df[left_columns] + # Offset the rows somewhat so that outer join can have an effect. + right = scalars_df[right_columns].assign(rowindex_2=scalars_df["rowindex_2"] + 2) + + df = left.merge(right, merge_how, on, sort=True) + bf_result = df.to_pandas() + + pd_result = scalars_pandas_df[left_columns].merge( + scalars_pandas_df[right_columns].assign( + rowindex_2=scalars_pandas_df["rowindex_2"] + 2 + ), + merge_how, + on, + sort=True, + ) + + assert_frame_equal(bf_result, pd_result, ignore_order=True, check_index_type=False) + + +@pytest.mark.parametrize( + ("left_on", "right_on"), + [ + (["int64_col", "rowindex_2"], ["int64_col", "rowindex_2"]), + (["rowindex_2", "int64_col"], ["int64_col", "rowindex_2"]), + # Polars engine is currently strict on join key types + # (["rowindex_2", "float64_col"], ["int64_col", "rowindex_2"]), + ], +) +def test_df_merge_multi_key(scalars_dfs, left_on, right_on): + scalars_df, scalars_pandas_df = scalars_dfs + left_columns = ["int64_col", "float64_col", "rowindex_2"] + right_columns = ["int64_col", "bool_col", "string_col", "rowindex_2"] + + left = scalars_df[left_columns] + # Offset the rows somewhat so that outer join can have an effect. + right = scalars_df[right_columns].assign(rowindex_2=scalars_df["rowindex_2"] + 2) + + df = left.merge(right, "outer", left_on=left_on, right_on=right_on, sort=True) + bf_result = df.to_pandas() + + pd_result = scalars_pandas_df[left_columns].merge( + scalars_pandas_df[right_columns].assign( + rowindex_2=scalars_pandas_df["rowindex_2"] + 2 + ), + "outer", + left_on=left_on, + right_on=right_on, + sort=True, + ) + + assert_frame_equal(bf_result, pd_result, ignore_order=True, check_index_type=False) + + +@pytest.mark.parametrize( + ("merge_how",), + [ + ("inner",), + ("outer",), + ("left",), + ("right",), + ], +) +def test_merge_custom_col_name(scalars_dfs, merge_how): + scalars_df, scalars_pandas_df = scalars_dfs + left_columns = ["int64_col", "float64_col"] + right_columns = ["int64_col", "bool_col", "string_col"] + on = "int64_col" + rename_columns = {"float64_col": "f64_col"} + + left = scalars_df[left_columns] + left = left.rename(columns=rename_columns) + right = scalars_df[right_columns] + df = left.merge(right, merge_how, on, sort=True) + bf_result = df.to_pandas() + + pandas_left_df = scalars_pandas_df[left_columns] + pandas_left_df = pandas_left_df.rename(columns=rename_columns) + pandas_right_df = scalars_pandas_df[right_columns] + pd_result = pandas_left_df.merge(pandas_right_df, merge_how, on, sort=True) + + assert_frame_equal(bf_result, pd_result, ignore_order=True, check_index_type=False) + + +@pytest.mark.parametrize( + ("merge_how",), + [ + ("inner",), + ("outer",), + ("left",), + ("right",), + ], +) +def test_merge_left_on_right_on(scalars_dfs, merge_how): + scalars_df, scalars_pandas_df = scalars_dfs + left_columns = ["int64_col", "float64_col", "int64_too"] + right_columns = ["int64_col", "bool_col", "string_col", "rowindex_2"] + + left = scalars_df[left_columns] + right = scalars_df[right_columns] + + df = left.merge( + right, merge_how, left_on="int64_too", right_on="rowindex_2", sort=True + ) + bf_result = df.to_pandas() + + pd_result = scalars_pandas_df[left_columns].merge( + scalars_pandas_df[right_columns], + merge_how, + left_on="int64_too", + right_on="rowindex_2", + sort=True, + ) + + assert_frame_equal(bf_result, pd_result, ignore_order=True, check_index_type=False) + + +def test_shape(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df.shape + pd_result = scalars_pandas_df.shape + + assert bf_result == pd_result + + +def test_len(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = len(scalars_df) + pd_result = len(scalars_pandas_df) + + assert bf_result == pd_result + + +@pytest.mark.parametrize( + ("n_rows",), + [ + (50,), + (10000,), + ], +) +def test_df_len_local(session, n_rows): + assert ( + len( + session.read_pandas( + pd.DataFrame(np.random.randint(1, 7, n_rows), columns=["one"]), + ) + ) + == n_rows + ) + + +def test_size(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df.size + pd_result = scalars_pandas_df.size + + assert bf_result == pd_result + + +def test_ndim(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df.ndim + pd_result = scalars_pandas_df.ndim + + assert bf_result == pd_result + + +def test_empty_false(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.empty + pd_result = scalars_pandas_df.empty + + assert bf_result == pd_result + + +def test_empty_true_column_filter(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df[[]].empty + pd_result = scalars_pandas_df[[]].empty + + assert bf_result == pd_result + + +def test_empty_true_row_filter(scalars_dfs: Tuple[dataframe.DataFrame, pd.DataFrame]): + scalars_df, scalars_pandas_df = scalars_dfs + bf_bool: series.Series = typing.cast(series.Series, scalars_df["bool_col"]) + pd_bool: pd.Series = scalars_pandas_df["bool_col"] + bf_false = bf_bool.notna() & (bf_bool != bf_bool) + pd_false = pd_bool.notna() & (pd_bool != pd_bool) + + bf_result = scalars_df[bf_false].empty + pd_result = scalars_pandas_df[pd_false].empty + + assert pd_result + assert bf_result == pd_result + + +def test_empty_true_memtable(session: bigframes.Session): + bf_df = dataframe.DataFrame(session=session) + pd_df = pd.DataFrame() + + bf_result = bf_df.empty + pd_result = pd_df.empty + + assert pd_result + assert bf_result == pd_result + + +@pytest.mark.parametrize( + ("drop",), + ((True,), (False,)), +) +def test_reset_index(scalars_df_index, scalars_pandas_df_index, drop): + df = scalars_df_index.reset_index(drop=drop) + assert df.index.name is None + + bf_result = df.to_pandas() + pd_result = scalars_pandas_df_index.reset_index(drop=drop) + + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + + # reset_index should maintain the original ordering. + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_reset_index_then_filter( + scalars_df_index, + scalars_pandas_df_index, +): + bf_filter = scalars_df_index["bool_col"].fillna(True) + bf_df = scalars_df_index.reset_index()[bf_filter] + bf_result = bf_df.to_pandas() + pd_filter = scalars_pandas_df_index["bool_col"].fillna(True) + pd_result = scalars_pandas_df_index.reset_index()[pd_filter] + + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + + # reset_index should maintain the original ordering and index keys + # post-filter will have gaps. + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_reset_index_with_unnamed_index( + scalars_df_index, + scalars_pandas_df_index, +): + scalars_df_index = scalars_df_index.copy() + scalars_pandas_df_index = scalars_pandas_df_index.copy() + + scalars_df_index.index.name = None + scalars_pandas_df_index.index.name = None + df = scalars_df_index.reset_index(drop=False) + assert df.index.name is None + + # reset_index(drop=False) creates a new column "index". + assert df.columns[0] == "index" + + bf_result = df.to_pandas() + pd_result = scalars_pandas_df_index.reset_index(drop=False) + + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + + # reset_index should maintain the original ordering. + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_reset_index_with_unnamed_multiindex(session): + bf_df = dataframe.DataFrame( + ([1, 2, 3], [2, 5, 7]), + index=pd.MultiIndex.from_tuples([("a", "aa"), ("a", "aa")]), + session=session, + ) + pd_df = pd.DataFrame( + ([1, 2, 3], [2, 5, 7]), + index=pd.MultiIndex.from_tuples([("a", "aa"), ("a", "aa")]), + ) + + bf_df = bf_df.reset_index() + pd_df = pd_df.reset_index() + + assert pd_df.columns[0] == "level_0" + assert bf_df.columns[0] == "level_0" + assert pd_df.columns[1] == "level_1" + assert bf_df.columns[1] == "level_1" + + +def test_reset_index_with_unnamed_index_and_index_column( + scalars_df_index, + scalars_pandas_df_index, +): + scalars_df_index = scalars_df_index.copy() + scalars_pandas_df_index = scalars_pandas_df_index.copy() + + scalars_df_index.index.name = None + scalars_pandas_df_index.index.name = None + df = scalars_df_index.assign(index=scalars_df_index["int64_col"]).reset_index( + drop=False + ) + assert df.index.name is None + + # reset_index(drop=False) creates a new column "level_0" if the "index" column already exists. + assert df.columns[0] == "level_0" + + bf_result = df.to_pandas() + pd_result = scalars_pandas_df_index.assign( + index=scalars_pandas_df_index["int64_col"] + ).reset_index(drop=False) + + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + + # reset_index should maintain the original ordering. + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("drop",), + ( + (True,), + (False,), + ), +) +@pytest.mark.parametrize( + ("append",), + ( + (True,), + (False,), + ), +) +@pytest.mark.parametrize( + ("index_column",), + (("int64_too",), ("string_col",), ("timestamp_col",)), +) +def test_set_index(scalars_dfs, index_column, drop, append): + scalars_df, scalars_pandas_df = scalars_dfs + df = scalars_df.set_index(index_column, append=append, drop=drop) + bf_result = df.to_pandas() + pd_result = scalars_pandas_df.set_index(index_column, append=append, drop=drop) + + # Sort to disambiguate when there are duplicate index labels. + # Note: Doesn't use assert_pandas_df_equal_ignore_ordering because we get + # "ValueError: 'timestamp_col' is both an index level and a column label, + # which is ambiguous" when trying to sort by a column with the same name as + # the index. + bf_result = bf_result.sort_values("rowindex_2") + pd_result = pd_result.sort_values("rowindex_2") + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_set_index_key_error(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + with pytest.raises(KeyError): + scalars_pandas_df.set_index(["not_a_col"]) + with pytest.raises(KeyError): + scalars_df.set_index(["not_a_col"]) + + +@pytest.mark.parametrize( + ("ascending",), + ((True,), (False,)), +) +@pytest.mark.parametrize( + ("na_position",), + (("first",), ("last",)), +) +@pytest.mark.parametrize( + ("axis",), + ((0,), ("columns",)), +) +def test_sort_index(scalars_dfs, ascending, na_position, axis): + index_column = "int64_col" + scalars_df, scalars_pandas_df = scalars_dfs + df = scalars_df.set_index(index_column) + bf_result = df.sort_index( + ascending=ascending, na_position=na_position, axis=axis + ).to_pandas() + pd_result = scalars_pandas_df.set_index(index_column).sort_index( + ascending=ascending, na_position=na_position, axis=axis + ) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_dataframe_sort_index_inplace(scalars_dfs): + index_column = "int64_col" + scalars_df, scalars_pandas_df = scalars_dfs + df = scalars_df.copy().set_index(index_column) + df.sort_index(ascending=False, inplace=True) + bf_result = df.to_pandas() + + pd_result = scalars_pandas_df.set_index(index_column).sort_index(ascending=False) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + +def test_df_abs(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + columns = ["int64_col", "int64_too", "float64_col"] + + bf_result = scalars_df[columns].abs() + pd_result = scalars_pandas_df[columns].abs() + + assert_dfs_equivalent(pd_result, bf_result) + + +def test_df_pos(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = (+scalars_df[["int64_col", "numeric_col"]]).to_pandas() + pd_result = +scalars_pandas_df[["int64_col", "numeric_col"]] + + assert_frame_equal(pd_result, bf_result) + + +def test_df_neg(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = (-scalars_df[["int64_col", "numeric_col"]]).to_pandas() + pd_result = -scalars_pandas_df[["int64_col", "numeric_col"]] + + assert_frame_equal(pd_result, bf_result) + + +def test_df_invert(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + columns = ["int64_col", "bool_col"] + + bf_result = (~scalars_df[columns]).to_pandas() + pd_result = ~scalars_pandas_df[columns] + + assert_frame_equal(bf_result, pd_result) + + +def test_df_isnull(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + columns = ["int64_col", "int64_too", "string_col", "bool_col"] + bf_result = scalars_df[columns].isnull().to_pandas() + pd_result = scalars_pandas_df[columns].isnull() + + # One of dtype mismatches to be documented. Here, the `bf_result.dtype` is + # `BooleanDtype` but the `pd_result.dtype` is `bool`. + pd_result["int64_col"] = pd_result["int64_col"].astype(pd.BooleanDtype()) + pd_result["int64_too"] = pd_result["int64_too"].astype(pd.BooleanDtype()) + pd_result["string_col"] = pd_result["string_col"].astype(pd.BooleanDtype()) + pd_result["bool_col"] = pd_result["bool_col"].astype(pd.BooleanDtype()) + + assert_frame_equal(bf_result, pd_result) + + +def test_df_notnull(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + columns = ["int64_col", "int64_too", "string_col", "bool_col"] + bf_result = scalars_df[columns].notnull().to_pandas() + pd_result = scalars_pandas_df[columns].notnull() + + # One of dtype mismatches to be documented. Here, the `bf_result.dtype` is + # `BooleanDtype` but the `pd_result.dtype` is `bool`. + pd_result["int64_col"] = pd_result["int64_col"].astype(pd.BooleanDtype()) + pd_result["int64_too"] = pd_result["int64_too"].astype(pd.BooleanDtype()) + pd_result["string_col"] = pd_result["string_col"].astype(pd.BooleanDtype()) + pd_result["bool_col"] = pd_result["bool_col"].astype(pd.BooleanDtype()) + + assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("left_labels", "right_labels", "overwrite", "fill_value"), + [ + (["a", "b", "c"], ["c", "a", "b"], True, None), + (["a", "b", "c"], ["c", "a", "b"], False, None), + (["a", "b", "c"], ["a", "b", "c"], False, 2), + ], + ids=[ + "one_one_match_overwrite", + "one_one_match_no_overwrite", + "exact_match", + ], +) +def test_combine( + scalars_df_index, + scalars_df_2_index, + scalars_pandas_df_index, + left_labels, + right_labels, + overwrite, + fill_value, +): + if pd.__version__.startswith("1."): + pytest.skip("pd.NA vs NaN not handled well in pandas 1.x.") + columns = ["int64_too", "int64_col", "float64_col"] + + bf_df_a = scalars_df_index[columns] + bf_df_a.columns = left_labels + bf_df_b = scalars_df_2_index[columns] + bf_df_b.columns = right_labels + bf_result = bf_df_a.combine( + bf_df_b, + lambda x, y: x**2 + 2 * x * y + y**2, + overwrite=overwrite, + fill_value=fill_value, + ).to_pandas() + + pd_df_a = scalars_pandas_df_index[columns] + pd_df_a.columns = left_labels + pd_df_b = scalars_pandas_df_index[columns] + pd_df_b.columns = right_labels + pd_result = pd_df_a.combine( + pd_df_b, + lambda x, y: x**2 + 2 * x * y + y**2, + overwrite=overwrite, + fill_value=fill_value, + ) + + # Some dtype inconsistency for all-NULL columns + pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +@pytest.mark.parametrize( + ("overwrite", "filter_func"), + [ + (True, None), + (False, None), + (True, lambda x: x.isna() | (x % 2 == 0)), + ], + ids=[ + "default", + "overwritefalse", + "customfilter", + ], +) +def test_df_update(overwrite, filter_func): + if pd.__version__.startswith("1."): + pytest.skip("dtype handled differently in pandas 1.x.") + + index1: pandas.Index = pandas.Index([1, 2, 3, 4], dtype="Int64") + + index2: pandas.Index = pandas.Index([1, 2, 4, 5], dtype="Int64") + pd_df1 = pandas.DataFrame( + {"a": [1, None, 3, 4], "b": [5, 6, None, 8]}, dtype="Int64", index=index1 + ) + pd_df2 = pandas.DataFrame( + {"a": [None, 20, 30, 40], "c": [90, None, 110, 120]}, + dtype="Int64", + index=index2, + ) + + bf_df1 = dataframe.DataFrame(pd_df1) + bf_df2 = dataframe.DataFrame(pd_df2) + + bf_df1.update(bf_df2, overwrite=overwrite, filter_func=filter_func) + pd_df1.update(pd_df2, overwrite=overwrite, filter_func=filter_func) + + pd.testing.assert_frame_equal(bf_df1.to_pandas(), pd_df1) + + +def test_df_idxmin(): + pd_df = pd.DataFrame( + {"a": [1, 2, 3], "b": [7, None, 3], "c": [4, 4, 4]}, index=["x", "y", "z"] + ) + bf_df = dataframe.DataFrame(pd_df) + + bf_result = bf_df.idxmin().to_pandas() + pd_result = pd_df.idxmin() + + pd.testing.assert_series_equal( + bf_result, pd_result, check_index_type=False, check_dtype=False + ) + + +def test_df_idxmax(): + pd_df = pd.DataFrame( + {"a": [1, 2, 3], "b": [7, None, 3], "c": [4, 4, 4]}, index=["x", "y", "z"] + ) + bf_df = dataframe.DataFrame(pd_df) + + bf_result = bf_df.idxmax().to_pandas() + pd_result = pd_df.idxmax() + + pd.testing.assert_series_equal( + bf_result, pd_result, check_index_type=False, check_dtype=False + ) + + +@pytest.mark.parametrize( + ("join", "axis"), + [ + ("outer", None), + ("outer", 0), + ("outer", 1), + ("left", 0), + ("right", 1), + ("inner", None), + ("inner", 1), + ], +) +def test_df_align(join, axis): + index1: pandas.Index = pandas.Index([1, 2, 3, 4], dtype="Int64") + + index2: pandas.Index = pandas.Index([1, 2, 4, 5], dtype="Int64") + pd_df1 = pandas.DataFrame( + {"a": [1, None, 3, 4], "b": [5, 6, None, 8]}, dtype="Int64", index=index1 + ) + pd_df2 = pandas.DataFrame( + {"a": [None, 20, 30, 40], "c": [90, None, 110, 120]}, + dtype="Int64", + index=index2, + ) + + bf_df1 = dataframe.DataFrame(pd_df1) + bf_df2 = dataframe.DataFrame(pd_df2) + + bf_result1, bf_result2 = bf_df1.align(bf_df2, join=join, axis=axis) + pd_result1, pd_result2 = pd_df1.align(pd_df2, join=join, axis=axis) + + # Don't check dtype as pandas does unnecessary float conversion + assert isinstance(bf_result1, dataframe.DataFrame) and isinstance( + bf_result2, dataframe.DataFrame + ) + pd.testing.assert_frame_equal(bf_result1.to_pandas(), pd_result1, check_dtype=False) + pd.testing.assert_frame_equal(bf_result2.to_pandas(), pd_result2, check_dtype=False) + + +def test_combine_first( + scalars_df_index, + scalars_df_2_index, + scalars_pandas_df_index, +): + if pd.__version__.startswith("1."): + pytest.skip("pd.NA vs NaN not handled well in pandas 1.x.") + columns = ["int64_too", "int64_col", "float64_col"] + + bf_df_a = scalars_df_index[columns].iloc[0:6] + bf_df_a.columns = ["a", "b", "c"] + bf_df_b = scalars_df_2_index[columns].iloc[2:8] + bf_df_b.columns = ["b", "a", "d"] + bf_result = bf_df_a.combine_first(bf_df_b).to_pandas() + + pd_df_a = scalars_pandas_df_index[columns].iloc[0:6] + pd_df_a.columns = ["a", "b", "c"] + pd_df_b = scalars_pandas_df_index[columns].iloc[2:8] + pd_df_b.columns = ["b", "a", "d"] + pd_result = pd_df_a.combine_first(pd_df_b) + + # Some dtype inconsistency for all-NULL columns + pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +def test_df_corr_w_invalid_parameters(scalars_dfs): + columns = ["int64_too", "int64_col", "float64_col"] + scalars_df, _ = scalars_dfs + + with pytest.raises(NotImplementedError): + scalars_df[columns].corr(method="kendall") + + with pytest.raises(NotImplementedError): + scalars_df[columns].corr(min_periods=1) + + +@pytest.mark.parametrize( + ("columns", "numeric_only"), + [ + (["bool_col", "int64_col", "float64_col"], True), + (["bool_col", "int64_col", "float64_col"], False), + (["bool_col", "int64_col", "float64_col", "string_col"], True), + pytest.param( + ["bool_col", "int64_col", "float64_col", "string_col"], + False, + marks=pytest.mark.xfail( + raises=NotImplementedError, + ), + ), + ], +) +def test_cov_w_numeric_only(scalars_dfs, columns, numeric_only): + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[columns].cov(numeric_only=numeric_only).to_pandas() + pd_result = scalars_pandas_df[columns].cov(numeric_only=numeric_only) + # BigFrames and Pandas differ in their data type handling: + # - Column types: BigFrames uses Float64, Pandas uses float64. + # - Index types: BigFrames uses strign, Pandas uses object. + pd.testing.assert_index_equal(bf_result.columns, pd_result.columns) + # Only check row order in ordered mode. + pd.testing.assert_frame_equal( + bf_result, + pd_result, + check_dtype=False, + check_index_type=False, + check_like=~scalars_df._block.session._strictly_ordered, + ) + + +def test_df_corrwith_df(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + l_cols = ["int64_col", "float64_col", "int64_too"] + r_cols = ["int64_too", "float64_col"] + + bf_result = scalars_df[l_cols].corrwith(scalars_df[r_cols]).to_pandas() + pd_result = scalars_pandas_df[l_cols].corrwith(scalars_pandas_df[r_cols]) + + # BigFrames and Pandas differ in their data type handling: + # - Column types: BigFrames uses Float64, Pandas uses float64. + # - Index types: BigFrames uses strign, Pandas uses object. + pd.testing.assert_series_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + +def test_df_corrwith_df_numeric_only(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + l_cols = ["int64_col", "float64_col", "int64_too", "string_col"] + r_cols = ["int64_too", "float64_col", "bool_col"] + + bf_result = ( + scalars_df[l_cols].corrwith(scalars_df[r_cols], numeric_only=True).to_pandas() + ) + pd_result = scalars_pandas_df[l_cols].corrwith( + scalars_pandas_df[r_cols], numeric_only=True + ) + + # BigFrames and Pandas differ in their data type handling: + # - Column types: BigFrames uses Float64, Pandas uses float64. + # - Index types: BigFrames uses strign, Pandas uses object. + pd.testing.assert_series_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + +def test_df_corrwith_df_non_numeric_error(scalars_dfs): + scalars_df, _ = scalars_dfs + + l_cols = ["int64_col", "float64_col", "int64_too", "string_col"] + r_cols = ["int64_too", "float64_col", "bool_col"] + + with pytest.raises(NotImplementedError): + scalars_df[l_cols].corrwith(scalars_df[r_cols], numeric_only=False) + + +def test_df_corrwith_series(scalars_dfs): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df, scalars_pandas_df = scalars_dfs + + l_cols = ["int64_col", "float64_col", "int64_too"] + r_col = "float64_col" + + bf_result = scalars_df[l_cols].corrwith(scalars_df[r_col]).to_pandas() + pd_result = scalars_pandas_df[l_cols].corrwith(scalars_pandas_df[r_col]) + + # BigFrames and Pandas differ in their data type handling: + # - Column types: BigFrames uses Float64, Pandas uses float64. + # - Index types: BigFrames uses strign, Pandas uses object. + pd.testing.assert_series_equal( + bf_result, pd_result, check_dtype=False, check_index_type=False + ) + + +@pytest.mark.parametrize( + ("op"), + [ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.eq, + operator.ne, + operator.gt, + operator.ge, + operator.lt, + operator.le, + ], + ids=[ + "add", + "subtract", + "multiply", + "true_divide", + "floor_divide", + "eq", + "ne", + "gt", + "ge", + "lt", + "le", + ], +) +# TODO(garrettwu): deal with NA values +@pytest.mark.parametrize(("other_scalar"), [1, 2.5, 0, 0.0]) +@pytest.mark.parametrize(("reverse_operands"), [True, False]) +def test_scalar_binop(scalars_dfs, op, other_scalar, reverse_operands): + scalars_df, scalars_pandas_df = scalars_dfs + columns = ["int64_col", "float64_col"] + + maybe_reversed_op = (lambda x, y: op(y, x)) if reverse_operands else op + + bf_result = maybe_reversed_op(scalars_df[columns], other_scalar).to_pandas() + pd_result = maybe_reversed_op(scalars_pandas_df[columns], other_scalar) + + assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize(("other_scalar"), [1, -2]) +def test_mod(scalars_dfs, other_scalar): + # Zero case excluded as pandas produces 0 result for Int64 inputs rather than NA/NaN. + # This is likely a pandas bug as mod 0 is undefined in other dtypes, and most programming languages. + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = (scalars_df[["int64_col", "int64_too"]] % other_scalar).to_pandas() + pd_result = scalars_pandas_df[["int64_col", "int64_too"]] % other_scalar + + assert_frame_equal(bf_result, pd_result) + + +def test_scalar_binop_str_exception(scalars_dfs): + scalars_df, _ = scalars_dfs + columns = ["string_col"] + with pytest.raises(TypeError, match="Cannot add dtypes"): + (scalars_df[columns] + 1).to_pandas() + + +@pytest.mark.parametrize( + ("op"), + [ + (lambda x, y: x.add(y, axis="index")), + (lambda x, y: x.radd(y, axis="index")), + (lambda x, y: x.sub(y, axis="index")), + (lambda x, y: x.rsub(y, axis="index")), + (lambda x, y: x.mul(y, axis="index")), + (lambda x, y: x.rmul(y, axis="index")), + (lambda x, y: x.truediv(y, axis="index")), + (lambda x, y: x.rtruediv(y, axis="index")), + (lambda x, y: x.floordiv(y, axis="index")), + (lambda x, y: x.floordiv(y, axis="index")), + (lambda x, y: x.gt(y, axis="index")), + (lambda x, y: x.ge(y, axis="index")), + (lambda x, y: x.lt(y, axis="index")), + (lambda x, y: x.le(y, axis="index")), + ], + ids=[ + "add", + "radd", + "sub", + "rsub", + "mul", + "rmul", + "truediv", + "rtruediv", + "floordiv", + "rfloordiv", + "gt", + "ge", + "lt", + "le", + ], +) +def test_series_binop_axis_index( + scalars_dfs, + op, +): + scalars_df, scalars_pandas_df = scalars_dfs + df_columns = ["int64_col", "float64_col"] + series_column = "int64_too" + + bf_result = op(scalars_df[df_columns], scalars_df[series_column]).to_pandas() + pd_result = op(scalars_pandas_df[df_columns], scalars_pandas_df[series_column]) + + assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("input"), + [ + ((1000, 2000, 3000)), + (pd.Index([1000, 2000, 3000])), + (pd.Series((1000, 2000), index=["int64_too", "float64_col"])), + ], + ids=[ + "tuple", + "pd_index", + "pd_series", + ], +) +def test_listlike_binop_axis_1_in_memory_data(scalars_dfs, input): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df, scalars_pandas_df = scalars_dfs + + df_columns = ["int64_col", "float64_col", "int64_too"] + + bf_result = scalars_df[df_columns].add(input, axis=1).to_pandas() + if hasattr(input, "to_pandas"): + input = input.to_pandas() + pd_result = scalars_pandas_df[df_columns].add(input, axis=1) + + assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +def test_df_reverse_binop_pandas(scalars_dfs): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df, scalars_pandas_df = scalars_dfs + + pd_series = pd.Series([100, 200, 300]) + + df_columns = ["int64_col", "float64_col", "int64_too"] + + bf_result = pd_series + scalars_df[df_columns].to_pandas() + pd_result = pd_series + scalars_pandas_df[df_columns] + + assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +def test_listlike_binop_axis_1_bf_index(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + df_columns = ["int64_col", "float64_col", "int64_too"] + + bf_result = ( + scalars_df[df_columns] + .add(bf_indexes.Index([1000, 2000, 3000]), axis=1) + .to_pandas() + ) + pd_result = scalars_pandas_df[df_columns].add(pd.Index([1000, 2000, 3000]), axis=1) + + assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +def test_binop_with_self_aggregate(session, scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + + df_columns = ["int64_col", "float64_col", "int64_too"] + + bf_df = scalars_df[df_columns] + bf_result = (bf_df - bf_df.mean()).to_pandas() + + pd_df = scalars_pandas_df[df_columns] + pd_result = pd_df - pd_df.mean() + + assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +@pytest.mark.parametrize( + ("left_labels", "right_labels"), + [ + (["a", "a", "b"], ["c", "c", "d"]), + (["a", "b", "c"], ["c", "a", "b"]), + (["a", "c", "c"], ["c", "a", "c"]), + (["a", "b", "c"], ["a", "b", "c"]), + ], + ids=[ + "no_overlap", + "one_one_match", + "multi_match", + "exact_match", + ], +) +def test_binop_df_df_binary_op( + scalars_df_index, + scalars_df_2_index, + scalars_pandas_df_index, + left_labels, + right_labels, +): + if pd.__version__.startswith("1."): + pytest.skip("pd.NA vs NaN not handled well in pandas 1.x.") + columns = ["int64_too", "int64_col", "float64_col"] + + bf_df_a = scalars_df_index[columns] + bf_df_a.columns = left_labels + bf_df_b = scalars_df_2_index[columns] + bf_df_b.columns = right_labels + bf_result = (bf_df_a - bf_df_b).to_pandas() + + pd_df_a = scalars_pandas_df_index[columns] + pd_df_a.columns = left_labels + pd_df_b = scalars_pandas_df_index[columns] + pd_df_b.columns = right_labels + pd_result = pd_df_a - pd_df_b + + # Some dtype inconsistency for all-NULL columns + pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + +# Differnt table will only work for explicit index, since default index orders are arbitrary. +@pytest.mark.parametrize( + ("ordered"), + [ + (True), + (False), + ], +) +def test_series_binop_add_different_table( + scalars_df_index, scalars_pandas_df_index, scalars_df_2_index, ordered +): + df_columns = ["int64_col", "float64_col"] + series_column = "int64_too" + + bf_result = ( + scalars_df_index[df_columns] + .add(scalars_df_2_index[series_column], axis="index") + .to_pandas(ordered=ordered) + ) + pd_result = scalars_pandas_df_index[df_columns].add( + scalars_pandas_df_index[series_column], axis="index" + ) + + assert_frame_equal(bf_result, pd_result, ignore_order=not ordered) + + +# TODO(garrettwu): Test series binop with different index + +all_joins = pytest.mark.parametrize( + ("how",), + (("outer",), ("left",), ("right",), ("inner",), ("cross",)), +) + + +@all_joins +def test_join_same_table(scalars_dfs, how): + bf_df, pd_df = scalars_dfs + if not bf_df._session._strictly_ordered and how == "cross": + pytest.skip("Cross join not supported in partial ordering mode.") + + bf_df_a = bf_df.set_index("int64_too")[["string_col", "int64_col"]] + bf_df_a = bf_df_a.sort_index() + + bf_df_b = bf_df.set_index("int64_too")[["float64_col"]] + bf_df_b = bf_df_b[bf_df_b.float64_col > 0] + bf_df_b = bf_df_b.sort_values("float64_col") + + bf_result = bf_df_a.join(bf_df_b, how=how).to_pandas() + + pd_df_a = pd_df.set_index("int64_too")[["string_col", "int64_col"]].sort_index() + pd_df_a = pd_df_a.sort_index() + + pd_df_b = pd_df.set_index("int64_too")[["float64_col"]] + pd_df_b = pd_df_b[pd_df_b.float64_col > 0] + pd_df_b = pd_df_b.sort_values("float64_col") + + pd_result = pd_df_a.join(pd_df_b, how=how) + + assert_frame_equal(bf_result, pd_result, ignore_order=True) + + +@all_joins +def test_join_different_table( + scalars_df_index, scalars_df_2_index, scalars_pandas_df_index, how +): + bf_df_a = scalars_df_index[["string_col", "int64_col"]] + bf_df_b = scalars_df_2_index.dropna()[["float64_col"]] + bf_result = bf_df_a.join(bf_df_b, how=how).to_pandas() + pd_df_a = scalars_pandas_df_index[["string_col", "int64_col"]] + pd_df_b = scalars_pandas_df_index.dropna()[["float64_col"]] + pd_result = pd_df_a.join(pd_df_b, how=how) + assert_frame_equal(bf_result, pd_result, ignore_order=True) + + +@all_joins +def test_join_raise_when_param_on_duplicate_with_column(scalars_df_index, how): + if how == "cross": + return + bf_df_a = scalars_df_index[["string_col", "int64_col"]].rename( + columns={"int64_col": "string_col"} + ) + bf_df_b = scalars_df_index.dropna()["string_col"] + with pytest.raises( + ValueError, match="The column label 'string_col' is not unique." + ): + bf_df_a.join(bf_df_b, on="string_col", how=how, lsuffix="_l", rsuffix="_r") + + +def test_join_duplicate_columns_raises_value_error(scalars_dfs): + scalars_df, _ = scalars_dfs + df_a = scalars_df[["string_col", "float64_col"]] + df_b = scalars_df[["float64_col"]] + with pytest.raises(ValueError, match="columns overlap but no suffix specified"): + df_a.join(df_b, how="outer") + + +@all_joins +def test_join_param_on_duplicate_with_index_raises_value_error(scalars_df_index, how): + if how == "cross": + return + bf_df_a = scalars_df_index[["string_col"]] + bf_df_a.index.name = "string_col" + bf_df_b = scalars_df_index.dropna()["string_col"] + with pytest.raises( + ValueError, + match="'string_col' is both an index level and a column label, which is ambiguous.", + ): + bf_df_a.join(bf_df_b, on="string_col", how=how, lsuffix="_l", rsuffix="_r") + + +@all_joins +def test_join_param_on(scalars_dfs, how): + bf_df, pd_df = scalars_dfs + + bf_df_a = bf_df[["string_col", "int64_col", "rowindex_2"]] + bf_df_a = bf_df_a.assign(rowindex_2=bf_df_a["rowindex_2"] + 2) + bf_df_b = bf_df[["float64_col"]] + + if how == "cross": + with pytest.raises(ValueError, match="'on' is not supported for cross join."): + bf_df_a.join(bf_df_b, on="rowindex_2", how=how) + else: + bf_result = bf_df_a.join(bf_df_b, on="rowindex_2", how=how).to_pandas() + + pd_df_a = pd_df[["string_col", "int64_col", "rowindex_2"]] + pd_df_a = pd_df_a.assign(rowindex_2=pd_df_a["rowindex_2"] + 2) + pd_df_b = pd_df[["float64_col"]] + pd_result = pd_df_a.join(pd_df_b, on="rowindex_2", how=how) + assert_frame_equal(bf_result, pd_result, ignore_order=True) + + +@all_joins +def test_df_join_series(scalars_dfs, how): + bf_df, pd_df = scalars_dfs + + bf_df_a = bf_df[["string_col", "int64_col", "rowindex_2"]] + bf_df_a = bf_df_a.assign(rowindex_2=bf_df_a["rowindex_2"] + 2) + bf_series_b = bf_df["float64_col"] + + if how == "cross": + with pytest.raises(ValueError): + bf_df_a.join(bf_series_b, on="rowindex_2", how=how) + else: + bf_result = bf_df_a.join(bf_series_b, on="rowindex_2", how=how).to_pandas() + + pd_df_a = pd_df[["string_col", "int64_col", "rowindex_2"]] + pd_df_a = pd_df_a.assign(rowindex_2=pd_df_a["rowindex_2"] + 2) + pd_series_b = pd_df["float64_col"] + pd_result = pd_df_a.join(pd_series_b, on="rowindex_2", how=how) + assert_frame_equal(bf_result, pd_result, ignore_order=True) + + +@pytest.mark.parametrize( + ("by", "ascending", "na_position"), + [ + ("int64_col", True, "first"), + (["bool_col", "int64_col"], True, "last"), + ("int64_col", False, "first"), + (["bool_col", "int64_col"], [False, True], "last"), + (["bool_col", "int64_col"], [True, False], "first"), + ], +) +def test_dataframe_sort_values( + scalars_df_index, scalars_pandas_df_index, by, ascending, na_position +): + # Test needs values to be unique + bf_result = scalars_df_index.sort_values( + by, ascending=ascending, na_position=na_position + ).to_pandas() + pd_result = scalars_pandas_df_index.sort_values( + by, ascending=ascending, na_position=na_position + ) + + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +@pytest.mark.parametrize( + ("by", "ascending", "na_position"), + [ + ("int64_col", True, "first"), + (["bool_col", "int64_col"], True, "last"), + ], +) +def test_dataframe_sort_values_inplace( + scalars_df_index, scalars_pandas_df_index, by, ascending, na_position +): + # Test needs values to be unique + bf_sorted = scalars_df_index.copy() + bf_sorted.sort_values( + by, ascending=ascending, na_position=na_position, inplace=True + ) + bf_result = bf_sorted.to_pandas() + pd_result = scalars_pandas_df_index.sort_values( + by, ascending=ascending, na_position=na_position + ) + + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_dataframe_sort_values_invalid_input(scalars_df_index): + with pytest.raises(KeyError): + scalars_df_index.sort_values(by=scalars_df_index["int64_col"]) + + +def test_dataframe_sort_values_stable(scalars_df_index, scalars_pandas_df_index): + bf_result = ( + scalars_df_index.sort_values("int64_col", kind="stable") + .sort_values("bool_col", kind="stable") + .to_pandas() + ) + pd_result = scalars_pandas_df_index.sort_values( + "int64_col", kind="stable" + ).sort_values("bool_col", kind="stable") + + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +@pytest.mark.parametrize( + ("operator", "columns"), + [ + pytest.param(lambda x: x.cumsum(), ["float64_col", "int64_too"]), + # pytest.param(lambda x: x.cumprod(), ["float64_col", "int64_too"]), + pytest.param( + lambda x: x.cumprod(), + ["string_col"], + marks=pytest.mark.xfail( + raises=ValueError, + ), + ), + ], + ids=[ + "cumsum", + # "cumprod", + "non-numeric", + ], +) +def test_dataframe_numeric_analytic_op( + scalars_df_index, scalars_pandas_df_index, operator, columns +): + # TODO: Add nullable ints (pandas 1.x has poor behavior on these) + bf_series = operator(scalars_df_index[columns]) + pd_series = operator(scalars_pandas_df_index[columns]) + bf_result = bf_series.to_pandas() + pd.testing.assert_frame_equal(pd_series, bf_result, check_dtype=False) + + +@pytest.mark.parametrize( + ("operator"), + [ + (lambda x: x.cummin()), + (lambda x: x.cummax()), + (lambda x: x.shift(2)), + (lambda x: x.shift(-2)), + ], + ids=[ + "cummin", + "cummax", + "shiftpostive", + "shiftnegative", + ], +) +def test_dataframe_general_analytic_op( + scalars_df_index, scalars_pandas_df_index, operator +): + col_names = ["int64_too", "float64_col", "int64_col", "bool_col"] + bf_series = operator(scalars_df_index[col_names]) + pd_series = operator(scalars_pandas_df_index[col_names]) + bf_result = bf_series.to_pandas() + pd.testing.assert_frame_equal( + pd_series, + bf_result, + ) + + +@pytest.mark.parametrize( + ("periods",), + [ + (1,), + (2,), + (-1,), + ], +) +def test_dataframe_diff(scalars_df_index, scalars_pandas_df_index, periods): + col_names = ["int64_too", "float64_col", "int64_col"] + bf_result = scalars_df_index[col_names].diff(periods=periods).to_pandas() + pd_result = scalars_pandas_df_index[col_names].diff(periods=periods) + pd.testing.assert_frame_equal( + pd_result, + bf_result, + ) + + +@pytest.mark.parametrize( + ("periods",), + [ + (1,), + (2,), + (-1,), + ], +) +def test_dataframe_pct_change(scalars_df_index, scalars_pandas_df_index, periods): + col_names = ["int64_too", "float64_col", "int64_col"] + bf_result = scalars_df_index[col_names].pct_change(periods=periods).to_pandas() + # pandas 3.0 does not automatically ffill anymore + pd_result = scalars_pandas_df_index[col_names].ffill().pct_change(periods=periods) + assert_frame_equal( + pd_result, + bf_result, + nulls_are_nan=True, + ) + + +def test_dataframe_agg_single_string(scalars_dfs): + numeric_cols = ["int64_col", "int64_too", "float64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df[numeric_cols].agg("sum").to_pandas() + pd_result = scalars_pandas_df[numeric_cols].agg("sum") + + assert bf_result.dtype == "Float64" + pd.testing.assert_series_equal( + pd_result, bf_result, check_dtype=False, check_index_type=False + ) + + +@pytest.mark.parametrize( + ("agg",), + ( + ("sum",), + ("size",), + ), +) +def test_dataframe_agg_int_single_string(scalars_dfs, agg): + numeric_cols = ["int64_col", "int64_too", "bool_col"] + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df[numeric_cols].agg(agg).to_pandas() + pd_result = scalars_pandas_df[numeric_cols].agg(agg) + + assert bf_result.dtype == "Int64" + pd.testing.assert_series_equal( + pd_result, bf_result, check_dtype=False, check_index_type=False + ) + + +def test_dataframe_agg_multi_string(scalars_dfs): + numeric_cols = ["int64_col", "int64_too", "float64_col"] + aggregations = [ + "sum", + "mean", + "median", + "std", + "var", + "min", + "max", + "nunique", + "count", + ] + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[numeric_cols].agg(aggregations) + pd_result = scalars_pandas_df[numeric_cols].agg(aggregations) + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + pd_result = pd_result.astype("Float64") + + # Drop median, as it's an approximation. + bf_median = bf_result.loc["median", :] + bf_result = bf_result.drop(labels=["median"]) + pd_result = pd_result.drop(labels=["median"]) + + assert_dfs_equivalent(pd_result, bf_result, check_index_type=False) + + # Double-check that median is at least plausible. + assert ( + (bf_result.loc["min", :] <= bf_median) & (bf_median <= bf_result.loc["max", :]) + ).all() + + +def test_dataframe_agg_int_multi_string(scalars_dfs): + numeric_cols = ["int64_col", "int64_too", "bool_col"] + aggregations = [ + "sum", + "nunique", + "count", + "size", + ] + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[numeric_cols].agg(aggregations).to_pandas() + pd_result = scalars_pandas_df[numeric_cols].agg(aggregations) + + for dtype in bf_result.dtypes: + assert dtype == "Int64" + + # Pandas may produce narrower numeric types + # Pandas has object index type + pd.testing.assert_frame_equal( + pd_result, bf_result, check_dtype=False, check_index_type=False + ) + + +def test_df_transpose(): + # Include some floats to ensure type coercion + values = [[0, 3.5, True], [1, 4.5, False], [2, 6.5, None]] + # Test complex case of both axes being multi-indices with non-unique elements + + columns: pandas.Index = pd.Index( + ["A", "B", "A"], dtype=pd.StringDtype(storage="pyarrow") + ) + columns_multi = pd.MultiIndex.from_arrays([columns, columns], names=["c1", "c2"]) + + index: pandas.Index = pd.Index( + ["b", "a", "a"], dtype=pd.StringDtype(storage="pyarrow") + ) + rows_multi = pd.MultiIndex.from_arrays([index, index], names=["r1", "r2"]) + + pd_df = pandas.DataFrame(values, index=rows_multi, columns=columns_multi) + bf_df = dataframe.DataFrame(values, index=rows_multi, columns=columns_multi) + + pd_result = pd_df.T + bf_result = bf_df.T.to_pandas() + + assert_frame_equal(pd_result, bf_result, check_dtype=False, nulls_are_nan=True) + + +def test_df_transpose_error(): + with pytest.raises(TypeError, match="Cannot coerce.*to a common type."): + dataframe.DataFrame([[1, "hello"], [2, "world"]]).transpose() + + +def test_df_transpose_repeated_uses_cache(): + bf_df = dataframe.DataFrame([[1, 2.5], [2, 3.5]]) + pd_df = pandas.DataFrame([[1, 2.5], [2, 3.5]]) + # Transposing many times so that operation will fail from complexity if not using cache + for i in range(10): + # Cache still works even with simple scalar binop + bf_df = bf_df.transpose() + i + pd_df = pd_df.transpose() + i + + pd.testing.assert_frame_equal( + pd_df, bf_df.to_pandas(), check_dtype=False, check_index_type=False + ) + + +def test_df_stack(scalars_dfs): + if pandas.__version__.startswith("1.") or pandas.__version__.startswith("2.0"): + pytest.skip("pandas <2.1 uses different stack implementation") + scalars_df, scalars_pandas_df = scalars_dfs + # To match bigquery dataframes + scalars_pandas_df = scalars_pandas_df.copy() + scalars_pandas_df.columns = scalars_pandas_df.columns.astype("string[pyarrow]") + # Can only stack identically-typed columns + columns = ["int64_col", "int64_too", "rowindex_2"] + + bf_result = scalars_df[columns].stack().to_pandas() + pd_result = scalars_pandas_df[columns].stack(future_stack=True) + + # Pandas produces NaN, where bq dataframes produces pd.NA + assert_series_equal(bf_result, pd_result, check_dtype=False) + + +def test_df_melt_default(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + # To match bigquery dataframes + scalars_pandas_df = scalars_pandas_df.copy() + scalars_pandas_df.columns = scalars_pandas_df.columns.astype("string[pyarrow]") + # Can only stack identically-typed columns + columns = ["int64_col", "int64_too", "rowindex_2"] + + bf_result = scalars_df[columns].melt().to_pandas() + pd_result = scalars_pandas_df[columns].melt() + + # Pandas produces int64 index, Bigframes produces Int64 (nullable) + pd.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + +def test_df_melt_parameterized(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + # To match bigquery dataframes + scalars_pandas_df = scalars_pandas_df.copy() + scalars_pandas_df.columns = scalars_pandas_df.columns.astype("string[pyarrow]") + # Can only stack identically-typed columns + + bf_result = scalars_df.melt( + var_name="alice", + value_name="bob", + id_vars=["string_col"], + value_vars=["int64_col", "int64_too"], + ).to_pandas() + pd_result = scalars_pandas_df.melt( + var_name="alice", + value_name="bob", + id_vars=["string_col"], + value_vars=["int64_col", "int64_too"], + ) + + # Pandas produces int64 index, Bigframes produces Int64 (nullable) + pd.testing.assert_frame_equal( + bf_result, pd_result, check_index_type=False, check_dtype=False + ) + + +@pytest.mark.parametrize( + ("ordered"), + [ + (True), + (False), + ], +) +def test_df_unstack(scalars_dfs, ordered): + scalars_df, scalars_pandas_df = scalars_dfs + # To match bigquery dataframes + scalars_pandas_df = scalars_pandas_df.copy() + scalars_pandas_df.columns = scalars_pandas_df.columns.astype("string[pyarrow]") + # Can only stack identically-typed columns + columns = [ + "rowindex_2", + "int64_col", + "int64_too", + ] + + # unstack on mono-index produces series + bf_result = scalars_df[columns].unstack().to_pandas(ordered=ordered) + pd_result = scalars_pandas_df[columns].unstack() + + # Pandas produces NaN, where bq dataframes produces pd.NA + assert_series_equal( + bf_result, pd_result, check_dtype=False, ignore_order=not ordered + ) + + +def test_ipython_key_completions_with_drop(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = "string_col" + bf_dataframe = scalars_df.drop(columns=col_names) + pd_dataframe = scalars_pandas_df.drop(columns=col_names) + expected = pd_dataframe.columns.tolist() + + results = bf_dataframe._ipython_key_completions_() + + assert col_names not in results + assert results == expected + # _ipython_key_completions_ is called with square brackets + # so only column names are relevant with tab completion + assert "to_gbq" not in results + assert "merge" not in results + assert "drop" not in results + + +def test_ipython_key_completions_with_rename(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name_dict = {"string_col": "a_renamed_column"} + bf_dataframe = scalars_df.rename(columns=col_name_dict) + pd_dataframe = scalars_pandas_df.rename(columns=col_name_dict) + expected = pd_dataframe.columns.tolist() + + results = bf_dataframe._ipython_key_completions_() + + assert "string_col" not in results + assert "a_renamed_column" in results + assert results == expected + # _ipython_key_completions_ is called with square brackets + # so only column names are relevant with tab completion + assert "to_gbq" not in results + assert "merge" not in results + assert "drop" not in results + + +def test__dir__with_drop(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_names = "string_col" + bf_dataframe = scalars_df.drop(columns=col_names) + pd_dataframe = scalars_pandas_df.drop(columns=col_names) + expected = pd_dataframe.columns.tolist() + + results = dir(bf_dataframe) + + assert col_names not in results + assert frozenset(expected) <= frozenset(results) + # __dir__ is called with a '.' and displays all methods, columns names, etc. + assert "to_gbq" in results + assert "merge" in results + assert "drop" in results + + +def test__dir__with_rename(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + col_name_dict = {"string_col": "a_renamed_column"} + bf_dataframe = scalars_df.rename(columns=col_name_dict) + pd_dataframe = scalars_pandas_df.rename(columns=col_name_dict) + expected = pd_dataframe.columns.tolist() + + results = dir(bf_dataframe) + + assert "string_col" not in results + assert "a_renamed_column" in results + assert frozenset(expected) <= frozenset(results) + # __dir__ is called with a '.' and displays all methods, columns names, etc. + assert "to_gbq" in results + assert "merge" in results + assert "drop" in results + + +@pytest.mark.parametrize( + ("start", "stop", "step"), + [ + (0, 0, None), + (None, None, None), + (1, None, None), + (None, 4, None), + (None, None, 2), + (None, 50000000000, 1), + (5, 4, None), + (3, None, 2), + (1, 7, 2), + (1, 7, 50000000000), + ], +) +def test_iloc_slice(scalars_df_index, scalars_pandas_df_index, start, stop, step): + bf_result = scalars_df_index.iloc[start:stop:step].to_pandas() + pd_result = scalars_pandas_df_index.iloc[start:stop:step] + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_iloc_slice_zero_step(scalars_df_index): + with pytest.raises(ValueError): + scalars_df_index.iloc[0:0:0] + + +@pytest.mark.parametrize( + ("ordered"), + [ + (True), + (False), + ], +) +def test_iloc_slice_nested(scalars_df_index, scalars_pandas_df_index, ordered): + bf_result = scalars_df_index.iloc[1:].iloc[1:].to_pandas(ordered=ordered) + pd_result = scalars_pandas_df_index.iloc[1:].iloc[1:] + + assert_frame_equal(bf_result, pd_result, ignore_order=not ordered) + + +@pytest.mark.parametrize( + "index", + [0, 5, -2, (2,)], +) +def test_iloc_single_integer(scalars_df_index, scalars_pandas_df_index, index): + bf_result = scalars_df_index.iloc[index] + pd_result = scalars_pandas_df_index.iloc[index] + + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) + + +@pytest.mark.parametrize( + "index", + [(2, 5), (5, 0), (0, 0)], +) +def test_iloc_tuple(scalars_df_index, scalars_pandas_df_index, index): + bf_result = scalars_df_index.iloc[index] + pd_result = scalars_pandas_df_index.iloc[index] + + assert bf_result == pd_result + + +@pytest.mark.parametrize( + "index", + [(slice(None), [1, 2, 3]), (slice(1, 7, 2), [2, 5, 3])], +) +def test_iloc_tuple_multi_columns(scalars_df_index, scalars_pandas_df_index, index): + bf_result = scalars_df_index.iloc[index].to_pandas() + pd_result = scalars_pandas_df_index.iloc[index] + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +def test_iloc_tuple_multi_columns_single_row(scalars_df_index, scalars_pandas_df_index): + index = (2, [2, 1, 3, -4]) + bf_result = scalars_df_index.iloc[index] + pd_result = scalars_pandas_df_index.iloc[index] + pd.testing.assert_series_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("index", "error"), + [ + ((1, 1, 1), pd.errors.IndexingError), + (("asd", "asd", "asd"), pd.errors.IndexingError), + (("asd"), TypeError), + ], +) +def test_iloc_tuple_errors(scalars_df_index, scalars_pandas_df_index, index, error): + with pytest.raises(error): + scalars_df_index.iloc[index] + with pytest.raises(error): + scalars_pandas_df_index.iloc[index] + + +@pytest.mark.parametrize( + "index", + [(2, 5), (5, 0), (0, 0)], +) +def test_iat(scalars_df_index, scalars_pandas_df_index, index): + bf_result = scalars_df_index.iat[index] + pd_result = scalars_pandas_df_index.iat[index] + + assert bf_result == pd_result + + +@pytest.mark.parametrize( + ("index", "error"), + [ + (0, TypeError), + ("asd", ValueError), + ((1, 2, 3), TypeError), + (("asd", "asd"), ValueError), + ], +) +def test_iat_errors(scalars_df_index, scalars_pandas_df_index, index, error): + with pytest.raises(error): + scalars_pandas_df_index.iat[index] + with pytest.raises(error): + scalars_df_index.iat[index] + + +def test_iloc_single_integer_out_of_bound_error( + scalars_df_index, scalars_pandas_df_index +): + with pytest.raises(IndexError, match="single positional indexer is out-of-bounds"): + scalars_df_index.iloc[99] + + +def test_loc_bool_series(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.loc[scalars_df_index.bool_col].to_pandas() + pd_result = scalars_pandas_df_index.loc[scalars_pandas_df_index.bool_col] + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_loc_select_column(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.loc[:, "int64_col"].to_pandas() + pd_result = scalars_pandas_df_index.loc[:, "int64_col"] + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) + + +def test_loc_select_with_column_condition(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.loc[:, scalars_df_index.dtypes == "Int64"].to_pandas() + pd_result = scalars_pandas_df_index.loc[ + :, scalars_pandas_df_index.dtypes == "Int64" + ] + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_loc_select_with_column_condition_bf_series( + scalars_df_index, scalars_pandas_df_index +): + # (b/347072677) GEOGRAPH type doesn't support DISTINCT op + columns = [ + item for item in scalars_pandas_df_index.columns if item != "geography_col" + ] + scalars_df_index = scalars_df_index[columns] + scalars_pandas_df_index = scalars_pandas_df_index[columns] + + size_half = len(scalars_pandas_df_index) / 2 + bf_result = scalars_df_index.loc[ + :, scalars_df_index.nunique() > size_half + ].to_pandas() + pd_result = scalars_pandas_df_index.loc[ + :, scalars_pandas_df_index.nunique() > size_half + ] + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_loc_single_index_with_duplicate(scalars_df_index, scalars_pandas_df_index): + scalars_df_index = scalars_df_index.set_index("string_col", drop=False) + scalars_pandas_df_index = scalars_pandas_df_index.set_index( + "string_col", drop=False + ) + index = "Hello, World!" + bf_result = scalars_df_index.loc[index] + pd_result = scalars_pandas_df_index.loc[index] + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_loc_single_index_no_duplicate(scalars_df_index, scalars_pandas_df_index): + scalars_df_index = scalars_df_index.set_index("int64_too", drop=False) + scalars_pandas_df_index = scalars_pandas_df_index.set_index("int64_too", drop=False) + index = -2345 + bf_result = scalars_df_index.loc[index] + pd_result = scalars_pandas_df_index.loc[index] + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) + + +def test_at_with_duplicate(scalars_df_index, scalars_pandas_df_index): + scalars_df_index = scalars_df_index.set_index("string_col", drop=False) + scalars_pandas_df_index = scalars_pandas_df_index.set_index( + "string_col", drop=False + ) + index = "Hello, World!" + bf_result = scalars_df_index.at[index, "int64_too"] + pd_result = scalars_pandas_df_index.at[index, "int64_too"] + pd.testing.assert_series_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_at_no_duplicate(scalars_df_index, scalars_pandas_df_index): + scalars_df_index = scalars_df_index.set_index("int64_too", drop=False) + scalars_pandas_df_index = scalars_pandas_df_index.set_index("int64_too", drop=False) + index = -2345 + bf_result = scalars_df_index.at[index, "string_col"] + pd_result = scalars_pandas_df_index.at[index, "string_col"] + assert bf_result == pd_result + + +def test_loc_setitem_bool_series_scalar_new_col(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df.loc[bf_df["int64_too"] == 0, "new_col"] = 99 + pd_df.loc[pd_df["int64_too"] == 0, "new_col"] = 99 + + # pandas uses float64 instead + pd_df["new_col"] = pd_df["new_col"].astype("Float64") + + pd.testing.assert_frame_equal( + bf_df.to_pandas(), + pd_df, + ) + + +@pytest.mark.parametrize( + ("col", "value"), + [ + ("string_col", "hello"), + ("int64_col", 3), + ("float64_col", 3.5), + ], +) +def test_loc_setitem_bool_series_scalar_existing_col(scalars_dfs, col, value): + if pd.__version__.startswith("1."): + pytest.skip("this loc overload not supported in pandas 1.x.") + + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + bf_df.loc[bf_df["int64_too"] == 1, col] = value + pd_df.loc[pd_df["int64_too"] == 1, col] = value + + pd.testing.assert_frame_equal( + bf_df.to_pandas(), + pd_df, + ) + + +def test_loc_setitem_bool_series_scalar_error(scalars_dfs): + if pd.__version__.startswith("1."): + pytest.skip("this loc overload not supported in pandas 1.x.") + + scalars_df, scalars_pandas_df = scalars_dfs + bf_df = scalars_df.copy() + pd_df = scalars_pandas_df.copy() + + with pytest.raises(Exception): + bf_df.loc[bf_df["int64_too"] == 1, "string_col"] = 99 + with pytest.raises(Exception): + pd_df.loc[pd_df["int64_too"] == 1, "string_col"] = 99 + + +@pytest.mark.parametrize( + ("col", "op"), + [ + # Int aggregates + pytest.param("int64_col", lambda x: x.sum(), id="int-sum"), + pytest.param("int64_col", lambda x: x.min(), id="int-min"), + pytest.param("int64_col", lambda x: x.max(), id="int-max"), + pytest.param("int64_col", lambda x: x.count(), id="int-count"), + pytest.param("int64_col", lambda x: x.nunique(), id="int-nunique"), + # Float aggregates + pytest.param("float64_col", lambda x: x.count(), id="float-count"), + pytest.param("float64_col", lambda x: x.nunique(), id="float-nunique"), + # Bool aggregates + pytest.param("bool_col", lambda x: x.sum(), id="bool-sum"), + pytest.param("bool_col", lambda x: x.count(), id="bool-count"), + pytest.param("bool_col", lambda x: x.nunique(), id="bool-nunique"), + # String aggregates + pytest.param("string_col", lambda x: x.count(), id="string-count"), + pytest.param("string_col", lambda x: x.nunique(), id="string-nunique"), + ], +) +def test_dataframe_aggregate_int(scalars_df_index, scalars_pandas_df_index, col, op): + bf_result = op(scalars_df_index[[col]]).to_pandas() + pd_result = op(scalars_pandas_df_index[[col]]) + + # Check dtype separately + assert bf_result.dtype == "Int64" + # Is otherwise "object" dtype + pd_result.index = pd_result.index.astype("string[pyarrow]") + # Pandas may produce narrower numeric types + assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) + + +@pytest.mark.parametrize( + ("col", "op"), + [ + pytest.param("bool_col", lambda x: x.min(), id="bool-min"), + pytest.param("bool_col", lambda x: x.max(), id="bool-max"), + ], +) +def test_dataframe_aggregate_bool(scalars_df_index, scalars_pandas_df_index, col, op): + bf_result = op(scalars_df_index[[col]]).to_pandas() + pd_result = op(scalars_pandas_df_index[[col]]) + + # Check dtype separately + assert bf_result.dtype == "boolean" + + # Pandas may produce narrower numeric types + # Pandas has object index type + pd_result.index = pd_result.index.astype("string[pyarrow]") + assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) + + +@pytest.mark.parametrize( + ("op", "bf_dtype"), + [ + (lambda x: x.sum(numeric_only=True), "Float64"), + (lambda x: x.mean(numeric_only=True), "Float64"), + (lambda x: x.min(numeric_only=True), "Float64"), + (lambda x: x.max(numeric_only=True), "Float64"), + (lambda x: x.std(numeric_only=True), "Float64"), + (lambda x: x.var(numeric_only=True), "Float64"), + (lambda x: x.count(numeric_only=False), "Int64"), + (lambda x: x.nunique(), "Int64"), + ], + ids=["sum", "mean", "min", "max", "std", "var", "count", "nunique"], +) +def test_dataframe_aggregates(scalars_dfs, op, bf_dtype): + scalars_df_index, scalars_pandas_df_index = scalars_dfs + col_names = ["int64_too", "float64_col", "string_col", "int64_col", "bool_col"] + bf_series = op(scalars_df_index[col_names]) + bf_result = bf_series + pd_result = op(scalars_pandas_df_index[col_names]) + + # Check dtype separately + assert bf_result.dtype == bf_dtype + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + # Pandas has object index type + pd_result.index = pd_result.index.astype("string[pyarrow]") + assert_series_equivalent( + pd_result, + bf_result, + check_dtype=False, + check_index_type=False, + ) + + +@pytest.mark.parametrize( + ("op"), + [ + (lambda x: x.sum(axis=1, numeric_only=True)), + (lambda x: x.mean(axis=1, numeric_only=True)), + (lambda x: x.min(axis=1, numeric_only=True)), + (lambda x: x.max(axis=1, numeric_only=True)), + (lambda x: x.std(axis=1, numeric_only=True)), + (lambda x: x.var(axis=1, numeric_only=True)), + ], + ids=["sum", "mean", "min", "max", "std", "var"], +) +def test_dataframe_aggregates_axis_1(scalars_df_index, scalars_pandas_df_index, op): + col_names = ["int64_too", "int64_col", "float64_col", "bool_col", "string_col"] + bf_result = op(scalars_df_index[col_names]).to_pandas() + pd_result = op(scalars_pandas_df_index[col_names]) + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + # Pandas has object index type + assert_series_equal(pd_result, bf_result, check_index_type=False, check_dtype=False) + + +@pytest.mark.parametrize( + ("op"), + [ + (lambda x: x.all(bool_only=True)), + (lambda x: x.any(bool_only=True)), + (lambda x: x.all(axis=1, bool_only=True)), + (lambda x: x.any(axis=1, bool_only=True)), + ], + ids=["all_axis0", "any_axis0", "all_axis1", "any_axis1"], +) +def test_dataframe_bool_aggregates(scalars_df_index, scalars_pandas_df_index, op): + # Pandas will drop nullable 'boolean' dtype so we convert first to bool, then cast back later + scalars_df_index = scalars_df_index.assign( + bool_col=scalars_df_index.bool_col.fillna(False) + ) + scalars_pandas_df_index = scalars_pandas_df_index.assign( + bool_col=scalars_pandas_df_index.bool_col.fillna(False).astype("bool") + ) + bf_series = op(scalars_df_index) + pd_series = op(scalars_pandas_df_index).astype("boolean") + bf_result = bf_series.to_pandas() + + pd_series.index = pd_series.index.astype(bf_result.index.dtype) + pd.testing.assert_series_equal(pd_series, bf_result, check_index_type=False) + + +def test_dataframe_prod(scalars_df_index, scalars_pandas_df_index): + col_names = ["int64_too", "float64_col"] + bf_series = scalars_df_index[col_names].prod() + pd_series = scalars_pandas_df_index[col_names].prod() + bf_result = bf_series.to_pandas() + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + pd_series = pd_series.astype("Float64") + # Pandas has object index type + pd.testing.assert_series_equal(pd_series, bf_result, check_index_type=False) + + +def test_df_skew_too_few_values(scalars_dfs): + columns = ["float64_col", "int64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[columns].head(2).skew().to_pandas() + pd_result = scalars_pandas_df[columns].head(2).skew() + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + pd_result = pd_result.astype("Float64") + + pd.testing.assert_series_equal(pd_result, bf_result, check_index_type=False) + + +@pytest.mark.parametrize( + ("ordered"), + [ + (True), + (False), + ], +) +def test_df_skew(scalars_dfs, ordered): + columns = ["float64_col", "int64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[columns].skew().to_pandas(ordered=ordered) + pd_result = scalars_pandas_df[columns].skew() + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + pd_result = pd_result.astype("Float64") + + assert_series_equal( + pd_result, bf_result, check_index_type=False, ignore_order=not ordered + ) + + +def test_df_kurt_too_few_values(scalars_dfs): + columns = ["float64_col", "int64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[columns].head(2).kurt().to_pandas() + pd_result = scalars_pandas_df[columns].head(2).kurt() + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + pd_result = pd_result.astype("Float64") + + pd.testing.assert_series_equal(pd_result, bf_result, check_index_type=False) + + +def test_df_kurt(scalars_dfs): + columns = ["float64_col", "int64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + bf_result = scalars_df[columns].kurt().to_pandas() + pd_result = scalars_pandas_df[columns].kurt() + + # Pandas may produce narrower numeric types, but bigframes always produces Float64 + pd_result = pd_result.astype("Float64") + + pd.testing.assert_series_equal(pd_result, bf_result, check_index_type=False) + + +def test_sample_raises_value_error(scalars_dfs): + scalars_df, _ = scalars_dfs + with pytest.raises( + ValueError, match="Only one of 'n' or 'frac' parameter can be specified." + ): + scalars_df.sample(frac=0.5, n=4) + + +@pytest.mark.parametrize( + ("axis",), + [ + (None,), + (0,), + (1,), + ], +) +def test_df_add_prefix(scalars_df_index, scalars_pandas_df_index, axis): + if pd.__version__.startswith("1."): + pytest.skip("add_prefix axis parameter not supported in pandas 1.x.") + bf_result = scalars_df_index.add_prefix("prefix_", axis).to_pandas() + + pd_result = scalars_pandas_df_index.add_prefix("prefix_", axis) + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + ) + + +@pytest.mark.parametrize( + ("axis",), + [ + (0,), + (1,), + ], +) +def test_df_add_suffix(scalars_df_index, scalars_pandas_df_index, axis): + if pd.__version__.startswith("1."): + pytest.skip("add_prefix axis parameter not supported in pandas 1.x.") + bf_result = scalars_df_index.add_suffix("_suffix", axis).to_pandas() + + pd_result = scalars_pandas_df_index.add_suffix("_suffix", axis) + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + ) + + +def test_df_astype_error_error(session): + input = pd.DataFrame(["hello", "world", "3.11", "4000"]) + with pytest.raises(ValueError): + session.read_pandas(input).astype("Float64", errors="bad_value") + + +def test_df_columns_filter_items(scalars_df_index, scalars_pandas_df_index): + if pd.__version__.startswith("2.0") or pd.__version__.startswith("1."): + pytest.skip("pandas filter items behavior different pre-2.1") + bf_result = scalars_df_index.filter(items=["string_col", "int64_col"]).to_pandas() + + pd_result = scalars_pandas_df_index.filter(items=["string_col", "int64_col"]) + # Ignore column ordering as pandas order differently depending on version + pd.testing.assert_frame_equal( + bf_result.sort_index(axis=1), + pd_result.sort_index(axis=1), + ) + + +def test_df_columns_filter_like(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.filter(like="64_col").to_pandas() + + pd_result = scalars_pandas_df_index.filter(like="64_col") + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_df_columns_filter_regex(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.filter(regex="^[^_]+$").to_pandas() + + pd_result = scalars_pandas_df_index.filter(regex="^[^_]+$") + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_df_reindex_rows_list(scalars_dfs): + scalars_df_index, scalars_pandas_df_index = scalars_dfs + bf_result = scalars_df_index.reindex(index=[5, 1, 3, 99, 1]) + + pd_result = scalars_pandas_df_index.reindex(index=[5, 1, 3, 99, 1]) + + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + assert_dfs_equivalent( + pd_result, + bf_result, + ) + + +def test_df_reindex_rows_index(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.reindex( + index=pd.Index([5, 1, 3, 99, 1], name="newname") + ).to_pandas() + + pd_result = scalars_pandas_df_index.reindex( + index=pd.Index([5, 1, 3, 99, 1], name="newname") + ) + + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_df_reindex_nonunique(scalars_df_index): + with pytest.raises(ValueError): + # int64_too is non-unique + scalars_df_index.set_index("int64_too").reindex( + index=[5, 1, 3, 99, 1], validate=True + ) + + +def test_df_reindex_columns(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.reindex( + columns=["not_a_col", "int64_col", "int64_too"] + ).to_pandas() + + pd_result = scalars_pandas_df_index.reindex( + columns=["not_a_col", "int64_col", "int64_too"] + ) + + # Pandas uses float64 as default for newly created empty column, bf uses Float64 + pd_result.not_a_col = pd_result.not_a_col.astype(pandas.Float64Dtype()) + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_df_reindex_columns_with_same_order(scalars_df_index, scalars_pandas_df_index): + # First, make sure the two dataframes have the same columns in order. + columns = ["int64_col", "int64_too"] + bf = scalars_df_index[columns] + pd_df = scalars_pandas_df_index[columns] + + bf_result = bf.reindex(columns=columns).to_pandas() + pd_result = pd_df.reindex(columns=columns) + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_df_equals_identical(scalars_df_index, scalars_pandas_df_index): + unsupported = [ + "geography_col", + ] + scalars_df_index = scalars_df_index.drop(columns=unsupported) + scalars_pandas_df_index = scalars_pandas_df_index.drop(columns=unsupported) + + bf_result = scalars_df_index.equals(scalars_df_index) + pd_result = scalars_pandas_df_index.equals(scalars_pandas_df_index) + + assert pd_result == bf_result + + +def test_df_equals_series(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index[["int64_col"]].equals(scalars_df_index["int64_col"]) + pd_result = scalars_pandas_df_index[["int64_col"]].equals( + scalars_pandas_df_index["int64_col"] + ) + + assert pd_result == bf_result + + +def test_df_equals_different_dtype(scalars_df_index, scalars_pandas_df_index): + columns = ["int64_col", "int64_too"] + scalars_df_index = scalars_df_index[columns] + scalars_pandas_df_index = scalars_pandas_df_index[columns] + + bf_modified = scalars_df_index.copy() + bf_modified = bf_modified.astype("Float64") + + pd_modified = scalars_pandas_df_index.copy() + pd_modified = pd_modified.astype("Float64") + + bf_result = scalars_df_index.equals(bf_modified) + pd_result = scalars_pandas_df_index.equals(pd_modified) + + assert pd_result == bf_result + + +def test_df_equals_different_values(scalars_df_index, scalars_pandas_df_index): + columns = ["int64_col", "int64_too"] + scalars_df_index = scalars_df_index[columns] + scalars_pandas_df_index = scalars_pandas_df_index[columns] + + bf_modified = scalars_df_index.copy() + bf_modified["int64_col"] = bf_modified.int64_col + 1 + + pd_modified = scalars_pandas_df_index.copy() + pd_modified["int64_col"] = pd_modified.int64_col + 1 + + bf_result = scalars_df_index.equals(bf_modified) + pd_result = scalars_pandas_df_index.equals(pd_modified) + + assert pd_result == bf_result + + +def test_df_equals_extra_column(scalars_df_index, scalars_pandas_df_index): + columns = ["int64_col", "int64_too"] + more_columns = ["int64_col", "int64_too", "float64_col"] + + bf_result = scalars_df_index[columns].equals(scalars_df_index[more_columns]) + pd_result = scalars_pandas_df_index[columns].equals( + scalars_pandas_df_index[more_columns] + ) + + assert pd_result == bf_result + + +def test_df_reindex_like(scalars_df_index, scalars_pandas_df_index): + reindex_target_bf = scalars_df_index.reindex( + columns=["not_a_col", "int64_col", "int64_too"], index=[5, 1, 3, 99, 1] + ) + bf_result = scalars_df_index.reindex_like(reindex_target_bf).to_pandas() + + reindex_target_pd = scalars_pandas_df_index.reindex( + columns=["not_a_col", "int64_col", "int64_too"], index=[5, 1, 3, 99, 1] + ) + pd_result = scalars_pandas_df_index.reindex_like(reindex_target_pd) + + # Pandas uses float64 as default for newly created empty column, bf uses Float64 + # Pandas uses int64 instead of Int64 (nullable) dtype. + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + # Pandas uses float64 as default for newly created empty column, bf uses Float64 + pd_result.not_a_col = pd_result.not_a_col.astype(pandas.Float64Dtype()) + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_df_values(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.values + + pd_result = scalars_pandas_df_index.values + # Numpy isn't equipped to compare non-numeric objects, so convert back to dataframe + pd.testing.assert_frame_equal( + pd.DataFrame(bf_result), pd.DataFrame(pd_result), check_dtype=False + ) + + +def test_df_to_numpy(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.to_numpy() + + pd_result = scalars_pandas_df_index.to_numpy() + # Numpy isn't equipped to compare non-numeric objects, so convert back to dataframe + pd.testing.assert_frame_equal( + pd.DataFrame(bf_result), pd.DataFrame(pd_result), check_dtype=False + ) + + +def test_df___array__(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.__array__() + + pd_result = scalars_pandas_df_index.__array__() + # Numpy isn't equipped to compare non-numeric objects, so convert back to dataframe + pd.testing.assert_frame_equal( + pd.DataFrame(bf_result), pd.DataFrame(pd_result), check_dtype=False + ) + + +def test_df_getattr_attribute_error_when_pandas_has(scalars_df_index): + # swapaxes is implemented in pandas but not in bigframes + with pytest.raises(AttributeError): + scalars_df_index.swapaxes() + + +def test_df_getattr_attribute_error(scalars_df_index): + with pytest.raises(AttributeError): + scalars_df_index.not_a_method() + + +def test_df_getattr_axes(): + df = dataframe.DataFrame( + [[1, 1, 1], [1, 1, 1]], columns=["index", "columns", "my_column"] + ) + assert isinstance(df.index, bigframes.core.indexes.Index) + assert isinstance(df.columns, pandas.Index) + assert isinstance(df.my_column, series.Series) + + +def test_df_setattr_index(): + pd_df = pandas.DataFrame( + [[1, 1, 1], [1, 1, 1]], columns=["index", "columns", "my_column"] + ) + bf_df = dataframe.DataFrame(pd_df) + + pd_df.index = pandas.Index([4, 5]) + bf_df.index = [4, 5] + + assert_frame_equal( + pd_df, bf_df.to_pandas(), check_index_type=False, check_dtype=False + ) + + +def test_df_setattr_columns(): + pd_df = pandas.DataFrame( + [[1, 1, 1], [1, 1, 1]], columns=["index", "columns", "my_column"] + ) + bf_df = dataframe.DataFrame(pd_df) + + pd_df.columns = typing.cast(pandas.Index, pandas.Index([4, 5, 6])) + + bf_df.columns = pandas.Index([4, 5, 6]) + + assert_frame_equal( + pd_df, bf_df.to_pandas(), check_index_type=False, check_dtype=False + ) + + +def test_df_setattr_modify_column(): + pd_df = pandas.DataFrame( + [[1, 1, 1], [1, 1, 1]], columns=["index", "columns", "my_column"] + ) + bf_df = dataframe.DataFrame(pd_df) + pd_df.my_column = [4, 5] + bf_df.my_column = [4, 5] + + assert_frame_equal( + pd_df, bf_df.to_pandas(), check_index_type=False, check_dtype=False + ) + + +def test_loc_list_string_index(scalars_df_index, scalars_pandas_df_index): + index_list = scalars_pandas_df_index.string_col.iloc[[0, 1, 1, 5]].values + + scalars_df_index = scalars_df_index.set_index("string_col") + scalars_pandas_df_index = scalars_pandas_df_index.set_index("string_col") + + bf_result = scalars_df_index.loc[index_list].to_pandas() + pd_result = scalars_pandas_df_index.loc[index_list] + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_loc_list_integer_index(scalars_df_index, scalars_pandas_df_index): + index_list = [3, 2, 1, 3, 2, 1] + + bf_result = scalars_df_index.loc[index_list] + pd_result = scalars_pandas_df_index.loc[index_list] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_loc_list_multiindex(scalars_dfs): + scalars_df_index, scalars_pandas_df_index = scalars_dfs + scalars_df_multiindex = scalars_df_index.set_index(["string_col", "int64_col"]) + scalars_pandas_df_multiindex = scalars_pandas_df_index.set_index( + ["string_col", "int64_col"] + ) + index_list = [("Hello, World!", -234892), ("Hello, World!", 123456789)] + + bf_result = scalars_df_multiindex.loc[index_list] + pd_result = scalars_pandas_df_multiindex.loc[index_list] + + assert_dfs_equivalent( + pd_result, + bf_result, + ) + + +@pytest.mark.parametrize( + "index_list", + [ + [0, 1, 2, 3, 4, 4], + [0, 0, 0, 5, 4, 7, -2, -5, 3], + [-1, -2, -3, -4, -5, -5], + ], +) +def test_iloc_list(scalars_df_index, scalars_pandas_df_index, index_list): + bf_result = scalars_df_index.iloc[index_list] + pd_result = scalars_pandas_df_index.iloc[index_list] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_iloc_list_multiindex(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + scalars_df = scalars_df.copy() + scalars_pandas_df = scalars_pandas_df.copy() + scalars_df = scalars_df.set_index(["bytes_col", "numeric_col"]) + scalars_pandas_df = scalars_pandas_df.set_index(["bytes_col", "numeric_col"]) + + index_list = [0, 0, 0, 5, 4, 7] + + bf_result = scalars_df.iloc[index_list] + pd_result = scalars_pandas_df.iloc[index_list] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_iloc_empty_list(scalars_df_index, scalars_pandas_df_index): + index_list: List[int] = [] + + bf_result = scalars_df_index.iloc[index_list] + pd_result = scalars_pandas_df_index.iloc[index_list] + + bf_result = bf_result.to_pandas() + assert bf_result.shape == pd_result.shape # types are known to be different + + +def test_rename_axis(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.rename_axis("newindexname") + pd_result = scalars_pandas_df_index.rename_axis("newindexname") + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_rename_axis_nonstring(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.rename_axis((4,)) + pd_result = scalars_pandas_df_index.rename_axis((4,)) + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_loc_bf_series_string_index(scalars_df_index, scalars_pandas_df_index): + pd_string_series = scalars_pandas_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + bf_string_series = scalars_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + + scalars_df_index = scalars_df_index.set_index("string_col") + scalars_pandas_df_index = scalars_pandas_df_index.set_index("string_col") + + bf_result = scalars_df_index.loc[bf_string_series] + pd_result = scalars_pandas_df_index.loc[pd_string_series] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_loc_bf_series_multiindex(scalars_df_index, scalars_pandas_df_index): + pd_string_series = scalars_pandas_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + bf_string_series = scalars_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + + scalars_df_multiindex = scalars_df_index.set_index(["string_col", "int64_col"]) + scalars_pandas_df_multiindex = scalars_pandas_df_index.set_index( + ["string_col", "int64_col"] + ) + + bf_result = scalars_df_multiindex.loc[bf_string_series] + pd_result = scalars_pandas_df_multiindex.loc[pd_string_series] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_loc_bf_index_integer_index(scalars_df_index, scalars_pandas_df_index): + pd_index = scalars_pandas_df_index.iloc[[0, 5, 1, 1, 5]].index + bf_index = scalars_df_index.iloc[[0, 5, 1, 1, 5]].index + + bf_result = scalars_df_index.loc[bf_index] + pd_result = scalars_pandas_df_index.loc[pd_index] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +def test_loc_bf_index_integer_index_renamed_col( + scalars_df_index, scalars_pandas_df_index +): + scalars_df_index = scalars_df_index.rename(columns={"int64_col": "rename"}) + scalars_pandas_df_index = scalars_pandas_df_index.rename( + columns={"int64_col": "rename"} + ) + + pd_index = scalars_pandas_df_index.iloc[[0, 5, 1, 1, 5]].index + bf_index = scalars_df_index.iloc[[0, 5, 1, 1, 5]].index + + bf_result = scalars_df_index.loc[bf_index] + pd_result = scalars_pandas_df_index.loc[pd_index] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + +@pytest.mark.parametrize( + ("subset"), + [ + None, + "bool_col", + ["bool_col", "int64_too"], + ], +) +@pytest.mark.parametrize( + ("keep",), + [ + (False,), + ], +) +def test_df_drop_duplicates(scalars_df_index, scalars_pandas_df_index, keep, subset): + columns = ["bool_col", "int64_too", "int64_col"] + bf_df = scalars_df_index[columns].drop_duplicates(subset, keep=keep).to_pandas() + pd_df = scalars_pandas_df_index[columns].drop_duplicates(subset, keep=keep) + pd.testing.assert_frame_equal( + pd_df, + bf_df, + ) + + +@pytest.mark.parametrize( + ("subset"), + [ + None, + ["bool_col"], + ], +) +@pytest.mark.parametrize( + ("keep",), + [ + (False,), + ], +) +def test_df_duplicated(scalars_df_index, scalars_pandas_df_index, keep, subset): + columns = ["bool_col", "int64_too", "int64_col"] + bf_series = scalars_df_index[columns].duplicated(subset, keep=keep).to_pandas() + pd_series = scalars_pandas_df_index[columns].duplicated(subset, keep=keep) + pd.testing.assert_series_equal(pd_series, bf_series, check_dtype=False) + + +def test_df_from_dict_columns_orient(): + data = {"a": [1, 2], "b": [3.3, 2.4]} + bf_result = dataframe.DataFrame.from_dict(data, orient="columns").to_pandas() + pd_result = pd.DataFrame.from_dict(data, orient="columns") + assert_frame_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) + + +def test_df_from_dict_index_orient(): + data = {"a": [1, 2], "b": [3.3, 2.4]} + bf_result = dataframe.DataFrame.from_dict( + data, orient="index", columns=["col1", "col2"] + ).to_pandas() + pd_result = pd.DataFrame.from_dict(data, orient="index", columns=["col1", "col2"]) + assert_frame_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) + + +def test_df_from_dict_tight_orient(): + data = { + "index": [("i1", "i2"), ("i3", "i4")], + "columns": ["col1", "col2"], + "data": [[1, 2.6], [3, 4.5]], + "index_names": ["in1", "in2"], + "column_names": ["column_axis"], + } + + bf_result = dataframe.DataFrame.from_dict(data, orient="tight").to_pandas() + pd_result = pd.DataFrame.from_dict(data, orient="tight") + assert_frame_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) + + +def test_df_from_records(): + records = ((1, "a"), (2.5, "b"), (3.3, "c"), (4.9, "d")) + + bf_result = dataframe.DataFrame.from_records( + records, columns=["c1", "c2"] + ).to_pandas() + pd_result = pd.DataFrame.from_records(records, columns=["c1", "c2"]) + assert_frame_equal(pd_result, bf_result, check_dtype=False, check_index_type=False) + + +def test_df_to_dict(scalars_df_index, scalars_pandas_df_index): + unsupported = ["numeric_col"] # formatted differently + bf_result = scalars_df_index.drop(columns=unsupported).to_dict() + pd_result = scalars_pandas_df_index.drop(columns=unsupported).to_dict() + + assert bf_result == pd_result + + +def test_df_to_json_local_str(scalars_df_index, scalars_pandas_df_index): + # pandas 3.0 bugged for serializing date col + bf_result = scalars_df_index.drop(columns="date_col").to_json() + # default_handler for arrow types that have no default conversion + pd_result = scalars_pandas_df_index.drop(columns="date_col").to_json( + default_handler=str + ) + + assert bf_result == pd_result + + +def test_df_to_json_local_file(scalars_df_index, scalars_pandas_df_index): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + # duration not fully supported at pandas level + scalars_df_index = scalars_df_index.drop(columns="duration_col") + scalars_pandas_df_index = scalars_pandas_df_index.drop(columns="duration_col") + with ( + tempfile.TemporaryFile() as bf_result_file, + tempfile.TemporaryFile() as pd_result_file, + ): + scalars_df_index.to_json(bf_result_file, orient="table") + # default_handler for arrow types that have no default conversion + scalars_pandas_df_index.to_json( + pd_result_file, orient="table", default_handler=str + ) + + bf_result = bf_result_file.read() + pd_result = pd_result_file.read() + + assert bf_result == pd_result + + +def test_df_to_csv_local_str(scalars_df_index, scalars_pandas_df_index): + bf_result = scalars_df_index.to_csv() + # default_handler for arrow types that have no default conversion + pd_result = scalars_pandas_df_index.to_csv() + + assert bf_result == pd_result + + +def test_df_to_csv_local_file(scalars_df_index, scalars_pandas_df_index): + with ( + tempfile.TemporaryFile() as bf_result_file, + tempfile.TemporaryFile() as pd_result_file, + ): + scalars_df_index.to_csv(bf_result_file) + scalars_pandas_df_index.to_csv(pd_result_file) + + bf_result = bf_result_file.read() + pd_result = pd_result_file.read() + + assert bf_result == pd_result + + +def test_df_to_parquet_local_bytes(scalars_df_index, scalars_pandas_df_index): + # GEOGRAPHY not supported in parquet export. + unsupported = ["geography_col"] + + bf_result = scalars_df_index.drop(columns=unsupported).to_parquet() + # default_handler for arrow types that have no default conversion + pd_result = scalars_pandas_df_index.drop(columns=unsupported).to_parquet() + + assert bf_result == pd_result + + +def test_df_to_parquet_local_file(scalars_df_index, scalars_pandas_df_index): + # GEOGRAPHY not supported in parquet export. + unsupported = ["geography_col"] + with ( + tempfile.TemporaryFile() as bf_result_file, + tempfile.TemporaryFile() as pd_result_file, + ): + scalars_df_index.drop(columns=unsupported).to_parquet(bf_result_file) + scalars_pandas_df_index.drop(columns=unsupported).to_parquet(pd_result_file) + + bf_result = bf_result_file.read() + pd_result = pd_result_file.read() + + assert bf_result == pd_result + + +def test_df_to_records(scalars_df_index, scalars_pandas_df_index): + unsupported = ["numeric_col"] + bf_result = scalars_df_index.drop(columns=unsupported).to_records() + pd_result = scalars_pandas_df_index.drop(columns=unsupported).to_records() + + for bfi, pdi in zip(bf_result, pd_result): + for bfj, pdj in zip(bfi, pdi): + assert pd.isna(bfj) and pd.isna(pdj) or bfj == pdj + + +def test_df_to_string(scalars_df_index, scalars_pandas_df_index): + unsupported = ["numeric_col"] # formatted differently + + bf_result = scalars_df_index.drop(columns=unsupported).to_string() + pd_result = scalars_pandas_df_index.drop(columns=unsupported).to_string() + + assert bf_result == pd_result + + +def test_df_to_html(scalars_df_index, scalars_pandas_df_index): + unsupported = ["numeric_col"] # formatted differently + + bf_result = scalars_df_index.drop(columns=unsupported).to_html() + pd_result = scalars_pandas_df_index.drop(columns=unsupported).to_html() + + assert bf_result == pd_result + + +def test_df_to_markdown(scalars_df_index, scalars_pandas_df_index): + # Nulls have bug from tabulate https://github.com/astanin/python-tabulate/issues/231 + bf_result = scalars_df_index.dropna().to_markdown() + pd_result = scalars_pandas_df_index.dropna().to_markdown() + + assert bf_result == pd_result + + +def test_df_to_pickle(scalars_df_index, scalars_pandas_df_index): + with ( + tempfile.TemporaryFile() as bf_result_file, + tempfile.TemporaryFile() as pd_result_file, + ): + scalars_df_index.to_pickle(bf_result_file) + scalars_pandas_df_index.to_pickle(pd_result_file) + bf_result = bf_result_file.read() + pd_result = pd_result_file.read() + + assert bf_result == pd_result + + +def test_df_to_orc(scalars_df_index, scalars_pandas_df_index): + pytest.importorskip("pyarrow.orc") + unsupported = [ + "numeric_col", + "bytes_col", + "date_col", + "datetime_col", + "time_col", + "timestamp_col", + "geography_col", + "duration_col", + ] + + bf_result_file = tempfile.TemporaryFile() + pd_result_file = tempfile.TemporaryFile() + scalars_df_index.drop(columns=unsupported).to_orc(bf_result_file) + scalars_pandas_df_index.drop(columns=unsupported).reset_index().to_orc( + pd_result_file + ) + bf_result = bf_result_file.read() + pd_result = bf_result_file.read() + + assert bf_result == pd_result + + +@pytest.mark.parametrize( + ("expr",), + [ + ("new_col = int64_col + int64_too",), + ("new_col = (rowindex > 3) | bool_col",), + ("int64_too = bool_col\nnew_col2 = rowindex",), + ], +) +def test_df_eval(scalars_dfs, expr): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.eval(expr).to_pandas() + pd_result = scalars_pandas_df.eval(expr) + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("expr",), + [ + ("int64_col > int64_too",), + ("bool_col",), + ("((int64_col - int64_too) % @local_var) == 0",), + ], +) +def test_df_query(scalars_dfs, expr): + # TODO: supply a reason why this isn't compatible with pandas 1.x + pytest.importorskip("pandas", minversion="2.0.0") + # local_var is referenced in expressions + local_var = 3 # NOQA + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = scalars_df.query(expr).to_pandas() + pd_result = scalars_pandas_df.query(expr) + + pd.testing.assert_frame_equal(bf_result, pd_result) + + +@pytest.mark.parametrize( + ("subset", "normalize", "ascending", "dropna"), + [ + (None, False, False, False), + (None, True, True, True), + ("bool_col", True, False, True), + ], +) +def test_df_value_counts(scalars_dfs, subset, normalize, ascending, dropna): + if pd.__version__.startswith("1."): + pytest.skip("pandas 1.x produces different column labels.") + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = ( + scalars_df[["string_col", "bool_col"]] + .value_counts(subset, normalize=normalize, ascending=ascending, dropna=dropna) + .to_pandas() + ) + pd_result = scalars_pandas_df[["string_col", "bool_col"]].value_counts( + subset, normalize=normalize, ascending=ascending, dropna=dropna + ) + + assert_series_equal( + bf_result, + pd_result, + check_dtype=False, + check_index_type=False, + # different pandas versions inconsistent for tie-handling + ignore_order=True, + ) + + +def test_df_bool_interpretation_error(scalars_df_index): + with pytest.raises(ValueError): + True if scalars_df_index else False + + +def test_assign_after_binop_row_joins(): + pd_df = pd.DataFrame( + { + "idx1": [1, 1, 1, 1, 2, 2, 2, 2], + "idx2": [10, 10, 20, 20, 10, 10, 20, 20], + "metric1": [10, 14, 2, 13, 6, 2, 9, 5], + "metric2": [25, -3, 8, 2, -1, 0, 0, -4], + }, + dtype=pd.Int64Dtype(), + ).set_index(["idx1", "idx2"]) + bf_df = dataframe.DataFrame(pd_df) + + # Expect implicit joiner to be used, preserving input cardinality rather than getting relational join + bf_df["metric_diff"] = bf_df.metric1 - bf_df.metric2 + pd_df["metric_diff"] = pd_df.metric1 - pd_df.metric2 + + assert_frame_equal(bf_df.to_pandas(), pd_df) + + +def test_df_dot_inline(session): + df1 = pd.DataFrame([[1, 2, 3], [2, 5, 7]]) + df2 = pd.DataFrame([[2, 4, 8], [1, 5, 10], [3, 6, 9]]) + + bf1 = session.read_pandas(df1) + bf2 = session.read_pandas(df2) + bf_result = bf1.dot(bf2).to_pandas() + pd_result = df1.dot(df2) + + # Patch pandas dtypes for testing parity + # Pandas uses int64 instead of Int64 (nullable) dtype. + for name in pd_result.columns: + pd_result[name] = pd_result[name].astype(pd.Int64Dtype()) + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + +def test_df_dot_series_inline(): + left = [[1, 2, 3], [2, 5, 7]] + right = [2, 1, 3] + + bf1 = dataframe.DataFrame(left) + bf2 = series.Series(right) + bf_result = bf1.dot(bf2).to_pandas() + + df1 = pd.DataFrame(left) + df2 = pd.Series(right) + pd_result = df1.dot(df2) + + # Patch pandas dtypes for testing parity + # Pandas result is int64 instead of Int64 (nullable) dtype. + pd_result = pd_result.astype(pd.Int64Dtype()) + pd_result.index = pd_result.index.astype(pd.Int64Dtype()) + + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) + + +@pytest.mark.parametrize( + ("col_names", "ignore_index"), + [ + pytest.param(["A"], False, id="one_array_false"), + pytest.param(["A"], True, id="one_array_true"), + pytest.param(["B"], False, id="one_float_false"), + pytest.param(["B"], True, id="one_float_true"), + pytest.param(["A", "C"], False, id="two_arrays_false"), + pytest.param(["A", "C"], True, id="two_arrays_true"), + ], +) +def test_dataframe_explode(col_names, ignore_index, session): + data = { + "A": [[0, 1, 2], [], [3, 4]], + "B": 3, + "C": [["a", "b", "c"], np.nan, ["d", "e"]], + } + + df = bpd.DataFrame(data, session=session) + pd_df = df.to_pandas() + pd_result = pd_df.explode(col_names, ignore_index=ignore_index) + bf_result = df.explode(col_names, ignore_index=ignore_index) + + # Check that to_pandas() results in at most a single query execution + bf_materialized = bf_result.to_pandas() + + pd.testing.assert_frame_equal( + bf_materialized, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + +@pytest.mark.parametrize( + ("ignore_index", "ordered"), + [ + pytest.param(True, True, id="include_index_ordered"), + pytest.param(True, False, id="include_index_unordered"), + pytest.param(False, True, id="ignore_index_ordered"), + ], +) +def test_dataframe_explode_reserve_order(session, ignore_index, ordered): + data = { + "a": [np.random.randint(0, 10, 10) for _ in range(10)], + "b": [np.random.randint(0, 10, 10) for _ in range(10)], + } + df = bpd.DataFrame(data) + pd_df = pd.DataFrame(data) + + res = df.explode(["a", "b"], ignore_index=ignore_index).to_pandas(ordered=ordered) + pd_res = pd_df.explode(["a", "b"], ignore_index=ignore_index).astype( + pd.Int64Dtype() + ) + pd.testing.assert_frame_equal( + res if ordered else res.sort_index(), + pd_res, + check_index_type=False, + ) + + +@pytest.mark.parametrize( + ("col_names"), + [ + pytest.param([], id="empty", marks=pytest.mark.xfail(raises=ValueError)), + pytest.param( + ["A", "A"], id="duplicate", marks=pytest.mark.xfail(raises=ValueError) + ), + pytest.param("unknown", id="unknown", marks=pytest.mark.xfail(raises=KeyError)), + ], +) +def test_dataframe_explode_xfail(col_names): + df = bpd.DataFrame({"A": [[0, 1, 2], [], [3, 4]]}) + df.explode(col_names) + + +def test_recursion_limit_unit(scalars_df_index): + scalars_df_index = scalars_df_index[["int64_too", "int64_col", "float64_col"]] + for i in range(250): + scalars_df_index = scalars_df_index + 4 + scalars_df_index.to_pandas() From 1ebcfcf45deaa24111481a6cc1f8c4d52ef146ac Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 3 Apr 2026 23:21:43 +0000 Subject: [PATCH 2/5] add more ops to substrait compiler --- .../core/compile/substrait/compiler.py | 58 +++--- .../unit/session/test_substrait_executor.py | 181 ++++++++++++++++++ 2 files changed, 213 insertions(+), 26 deletions(-) diff --git a/packages/bigframes/bigframes/core/compile/substrait/compiler.py b/packages/bigframes/bigframes/core/compile/substrait/compiler.py index 755c232d92c4..8c1f19a4bed9 100644 --- a/packages/bigframes/bigframes/core/compile/substrait/compiler.py +++ b/packages/bigframes/bigframes/core/compile/substrait/compiler.py @@ -53,12 +53,15 @@ def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: ("sub", 2), ("mul", 3), ("div", 4), - ("eq", 5), + ("equal", 5), ("ne", 6), ("lt", 7), ("gt", 8), ("le", 9), ("ge", 10), + ("sum", 11), + ("max", 12), + ("and", 13), ] for name, anchor in extensions: ext = pb_plan.extensions.add() @@ -90,8 +93,6 @@ def _compile_node(self, node: bigframe_node.BigFrameNode) -> Dict[str, Any]: return self._compile_selection(node) elif isinstance(node, nodes.FilterNode): return self._compile_filter(node) - elif isinstance(node, nodes.SliceNode): - return self._compile_slice(node) elif isinstance(node, nodes.ProjectionNode): return self._compile_projection(node) elif isinstance(node, nodes.JoinNode): @@ -114,24 +115,28 @@ def _compile_read(self, node: nodes.ReadLocalNode) -> Dict[str, Any]: return json_format.MessageToDict(rel, preserving_proto_field_name=True) def _compile_selection(self, node: nodes.SelectionNode) -> Dict[str, Any]: - # Selection usually maps to ProjectRel or FilterRel depending on if it filters or just selects columns. - # If it's just column selection (Projection), it's a ProjectRel. - # Let's assume it's a ProjectRel for now. input_rel = self._compile_node(node.child) + expressions = [] + child_ids = list(node.child.ids) + for aliased_ref in node.input_output_pairs: + source_id = aliased_ref.ref.id + idx = child_ids.index(source_id) + expressions.append({"selection": {"direct_reference": {"struct_field": {"field": idx}}}}) return { "project": { + "common": { + "emit": { + "outputMapping": [len(child_ids) + i for i in range(len(expressions))] + } + }, "input": input_rel, - "expressions": [ - # Skeletal expression mapping - {"selection": {"direct_reference": {"struct_field": {"field": i}}}} - for i in range(len(node.schema)) - ] + "expressions": expressions } } def _compile_filter(self, node: nodes.FilterNode) -> Dict[str, Any]: input_rel = self._compile_node(node.child) - condition_rel = self._compile_expression(node.condition, node.child) + condition_rel = self._compile_expression(node.predicate, node.child) return { "filter": { "input": input_rel, @@ -139,18 +144,6 @@ def _compile_filter(self, node: nodes.FilterNode) -> Dict[str, Any]: } } - def _compile_slice(self, node: nodes.SliceNode) -> Dict[str, Any]: - input_rel = self._compile_node(node.child) - count = node.stop if node.stop is not None else -1 - offset = node.start if node.start is not None else 0 - - return { - "fetch": { - "input": input_rel, - "offset": offset, - "count": count - } - } def _compile_projection(self, node: nodes.ProjectionNode) -> Dict[str, Any]: input_rel_dict = self._compile_node(node.child) @@ -193,7 +186,7 @@ def _compile_join(self, node: nodes.JoinNode) -> Dict[str, Any]: eq_expressions.append({ "scalar_function": { - "function_reference": 0, + "function_reference": 5, # eq "arguments": [ {"value": {"selection": {"direct_reference": {"struct_field": {"field": left_idx}}}}}, {"value": {"selection": {"direct_reference": {"struct_field": {"field": right_idx}}}}} @@ -203,6 +196,13 @@ def _compile_join(self, node: nodes.JoinNode) -> Dict[str, Any]: if len(eq_expressions) > 1: expr = eq_expressions[0] + for e in eq_expressions[1:]: + expr = { + "scalar_function": { + "function_reference": 13, # and + "arguments": [{"value": expr}, {"value": e}] + } + } elif len(eq_expressions) == 1: expr = eq_expressions[0] else: @@ -229,8 +229,14 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> Dict[str, Any]: groupings.append({"grouping_expressions": grouping_expressions}) measures = [] + import bigframes.operations.aggregations as agg_ops for agg, _ in node.aggregations: - func_ref = 1 if "Sum" in type(agg).__name__ else 2 + if isinstance(agg.op, agg_ops.SumOp): + func_ref = 11 + elif isinstance(agg.op, agg_ops.MaxOp): + func_ref = 12 + else: + raise NotImplementedError(f"Aggregation {type(agg.op)} not supported in Substrait compiler yet") args = [] if hasattr(agg, "column_references"): for col_id in agg.column_references: diff --git a/packages/bigframes/tests/unit/session/test_substrait_executor.py b/packages/bigframes/tests/unit/session/test_substrait_executor.py index bdd21bed2b79..d5a6a5cd2da7 100644 --- a/packages/bigframes/tests/unit/session/test_substrait_executor.py +++ b/packages/bigframes/tests/unit/session/test_substrait_executor.py @@ -133,3 +133,184 @@ def test_execute_projection_add_with_datafusion(): assert "b" in result_table.column_names assert "a" in result_table.column_names assert result_table.column("b").to_pylist() == [43, 44, 45] + + +def test_execute_filter_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + from bigframes.operations.comparison_ops import gt_op + + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + read_node = create_read_local_node() + + # a > 1 + filter_expr = ex.OpExpression( + op=gt_op, + inputs=( + ex.DerefOp(identifiers.ColumnId("a")), + ex.ScalarConstantExpression(1), + ), + ) + plan = nodes.FilterNode( + child=read_node, + predicate=filter_expr, + ) + + result = executor.execute(plan, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 2 + assert "a" in result_table.column_names + assert result_table.column("a").to_pylist() == [2, 3] + + +def test_execute_aggregate_sum_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + from bigframes.operations.aggregations import sum_op + from bigframes.core.agg_expressions import UnaryAggregation + + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + read_node = create_read_local_node() + + # sum(a) + sum_agg = UnaryAggregation( + op=sum_op, + arg=ex.DerefOp(identifiers.ColumnId("a")), + ) + + plan = nodes.AggregateNode( + child=read_node, + aggregations=((sum_agg, identifiers.ColumnId("sum_a")),), + by_column_ids=(), + ) + + result = executor.execute(plan, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 1 + assert "sum_a" in result_table.column_names + assert result_table.column("sum_a").to_pylist() == [6] + + +def test_execute_aggregate_max_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + from bigframes.operations.aggregations import max_op + from bigframes.core.agg_expressions import UnaryAggregation + + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + read_node = create_read_local_node() + + # max(a) + max_agg = UnaryAggregation( + op=max_op, + arg=ex.DerefOp(identifiers.ColumnId("a")), + ) + + plan = nodes.AggregateNode( + child=read_node, + aggregations=((max_agg, identifiers.ColumnId("max_a")),), + by_column_ids=(), + ) + + result = executor.execute(plan, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 1 + assert "max_a" in result_table.column_names + assert result_table.column("max_a").to_pylist() == [3] + + +def test_execute_join_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + # Table 1: a + session1 = mocks.create_bigquery_session() + table1 = pa.Table.from_pydict({"a": [1, 2, 3]}) + source1 = local_data.ManagedArrowTable.from_pyarrow(table1) + col_id_a = identifiers.ColumnId("a") + read_node1 = nodes.ReadLocalNode( + local_data_source=source1, + session=session1, + scan_list=nodes.ScanList(items=(nodes.ScanItem(id=col_id_a, source_id="a"),)), + ) + + # Table 2: b + session2 = mocks.create_bigquery_session() + table2 = pa.Table.from_pydict({"b": [2, 3, 4]}) + source2 = local_data.ManagedArrowTable.from_pyarrow(table2) + col_id_b = identifiers.ColumnId("b") + read_node2 = nodes.ReadLocalNode( + local_data_source=source2, + session=session2, + scan_list=nodes.ScanList(items=(nodes.ScanItem(id=col_id_b, source_id="b"),)), + ) + + # Join on a = b + join_node = nodes.JoinNode( + left_child=read_node1, + right_child=read_node2, + conditions=((ex.DerefOp(col_id_a), ex.DerefOp(col_id_b)),), + type="inner", + propogate_order=False, + ) + + result = executor.execute(join_node, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 2 + assert "a" in result_table.column_names + assert "b" in result_table.column_names + assert result_table.column("a").to_pylist() == [2, 3] + assert result_table.column("b").to_pylist() == [2, 3] + + +def test_execute_selection_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + from bigframes.core.nodes import AliasedRef + + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + # Table with a and b + session = mocks.create_bigquery_session() + table = pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + source = local_data.ManagedArrowTable.from_pyarrow(table) + col_id_a = identifiers.ColumnId("a") + col_id_b = identifiers.ColumnId("b") + read_node = nodes.ReadLocalNode( + local_data_source=source, + session=session, + scan_list=nodes.ScanList( + items=( + nodes.ScanItem(id=col_id_a, source_id="a"), + nodes.ScanItem(id=col_id_b, source_id="b"), + ) + ), + ) + + # Select only a, and rename it to c + col_id_c = identifiers.ColumnId("c") + selection_node = nodes.SelectionNode( + child=read_node, + input_output_pairs=(AliasedRef(ex.DerefOp(col_id_a), col_id_c),), + ) + + result = executor.execute(selection_node, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 3 + assert result_table.column_names == ["c"] + assert result_table.column("c").to_pylist() == [1, 2, 3] From daed87efd59241d4cc492461fc6c36778d406970 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Mon, 6 Apr 2026 20:15:07 +0000 Subject: [PATCH 3/5] more ops, types --- .../core/compile/substrait/compiler.py | 378 ++++++++++-------- .../bigframes/session/substrait_executor.py | 18 +- .../unit/session/test_substrait_executor.py | 34 ++ 3 files changed, 250 insertions(+), 180 deletions(-) diff --git a/packages/bigframes/bigframes/core/compile/substrait/compiler.py b/packages/bigframes/bigframes/core/compile/substrait/compiler.py index 8c1f19a4bed9..04adca253cee 100644 --- a/packages/bigframes/bigframes/core/compile/substrait/compiler.py +++ b/packages/bigframes/bigframes/core/compile/substrait/compiler.py @@ -14,16 +14,20 @@ from __future__ import annotations +from functools import singledispatchmethod import json -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Sequence import substrait.algebra_pb2 as algebra_pb2 import substrait.plan_pb2 as plan_pb2 +import substrait.type_pb2 as type_pb2 from google.protobuf import json_format from bigframes.core import bigframe_node, nodes import bigframes.core.expression as ex import pandas as pd +import bigframes.operations.numeric_ops as numeric_ops +import bigframes.operations.comparison_ops as comparison_ops class SubstraitCompiler: @@ -38,37 +42,22 @@ def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: if not self.can_compile(plan): return None - plan_dict = self._compile_node(plan) + pb_rel = self._compile_node(plan) pb_plan = plan_pb2.Plan() pb_plan.version.minor_number = 42 plan_rel = pb_plan.relations.add() - json_format.ParseDict(plan_dict, plan_rel.root.input) + plan_rel.root.input.CopyFrom(pb_rel) plan_rel.root.names.extend([item.column for item in plan.schema.items]) - extensions = [ - ("add", 1), - ("sub", 2), - ("mul", 3), - ("div", 4), - ("equal", 5), - ("ne", 6), - ("lt", 7), - ("gt", 8), - ("le", 9), - ("ge", 10), - ("sum", 11), - ("max", 12), - ("and", 13), - ] - for name, anchor in extensions: + for name, anchor in self._EXTENSIONS.items(): ext = pb_plan.extensions.add() ext.extension_function.function_anchor = anchor ext.extension_function.name = name - return json_format.MessageToJson(pb_plan).encode('utf-8') + return pb_plan.SerializeToString() def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: """ @@ -84,25 +73,25 @@ def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: nodes.JoinNode, nodes.AggregateNode, ) - return all(isinstance(node, supported_nodes) for node in plan.unique_nodes()) + return isinstance(plan, supported_nodes) - def _compile_node(self, node: bigframe_node.BigFrameNode) -> Dict[str, Any]: + def _compile_node(self, node: bigframe_node.BigFrameNode) -> algebra_pb2.Rel: if isinstance(node, nodes.ReadLocalNode): - return self._compile_read(node) + return self._compile_read(node) elif isinstance(node, nodes.SelectionNode): - return self._compile_selection(node) + return self._compile_selection(node) elif isinstance(node, nodes.FilterNode): - return self._compile_filter(node) + return self._compile_filter(node) elif isinstance(node, nodes.ProjectionNode): - return self._compile_projection(node) + return self._compile_projection(node) elif isinstance(node, nodes.JoinNode): - return self._compile_join(node) + return self._compile_join(node) elif isinstance(node, nodes.AggregateNode): - return self._compile_aggregate(node) + return self._compile_aggregate(node) else: - raise NotImplementedError(f"Node type {type(node)} not supported in Substrait compiler yet") + raise NotImplementedError(f"Node type {type(node)} not supported in Substrait compiler yet") - def _compile_read(self, node: nodes.ReadLocalNode) -> Dict[str, Any]: + def _compile_read(self, node: nodes.ReadLocalNode) -> algebra_pb2.Rel: table_name = f"table_{node.local_data_source.id.hex}" rel = algebra_pb2.Rel() @@ -112,70 +101,71 @@ def _compile_read(self, node: nodes.ReadLocalNode) -> Dict[str, Any]: schema_dict = self._convert_schema(node.local_data_source.schema) json_format.ParseDict(schema_dict, read_rel.base_schema) - return json_format.MessageToDict(rel, preserving_proto_field_name=True) + return rel - def _compile_selection(self, node: nodes.SelectionNode) -> Dict[str, Any]: + def _compile_selection(self, node: nodes.SelectionNode) -> algebra_pb2.Rel: input_rel = self._compile_node(node.child) - expressions = [] + + rel = algebra_pb2.Rel() + project_rel = rel.project + project_rel.input.CopyFrom(input_rel) + child_ids = list(node.child.ids) + num_exprs = 0 for aliased_ref in node.input_output_pairs: source_id = aliased_ref.ref.id idx = child_ids.index(source_id) - expressions.append({"selection": {"direct_reference": {"struct_field": {"field": idx}}}}) - return { - "project": { - "common": { - "emit": { - "outputMapping": [len(child_ids) + i for i in range(len(expressions))] - } - }, - "input": input_rel, - "expressions": expressions - } - } + expr = project_rel.expressions.add() + expr.selection.direct_reference.struct_field.field = idx + num_exprs += 1 + + project_rel.common.emit.output_mapping.extend([len(child_ids) + i for i in range(num_exprs)]) + + return rel - def _compile_filter(self, node: nodes.FilterNode) -> Dict[str, Any]: + def _compile_filter(self, node: nodes.FilterNode) -> algebra_pb2.Rel: input_rel = self._compile_node(node.child) - condition_rel = self._compile_expression(node.predicate, node.child) - return { - "filter": { - "input": input_rel, - "condition": condition_rel - } - } + + rel = algebra_pb2.Rel() + filter_rel = rel.filter + filter_rel.input.CopyFrom(input_rel) + + condition_expr = self._compile_expression(node.predicate, node.child) + filter_rel.condition.CopyFrom(condition_expr) + + return rel - def _compile_projection(self, node: nodes.ProjectionNode) -> Dict[str, Any]: - input_rel_dict = self._compile_node(node.child) + def _compile_projection(self, node: nodes.ProjectionNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) rel = algebra_pb2.Rel() project_rel = rel.project + project_rel.input.CopyFrom(input_rel) - json_format.ParseDict(input_rel_dict, project_rel.input) - - # DataFusion ProjectRel seems to be additive (appends to input). - # So we don't need to add passthrough expressions for input fields. - - # Add new assignments for expr, _ in node.assignments: - expr_dict = self._compile_expression(expr, node.child) - expr_pb = project_rel.expressions.add() - json_format.ParseDict(expr_dict, expr_pb) + expr_pb = self._compile_expression(expr, node.child) + project_rel.expressions.add().CopyFrom(expr_pb) - return json_format.MessageToDict(rel, preserving_proto_field_name=True) + return rel - def _compile_join(self, node: nodes.JoinNode) -> Dict[str, Any]: + def _compile_join(self, node: nodes.JoinNode) -> algebra_pb2.Rel: left_rel = self._compile_node(node.left_child) right_rel = self._compile_node(node.right_child) + rel = algebra_pb2.Rel() + join_rel = rel.join + + join_rel.left.CopyFrom(left_rel) + join_rel.right.CopyFrom(right_rel) + type_map = { - "inner": "JOIN_TYPE_INNER", - "left": "JOIN_TYPE_LEFT", - "right": "JOIN_TYPE_RIGHT", - "outer": "JOIN_TYPE_OUTER", - "cross": "JOIN_TYPE_CROSS", + "inner": algebra_pb2.JoinRel.JOIN_TYPE_INNER, + "left": algebra_pb2.JoinRel.JOIN_TYPE_LEFT, + "right": algebra_pb2.JoinRel.JOIN_TYPE_RIGHT, + "outer": algebra_pb2.JoinRel.JOIN_TYPE_OUTER, } - join_type = type_map.get(node.type, "JOIN_TYPE_UNSPECIFIED") + join_rel.type = type_map.get(node.type, algebra_pb2.JoinRel.JOIN_TYPE_UNSPECIFIED) left_len = len(node.left_child.schema) @@ -184,51 +174,49 @@ def _compile_join(self, node: nodes.JoinNode) -> Dict[str, Any]: left_idx = list(node.left_child.ids).index(left_deref.id) right_idx = list(node.right_child.ids).index(right_deref.id) + left_len - eq_expressions.append({ - "scalar_function": { - "function_reference": 5, # eq - "arguments": [ - {"value": {"selection": {"direct_reference": {"struct_field": {"field": left_idx}}}}}, - {"value": {"selection": {"direct_reference": {"struct_field": {"field": right_idx}}}}} - ] - } - }) + eq_expr = algebra_pb2.Expression() + eq_expr.scalar_function.function_reference = 5 # equal + + arg1 = eq_expr.scalar_function.arguments.add() + arg1.value.selection.direct_reference.struct_field.field = left_idx + + arg2 = eq_expr.scalar_function.arguments.add() + arg2.value.selection.direct_reference.struct_field.field = right_idx + + eq_expressions.append(eq_expr) if len(eq_expressions) > 1: expr = eq_expressions[0] for e in eq_expressions[1:]: - expr = { - "scalar_function": { - "function_reference": 13, # and - "arguments": [{"value": expr}, {"value": e}] - } - } + and_expr = algebra_pb2.Expression() + and_expr.scalar_function.function_reference = 13 # and + and_expr.scalar_function.arguments.add().value.CopyFrom(expr) + and_expr.scalar_function.arguments.add().value.CopyFrom(e) + expr = and_expr elif len(eq_expressions) == 1: expr = eq_expressions[0] else: - expr = {"literal": {"boolean": True}} + expr = algebra_pb2.Expression() + expr.literal.boolean = True - return { - "join": { - "left": left_rel, - "right": right_rel, - "expression": expr, - "type": join_type - } - } + join_rel.expression.CopyFrom(expr) + + return rel - def _compile_aggregate(self, node: nodes.AggregateNode) -> Dict[str, Any]: + def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: input_rel = self._compile_node(node.child) - groupings = [] - grouping_expressions = [] - for deref in node.by_column_ids: - idx = list(node.child.ids).index(deref.id) - grouping_expressions.append({"selection": {"direct_reference": {"struct_field": {"field": idx}}}}) - if grouping_expressions: - groupings.append({"grouping_expressions": grouping_expressions}) - - measures = [] + rel = algebra_pb2.Rel() + agg_rel = rel.aggregate + agg_rel.input.CopyFrom(input_rel) + + if node.by_column_ids: + grouping = agg_rel.groupings.add() + for deref in node.by_column_ids: + idx = list(node.child.ids).index(deref.id) + expr = grouping.grouping_expressions.add() + expr.selection.direct_reference.struct_field.field = idx + import bigframes.operations.aggregations as agg_ops for agg, _ in node.aggregations: if isinstance(agg.op, agg_ops.SumOp): @@ -237,78 +225,110 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> Dict[str, Any]: func_ref = 12 else: raise NotImplementedError(f"Aggregation {type(agg.op)} not supported in Substrait compiler yet") - args = [] + + measure = agg_rel.measures.add() + measure.measure.function_reference = func_ref + if hasattr(agg, "column_references"): for col_id in agg.column_references: try: idx = list(node.child.ids).index(col_id) - args.append({"value": {"selection": {"direct_reference": {"struct_field": {"field": idx}}}}}) + arg = measure.measure.arguments.add() + arg.value.selection.direct_reference.struct_field.field = idx except ValueError: pass - measures.append({ - "measure": { - "function_reference": func_ref, - "arguments": args - } - }) - - return { - "aggregate": { - "input": input_rel, - "groupings": groupings, - "measures": measures - } - } + + return rel - def _compile_expression(self, expr: ex.Expression, child: nodes.BigFrameNode) -> Dict[str, Any]: - if isinstance(expr, ex.ScalarConstantExpression): - val = expr.value - if isinstance(val, int): - return {"literal": {"i64": val}} - elif isinstance(val, float): - return {"literal": {"fp64": val}} - elif isinstance(val, str): - return {"literal": {"string": val}} - elif pd.isna(val): - return {"literal": {"null": {"varchar": {"length": 0}}}} - else: - return {"literal": {"string": str(val)}} - - elif isinstance(expr, ex.DerefOp): - try: - # print(f"DerefOp: id={expr.id}, child.ids={list(child.ids)}") # Debug - idx = list(child.ids).index(expr.id) - return {"selection": {"direct_reference": {"struct_field": {"field": idx}}}} - except ValueError: - raise ValueError(f"Column {expr.id} not found in child schema") - - elif isinstance(expr, ex.OpExpression): - op_name = expr.op.name - op_mapping = { - "add": 1, - "sub": 2, - "mul": 3, - "div": 4, - "eq": 5, - "ne": 6, - "lt": 7, - "gt": 8, - "le": 9, - "ge": 10, - } - if op_name not in op_mapping: - raise NotImplementedError(f"Operation {op_name} not supported in Substrait compiler yet") - func_ref = op_mapping[op_name] - - args = [self._compile_expression(arg, child) for arg in expr.inputs] - return { - "scalar_function": { - "function_reference": func_ref, - "arguments": [{"value": arg} for arg in args] - } - } - else: - raise NotImplementedError(f"Expression type {type(expr)} not supported in Substrait compiler yet") + _EXTENSIONS = { + "add": 1, + "subtract": 2, + "multiply": 3, + "divide": 4, + "equal": 5, + "ne": 6, + "lt": 7, + "gt": 8, + "lte": 9, + "gte": 10, + "sum": 11, + "max": 12, + "and": 13, + } + + _OP_TO_EXTENSION = { + numeric_ops.AddOp: "add", + numeric_ops.SubOp: "subtract", + numeric_ops.MulOp: "multiply", + numeric_ops.DivOp: "divide", + comparison_ops.EqOp: "equal", + comparison_ops.NeOp: "ne", + comparison_ops.LtOp: "lt", + comparison_ops.GtOp: "gt", + comparison_ops.LeOp: "lte", + comparison_ops.GeOp: "gte", + } + + @singledispatchmethod + def _compile_expression(self, expr: ex.Expression, child: nodes.BigFrameNode) -> algebra_pb2.Expression: + raise NotImplementedError(f"Expression type {type(expr)} not supported in Substrait compiler yet") + + @_compile_expression.register + def _compile_scalar_constant(self, expr: ex.ScalarConstantExpression, child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + val = expr.value + if isinstance(val, int): + pb_expr.literal.i64 = val + elif isinstance(val, float): + pb_expr.literal.fp64 = val + elif isinstance(val, str): + pb_expr.literal.string = val + elif pd.isna(val): + pb_expr.literal.null.varchar.length = 0 + else: + pb_expr.literal.string = str(val) + return pb_expr + + @_compile_expression.register + def _compile_deref(self, expr: ex.DerefOp, child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + try: + idx = list(child.ids).index(expr.id) + pb_expr.selection.direct_reference.struct_field.field = idx + return pb_expr + except ValueError: + raise ValueError(f"Column {expr.id} not found in child schema") + + @_compile_expression.register + def _compile_op_expr(self, expr: ex.OpExpression, child: nodes.BigFrameNode) -> algebra_pb2.Expression: + return self._compile_op(expr.op, expr.inputs, child) + + @singledispatchmethod + def _compile_op(self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + raise NotImplementedError(f"Op type {type(op)} not supported in Substrait compiler yet") + + @_compile_op.register(numeric_ops.AddOp) + @_compile_op.register(numeric_ops.SubOp) + @_compile_op.register(numeric_ops.MulOp) + @_compile_op.register(numeric_ops.DivOp) + @_compile_op.register(comparison_ops.EqOp) + @_compile_op.register(comparison_ops.NeOp) + @_compile_op.register(comparison_ops.LtOp) + @_compile_op.register(comparison_ops.GtOp) + @_compile_op.register(comparison_ops.LeOp) + @_compile_op.register(comparison_ops.GeOp) + def _compile_basic_binops(self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + op_class = type(op) + ext_name = self._OP_TO_EXTENSION[op_class] + return self._compile_basic_binop(ext_name, inputs, child) + + def _compile_basic_binop(self, ext_name: str, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS[ext_name] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr def _convert_schema(self, schema: Any) -> Dict[str, Any]: # Convert bigframes schema to Substrait Type.NamedStruct @@ -335,6 +355,22 @@ def _convert_type(self, dtype: Any) -> Dict[str, Any]: return {"bool": {}} elif dtype == bigframes.dtypes.STRING_DTYPE: return {"string": {}} + elif dtype == bigframes.dtypes.BYTES_DTYPE: + return {"binary": {}} + elif dtype == bigframes.dtypes.DATE_DTYPE: + return {"date": {}} + elif dtype == bigframes.dtypes.DATETIME_DTYPE: + return {"precision_timestamp": {"precision": 6}} + elif dtype == bigframes.dtypes.TIMESTAMP_DTYPE: + return {"precision_timestamp_tz": {"precision": 6}} + elif dtype == bigframes.dtypes.TIME_DTYPE: + # type_variation_reference 1 is for time64, precision 6 is for microseconds + return {"precision_time": {"precision": 6, "type_variation_reference": 1}} + elif dtype in (bigframes.dtypes.NUMERIC_DTYPE, bigframes.dtypes.BIGNUMERIC_DTYPE): + arrow_dtype = dtype.pyarrow_dtype + return {"decimal": {"precision": arrow_dtype.precision, "scale": arrow_dtype.scale}} + elif dtype == bigframes.dtypes.TIMEDELTA_DTYPE: + return {"interval_day": {"precision": 6, "type_variation_reference": 1}} else: # Fallback to string for now return {"string": {}} diff --git a/packages/bigframes/bigframes/session/substrait_executor.py b/packages/bigframes/bigframes/session/substrait_executor.py index 243c51d8c634..1363255c9db9 100644 --- a/packages/bigframes/bigframes/session/substrait_executor.py +++ b/packages/bigframes/bigframes/session/substrait_executor.py @@ -52,7 +52,7 @@ class DataFusionSubstraitConsumer(SubstraitConsumer): Executes Substrait plans using Apache DataFusion. """ - def consume(self, plan: bytes, tables: dict[str, pa.Table]) -> pa.Table: + def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table: # Import datafusion lazily to avoid hard dependency try: import datafusion @@ -76,11 +76,11 @@ def consume(self, plan: bytes, tables: dict[str, pa.Table]) -> pa.Table: import datafusion.substrait - json_str = plan.decode('utf-8') - plan_obj = datafusion.substrait.Plan.from_json(json_str) - print("DEBUG RE-SERIALIZED JSON SUBSTRAIT PLAN:") - print(plan_obj.to_json()) - logical_plan = datafusion.substrait.Consumer.from_substrait_plan(ctx, plan_obj) + #json_str = plan.decode('utf-8') + #print("DEBUG RE-SERIALIZED JSON SUBSTRAIT PLAN:") + #print(plan_obj.to_json()) + datafusion_substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_proto) + logical_plan = datafusion.substrait.Consumer.from_substrait_plan(ctx, datafusion_substrait_plan) df = ctx.create_dataframe_from_logical_plan(logical_plan) return df.to_arrow_table() @@ -107,9 +107,9 @@ def execute( if not self._can_execute(rewritten_plan): return None - substrait_plan = self._compiler.compile(rewritten_plan) + substrait_plan_proto = self._compiler.compile(rewritten_plan) - if substrait_plan is None: + if substrait_plan_proto is None: return None tables = {} @@ -118,7 +118,7 @@ def execute( table_name = f"table_{node.local_data_source.id.hex}" tables[table_name] = node.local_data_source.data - pa_table = self._consumer.consume(substrait_plan, tables) + pa_table = self._consumer.consume(substrait_plan_proto, tables) if peek is not None: pa_table = pa_table.slice(0, peek) diff --git a/packages/bigframes/tests/unit/session/test_substrait_executor.py b/packages/bigframes/tests/unit/session/test_substrait_executor.py index d5a6a5cd2da7..aecc22a3fb6b 100644 --- a/packages/bigframes/tests/unit/session/test_substrait_executor.py +++ b/packages/bigframes/tests/unit/session/test_substrait_executor.py @@ -314,3 +314,37 @@ def test_execute_selection_with_datafusion(): assert result_table.num_rows == 3 assert result_table.column_names == ["c"] assert result_table.column("c").to_pylist() == [1, 2, 3] + + +def test_execute_various_types_with_datafusion(): + from bigframes.session.substrait_executor import DataFusionSubstraitConsumer + import datetime + import pandas as pd + + consumer = DataFusionSubstraitConsumer() + executor = substrait_executor.SubstraitExecutor(consumer) + + session = mocks.create_bigquery_session() + table = pa.Table.from_pydict({ + "bin": [b"a", b"b"], + "dat": [datetime.date(2023, 1, 1), datetime.date(2023, 1, 2)], + "dt": [datetime.datetime(2023, 1, 1, 12, 0), datetime.datetime(2023, 1, 2, 12, 0)], + }) + source = local_data.ManagedArrowTable.from_pyarrow(table) + + scan_items = [] + for name in table.column_names: + scan_items.append(nodes.ScanItem(id=identifiers.ColumnId(name), source_id=name)) + + read_node = nodes.ReadLocalNode( + local_data_source=source, + session=session, + scan_list=nodes.ScanList(items=tuple(scan_items)), + ) + + result = executor.execute(read_node, ordered=True) + assert result is not None + + result_table = pa.Table.from_batches(result.batches().arrow_batches) + assert result_table.num_rows == 2 + assert result_table.column_names == ["bin", "dat", "dt"] From a26aeba7b7a948a596302e53ed4ce6becccd5001 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 19 May 2026 22:45:53 +0000 Subject: [PATCH 4/5] support more ops to substrait --- .../core/compile/substrait/compiler.py | 1188 ++++++++++++++++- .../compile/substrait/test_multiindex_drop.py | 28 + .../bigframes/session/substrait_executor.py | 152 ++- .../tests/unit/test_dataframe_substrait.py | 31 +- 4 files changed, 1368 insertions(+), 31 deletions(-) create mode 100644 packages/bigframes/bigframes/core/compile/substrait/test_multiindex_drop.py diff --git a/packages/bigframes/bigframes/core/compile/substrait/compiler.py b/packages/bigframes/bigframes/core/compile/substrait/compiler.py index 04adca253cee..46a266be2011 100644 --- a/packages/bigframes/bigframes/core/compile/substrait/compiler.py +++ b/packages/bigframes/bigframes/core/compile/substrait/compiler.py @@ -25,9 +25,14 @@ from bigframes.core import bigframe_node, nodes import bigframes.core.expression as ex +import bigframes.dtypes as dtypes import pandas as pd +import bigframes.operations as ops import bigframes.operations.numeric_ops as numeric_ops import bigframes.operations.comparison_ops as comparison_ops +import bigframes.operations.generic_ops as generic_ops +import bigframes.operations.bool_ops as bool_ops +import bigframes.operations.struct_ops as struct_ops class SubstraitCompiler: @@ -35,6 +40,17 @@ class SubstraitCompiler: Compiles BigFrameNode plans to Substrait schema (JSON representation). """ + def _print_node_tree(self, node: bigframe_node.BigFrameNode, indent: int = 0): + import sys + try: + ids = list(node.ids) + except Exception as e: + ids = f"" + sys.stderr.write(" " * indent + f"- {type(node).__name__}: ids={ids}\n") + sys.stderr.flush() + for child in node.child_nodes: + self._print_node_tree(child, indent + 1) + def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: """ Compiles a BigFrameNode to Substrait bytes (JSON encoded via protobuf). @@ -42,6 +58,11 @@ def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: if not self.can_compile(plan): return None + import sys + sys.stderr.write("DEBUG TREE:\n") + sys.stderr.flush() + self._print_node_tree(plan) + pb_rel = self._compile_node(plan) pb_plan = plan_pb2.Plan() @@ -50,7 +71,8 @@ def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: plan_rel = pb_plan.relations.add() plan_rel.root.input.CopyFrom(pb_rel) - plan_rel.root.names.extend([item.column for item in plan.schema.items]) + for item in plan.schema.items: + plan_rel.root.names.extend(self._get_substrait_names(item.column, item.dtype)) for name, anchor in self._EXTENSIONS.items(): ext = pb_plan.extensions.add() @@ -72,8 +94,17 @@ def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: nodes.ProjectionNode, nodes.JoinNode, nodes.AggregateNode, + nodes.OrderByNode, + nodes.PromoteOffsetsNode, + nodes.WindowOpNode, + nodes.ConcatNode, ) - return isinstance(plan, supported_nodes) + import sys + for n in plan.unique_nodes(): + if not isinstance(n, supported_nodes): + sys.stderr.write(f"UNSUPPORTED NODE TYPE: {type(n).__name__} -> {n}\n") + sys.stderr.flush() + return all(isinstance(n, supported_nodes) for n in plan.unique_nodes()) def _compile_node(self, node: bigframe_node.BigFrameNode) -> algebra_pb2.Rel: if isinstance(node, nodes.ReadLocalNode): @@ -88,17 +119,43 @@ def _compile_node(self, node: bigframe_node.BigFrameNode) -> algebra_pb2.Rel: return self._compile_join(node) elif isinstance(node, nodes.AggregateNode): return self._compile_aggregate(node) + elif isinstance(node, nodes.OrderByNode): + return self._compile_orderby(node) + elif isinstance(node, nodes.SliceNode): + return self._compile_slice(node) + elif isinstance(node, nodes.PromoteOffsetsNode): + return self._compile_promote_offsets(node) + elif isinstance(node, nodes.WindowOpNode): + return self._compile_window(node) + elif isinstance(node, nodes.ConcatNode): + return self._compile_concat(node) else: raise NotImplementedError(f"Node type {type(node)} not supported in Substrait compiler yet") def _compile_read(self, node: nodes.ReadLocalNode) -> algebra_pb2.Rel: - table_name = f"table_{node.local_data_source.id.hex}" + table_name = f"table_{id(node)}" rel = algebra_pb2.Rel() read_rel = rel.read read_rel.named_table.names.append(table_name) - schema_dict = self._convert_schema(node.local_data_source.schema) + import bigframes.dtypes as dtypes + fields = [] + types = [] + for item in node.scan_list.items: + col_dtype = node.local_data_source.schema.get_type(item.source_id) + fields.extend(self._get_substrait_names(item.id.sql, col_dtype)) + types.append(self._convert_type(col_dtype)) + + if node.offsets_col is not None: + fields.append(node.offsets_col.sql) + types.append(self._convert_type(dtypes.INT_DTYPE)) + + schema_dict = { + "names": fields, + "struct": {"types": types} + } + print("SCHEMA_DICT:", schema_dict) json_format.ParseDict(schema_dict, read_rel.base_schema) return rel @@ -123,6 +180,22 @@ def _compile_selection(self, node: nodes.SelectionNode) -> algebra_pb2.Rel: return rel + def _compile_promote_offsets(self, node: nodes.PromoteOffsetsNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + project_rel = rel.project + project_rel.input.CopyFrom(input_rel) + + # Add a dummy literal i64 = 0 for the offsets column + expr = project_rel.expressions.add() + expr.literal.i64 = 0 + + child_ids = list(node.child.ids) + project_rel.common.emit.output_mapping.extend(range(len(child_ids) + 1)) + + return rel + def _compile_filter(self, node: nodes.FilterNode) -> algebra_pb2.Rel: input_rel = self._compile_node(node.child) @@ -147,6 +220,10 @@ def _compile_projection(self, node: nodes.ProjectionNode) -> algebra_pb2.Rel: expr_pb = self._compile_expression(expr, node.child) project_rel.expressions.add().CopyFrom(expr_pb) + child_ids = list(node.child.ids) + num_exprs = len(node.assignments) + project_rel.common.emit.output_mapping.extend(range(len(child_ids) + num_exprs)) + return rel def _compile_join(self, node: nodes.JoinNode) -> algebra_pb2.Rel: @@ -154,8 +231,13 @@ def _compile_join(self, node: nodes.JoinNode) -> algebra_pb2.Rel: right_rel = self._compile_node(node.right_child) rel = algebra_pb2.Rel() + if node.type == "cross": + cross_rel = rel.cross + cross_rel.left.CopyFrom(left_rel) + cross_rel.right.CopyFrom(right_rel) + return rel + join_rel = rel.join - join_rel.left.CopyFrom(left_rel) join_rel.right.CopyFrom(right_rel) @@ -169,21 +251,45 @@ def _compile_join(self, node: nodes.JoinNode) -> algebra_pb2.Rel: left_len = len(node.left_child.schema) + import sys + sys.stderr.write(f"JOIN CONDITIONS: {node.conditions}\n") + sys.stderr.flush() + eq_expressions = [] for left_deref, right_deref in node.conditions: left_idx = list(node.left_child.ids).index(left_deref.id) right_idx = list(node.right_child.ids).index(right_deref.id) + left_len + arg1 = algebra_pb2.Expression() + arg1.selection.direct_reference.struct_field.field = left_idx + + arg2 = algebra_pb2.Expression() + arg2.selection.direct_reference.struct_field.field = right_idx + eq_expr = algebra_pb2.Expression() - eq_expr.scalar_function.function_reference = 5 # equal + eq_expr.scalar_function.function_reference = self._EXTENSIONS["equal"] + eq_expr.scalar_function.arguments.add().value.CopyFrom(arg1) + eq_expr.scalar_function.arguments.add().value.CopyFrom(arg2) + + isnull1_expr = algebra_pb2.Expression() + isnull1_expr.scalar_function.function_reference = self._EXTENSIONS["is_null"] + isnull1_expr.scalar_function.arguments.add().value.CopyFrom(arg1) + + isnull2_expr = algebra_pb2.Expression() + isnull2_expr.scalar_function.function_reference = self._EXTENSIONS["is_null"] + isnull2_expr.scalar_function.arguments.add().value.CopyFrom(arg2) - arg1 = eq_expr.scalar_function.arguments.add() - arg1.value.selection.direct_reference.struct_field.field = left_idx + both_null_expr = algebra_pb2.Expression() + both_null_expr.scalar_function.function_reference = self._EXTENSIONS["and"] + both_null_expr.scalar_function.arguments.add().value.CopyFrom(isnull1_expr) + both_null_expr.scalar_function.arguments.add().value.CopyFrom(isnull2_expr) - arg2 = eq_expr.scalar_function.arguments.add() - arg2.value.selection.direct_reference.struct_field.field = right_idx + null_safe_eq = algebra_pb2.Expression() + null_safe_eq.scalar_function.function_reference = self._EXTENSIONS["or"] + null_safe_eq.scalar_function.arguments.add().value.CopyFrom(eq_expr) + null_safe_eq.scalar_function.arguments.add().value.CopyFrom(both_null_expr) - eq_expressions.append(eq_expr) + eq_expressions.append(null_safe_eq) if len(eq_expressions) > 1: expr = eq_expressions[0] @@ -203,6 +309,331 @@ def _compile_join(self, node: nodes.JoinNode) -> algebra_pb2.Rel: return rel + def _compile_window(self, node: nodes.WindowOpNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + # ProjectRel 1: Evaluate all window functions (standard ones + lag/lead for DiffOp) + rel1 = algebra_pb2.Rel() + project_rel1 = rel1.project + project_rel1.input.CopyFrom(input_rel) + + import bigframes.dtypes as dtypes + import bigframes.operations.aggregations as agg_ops + import bigframes.core.window_spec as window_spec_module + + has_diff = False + has_struct_fill = False + for cdef in node.agg_exprs: + op = cdef.expression.op + if isinstance(op, agg_ops.DiffOp): + has_diff = True + if isinstance(op, (agg_ops.LastNonNullOp, agg_ops.FirstNonNullOp)) and len(node.window_spec.ordering) > 0: + has_struct_fill = True + + expr = project_rel1.expressions.add() + wf = expr.window_function + + if isinstance(op, agg_ops.DiffOp): + periods = op.periods + if periods >= 0: + func_name = "lag" + else: + func_name = "lead" + periods = -periods + elif isinstance(op, agg_ops.SumOp): + func_name = "sum" + elif isinstance(op, agg_ops.MaxOp): + func_name = "max" + elif isinstance(op, agg_ops.MinOp): + func_name = "min" + elif isinstance(op, agg_ops.MeanOp): + func_name = "mean" + elif isinstance(op, agg_ops.CountOp): + func_name = "count" + elif isinstance(op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + func_name = "count" + elif isinstance(op, agg_ops.StdOp): + func_name = "stddev" + elif isinstance(op, agg_ops.VarOp): + func_name = "var" + elif isinstance(op, agg_ops.PopVarOp): + func_name = "var_pop" + elif isinstance(op, agg_ops.RowNumberOp): + func_name = "row_number" + elif isinstance(op, agg_ops.RankOp): + func_name = "rank" + elif isinstance(op, agg_ops.DenseRankOp): + func_name = "dense_rank" + elif isinstance(op, agg_ops.FirstOp): + func_name = "first_value" + elif isinstance(op, agg_ops.FirstNonNullOp): + if len(node.window_spec.ordering) > 0: + func_name = "min" + else: + func_name = "first_value" + elif isinstance(op, agg_ops.LastOp): + func_name = "last_value" + elif isinstance(op, agg_ops.LastNonNullOp): + if len(node.window_spec.ordering) > 0: + func_name = "max" + else: + func_name = "last_value" + elif isinstance(op, agg_ops.ShiftOp): + periods = op.periods + if periods >= 0: + func_name = "lag" + else: + func_name = "lead" + periods = -periods + else: + raise NotImplementedError(f"Aggregation operator {type(op).__name__} not supported in window compilation yet") + + wf.function_reference = self._EXTENSIONS[func_name] + + if isinstance(op, (agg_ops.LastNonNullOp, agg_ops.FirstNonNullOp)) and len(node.window_spec.ordering) > 0: + import bigframes.operations.generic_ops as generic_ops + from google.protobuf import json_format + + input_expr = cdef.expression.children[0] + order_expr = node.window_spec.ordering[0].scalar_expression + in_dtype_val = self._get_expression_dtype(input_expr, node.child) + + # Build NOT_NULL(input_expr) + cond = algebra_pb2.Expression() + cond.scalar_function.function_reference = self._EXTENSIONS[self._OP_TO_EXTENSION[generic_ops.NotNullOp]] + cond_arg = cond.scalar_function.arguments.add() + cond_arg.value.CopyFrom(self._compile_expression(input_expr, node.child)) + json_format.ParseDict({"bool": {}}, cond.scalar_function.output_type) + + # Build struct_expr struct(order_expr, input_expr) + struct_expr = algebra_pb2.Expression() + struct_expr.scalar_function.function_reference = self._EXTENSIONS["struct"] + + arg0 = struct_expr.scalar_function.arguments.add() + arg0.value.CopyFrom(self._compile_expression(order_expr, node.child)) + + arg1 = struct_expr.scalar_function.arguments.add() + arg1.value.CopyFrom(self._compile_expression(input_expr, node.child)) + + # Output type of struct is struct + struct_type = { + "struct": { + "types": [ + {"i64": {}}, + self._convert_type(in_dtype_val) + ] + } + } + json_format.ParseDict(struct_type, struct_expr.scalar_function.output_type) + + is_min = isinstance(op, agg_ops.FirstNonNullOp) + dummy_idx = 9223372036854775807 if is_min else -1 + + # Build dummy struct struct(dummy_idx, NULL) + dummy_struct = algebra_pb2.Expression() + dummy_struct.scalar_function.function_reference = self._EXTENSIONS["struct"] + + arg0_d = dummy_struct.scalar_function.arguments.add() + arg0_d.value.literal.i64 = dummy_idx + + arg1_d = dummy_struct.scalar_function.arguments.add() + target_type_proto = type_pb2.Type() + json_format.ParseDict(self._convert_type(in_dtype_val), target_type_proto) + arg1_d.value.literal.null.CopyFrom(target_type_proto) + + json_format.ParseDict(struct_type, dummy_struct.scalar_function.output_type) + + # Build IfThen expression + ifthen_expr = algebra_pb2.Expression() + if_clause = ifthen_expr.if_then.ifs.add() + getattr(if_clause, "if").CopyFrom(cond) + if_clause.then.CopyFrom(struct_expr) + getattr(ifthen_expr.if_then, "else").CopyFrom(dummy_struct) + + arg = wf.arguments.add() + arg.value.CopyFrom(ifthen_expr) + else: + for child_expr in cdef.expression.children: + arg = wf.arguments.add() + arg.value.CopyFrom(self._compile_expression(child_expr, node.child)) + + if isinstance(op, (agg_ops.ShiftOp, agg_ops.DiffOp)): + arg = wf.arguments.add() + arg.value.literal.i64 = periods + + if isinstance(op, (agg_ops.LastNonNullOp, agg_ops.FirstNonNullOp)) and len(node.window_spec.ordering) > 0: + import bigframes.dtypes as dtypes + import pandas as pd + import pyarrow as pa + input_expr = cdef.expression.children[0] + order_expr = node.window_spec.ordering[0].scalar_expression + in_dtype_order = self._get_expression_dtype(order_expr, node.child) + in_dtype_val = self._get_expression_dtype(input_expr, node.child) + arrow_order = dtypes.bigframes_dtype_to_arrow_dtype(in_dtype_order) + arrow_val = dtypes.bigframes_dtype_to_arrow_dtype(in_dtype_val) + + struct_dtype = pd.ArrowDtype(pa.struct([ + pa.field("c0", arrow_order, nullable=True), + pa.field("c1", arrow_val, nullable=True) + ])) + from google.protobuf import json_format + json_format.ParseDict(self._convert_type(struct_dtype), wf.output_type) + else: + dt = node.field_by_id[cdef.id].dtype + from google.protobuf import json_format + json_format.ParseDict(self._convert_type(dt), wf.output_type) + + # Compile Bounds + if isinstance(node.window_spec.bounds, window_spec_module.RowsWindowBounds): + wf.bounds_type = algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROWS + self._compile_bound(node.window_spec.bounds.start, wf.lower_bound) + self._compile_bound(node.window_spec.bounds.end, wf.upper_bound) + elif isinstance(node.window_spec.bounds, window_spec_module.RangeWindowBounds): + wf.bounds_type = algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_RANGE + self._compile_bound(node.window_spec.bounds.start, wf.lower_bound) + self._compile_bound(node.window_spec.bounds.end, wf.upper_bound) + else: + wf.bounds_type = algebra_pb2.Expression.WindowFunction.BoundsType.BOUNDS_TYPE_ROWS + wf.lower_bound.unbounded.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + wf.upper_bound.unbounded.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + + # Partition Expressions + for part_expr in node.window_spec.grouping_keys: + expr_pb = self._compile_expression(part_expr, node.child) + wf.partitions.add().CopyFrom(expr_pb) + + # Sorts + for sort_expr in node.window_spec.ordering: + expr_pb = self._compile_expression(sort_expr.scalar_expression, node.child) + sort_field = wf.sorts.add() + sort_field.expr.CopyFrom(expr_pb) + if sort_expr.direction.is_ascending: + if sort_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST + else: + if sort_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST + + child_ids = list(node.child.ids) + num_exprs = len(node.agg_exprs) + + if not (has_diff or has_struct_fill): + project_rel1.common.emit.output_mapping.extend(range(len(child_ids) + num_exprs)) + return rel1 + + # ProjectRel 2: Compute subtraction for DiffOp columns or extract struct fields + rel2 = algebra_pb2.Rel() + project_rel2 = rel2.project + project_rel2.input.CopyFrom(rel1) + + num_cols_rel1 = len(child_ids) + num_exprs + new_expr_idx = 0 + output_mapping = list(range(len(child_ids))) + + for cdef_idx, cdef in enumerate(node.agg_exprs): + op = cdef.expression.op + if isinstance(op, agg_ops.DiffOp): + expr = project_rel2.expressions.add() + expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + + # Arg 1: current column + col_idx = list(node.child.ids).index(cdef.expression.children[0].id) + arg1 = expr.scalar_function.arguments.add() + arg1.value.selection.direct_reference.struct_field.field = col_idx + + # Arg 2: lag column from rel1 + lag_idx = len(child_ids) + cdef_idx + arg2 = expr.scalar_function.arguments.add() + arg2.value.selection.direct_reference.struct_field.field = lag_idx + + dt = node.field_by_id[cdef.id].dtype + from google.protobuf import json_format + json_format.ParseDict(self._convert_type(dt), expr.scalar_function.output_type) + + # Output mapping points to this newly projected expression + output_mapping.append(num_cols_rel1 + new_expr_idx) + new_expr_idx += 1 + elif isinstance(op, (agg_ops.LastNonNullOp, agg_ops.FirstNonNullOp)) and len(node.window_spec.ordering) > 0: + from google.protobuf import json_format + + window_col_idx = len(child_ids) + cdef_idx + is_min = isinstance(op, agg_ops.FirstNonNullOp) + dummy_idx = 9223372036854775807 if is_min else -1 + + # 1. Build get_field(window_col_idx, "c0") + get_c0 = algebra_pb2.Expression() + get_c0.scalar_function.function_reference = self._EXTENSIONS["get_field"] + arg0 = get_c0.scalar_function.arguments.add() + arg0.value.selection.direct_reference.struct_field.field = window_col_idx + arg1 = get_c0.scalar_function.arguments.add() + arg1.value.literal.string = "c0" + json_format.ParseDict({"i64": {}}, get_c0.scalar_function.output_type) + + # 2. Build get_c0 == dummy_idx + eq_c0 = algebra_pb2.Expression() + eq_c0.scalar_function.function_reference = self._EXTENSIONS["equal"] + arg0 = eq_c0.scalar_function.arguments.add() + arg0.value.CopyFrom(get_c0) + arg1 = eq_c0.scalar_function.arguments.add() + arg1.value.literal.i64 = dummy_idx + json_format.ParseDict({"bool": {}}, eq_c0.scalar_function.output_type) + + # 3. Build the IfThen expression + expr = project_rel2.expressions.add() + if_clause = expr.if_then.ifs.add() + getattr(if_clause, "if").CopyFrom(eq_c0) + + # Then value: NULL of target type + dt = node.field_by_id[cdef.id].dtype + target_type_proto = type_pb2.Type() + json_format.ParseDict(self._convert_type(dt), target_type_proto) + if_clause.then.literal.null.CopyFrom(target_type_proto) + + # Else value: get_field(window_col_idx, "c1") + get_c1 = algebra_pb2.Expression() + get_c1.scalar_function.function_reference = self._EXTENSIONS["get_field"] + arg0 = get_c1.scalar_function.arguments.add() + arg0.value.selection.direct_reference.struct_field.field = window_col_idx + arg1 = get_c1.scalar_function.arguments.add() + arg1.value.literal.string = "c1" + json_format.ParseDict(self._convert_type(dt), get_c1.scalar_function.output_type) + + getattr(expr.if_then, "else").CopyFrom(get_c1) + + # Output mapping points to this newly projected expression + output_mapping.append(num_cols_rel1 + new_expr_idx) + new_expr_idx += 1 + else: + # Output mapping points to the direct window column from rel1 + output_mapping.append(len(child_ids) + cdef_idx) + + project_rel2.common.emit.output_mapping.extend(output_mapping) + return rel2 + + def _compile_bound(self, val: typing.Optional[int], bound_msg: algebra_pb2.Expression.WindowFunction.Bound): + if val is None: + bound_msg.unbounded.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.Unbounded()) + elif val == 0: + bound_msg.current_row.CopyFrom(algebra_pb2.Expression.WindowFunction.Bound.CurrentRow()) + elif val < 0: + bound_msg.preceding.offset = -val + else: + bound_msg.following.offset = val + + def _compile_concat(self, node: nodes.ConcatNode) -> algebra_pb2.Rel: + rel = algebra_pb2.Rel() + set_rel = rel.set + set_rel.op = algebra_pb2.SetRel.SetOp.SET_OP_UNION_ALL + + for child in node.children: + set_rel.inputs.append(self._compile_node(child)) + + return rel + def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: input_rel = self._compile_node(node.child) @@ -218,26 +649,151 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: expr.selection.direct_reference.struct_field.field = idx import bigframes.operations.aggregations as agg_ops - for agg, _ in node.aggregations: + import bigframes.dtypes as dtypes + size_count = 0 + for agg, out_col_id in node.aggregations: + distinct = False if isinstance(agg.op, agg_ops.SumOp): - func_ref = 11 + func_ref = self._EXTENSIONS["sum"] elif isinstance(agg.op, agg_ops.MaxOp): - func_ref = 12 + func_ref = self._EXTENSIONS["max"] + elif isinstance(agg.op, agg_ops.MinOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.MeanOp): + func_ref = self._EXTENSIONS["mean"] + elif isinstance(agg.op, agg_ops.CountOp): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + func_ref = self._EXTENSIONS["count"] + elif isinstance(agg.op, agg_ops.NuniqueOp): + func_ref = self._EXTENSIONS["count"] + distinct = True + elif isinstance(agg.op, agg_ops.StdOp): + func_ref = self._EXTENSIONS["stddev"] + elif isinstance(agg.op, agg_ops.VarOp): + func_ref = self._EXTENSIONS["var"] + elif isinstance(agg.op, agg_ops.PopVarOp): + func_ref = self._EXTENSIONS["var_pop"] + elif isinstance(agg.op, agg_ops.AnyValueOp): + func_ref = self._EXTENSIONS["min"] + elif isinstance(agg.op, agg_ops.AllOp): + func_ref = self._EXTENSIONS["bool_and"] + elif isinstance(agg.op, agg_ops.AnyOp): + func_ref = self._EXTENSIONS["bool_or"] + elif isinstance(agg.op, agg_ops.ProductOp): + func_ref = self._EXTENSIONS["product"] + elif isinstance(agg.op, agg_ops.MedianOp): + func_ref = self._EXTENSIONS["median"] else: raise NotImplementedError(f"Aggregation {type(agg.op)} not supported in Substrait compiler yet") measure = agg_rel.measures.add() measure.measure.function_reference = func_ref + if distinct or isinstance(agg.op, agg_ops.NuniqueOp): + measure.measure.invocation = algebra_pb2.AggregateFunction.AGGREGATION_INVOCATION_DISTINCT - if hasattr(agg, "column_references"): + if isinstance(agg.op, (agg_ops.SizeOp, agg_ops.SizeUnaryOp)): + size_count += 1 + arg = measure.measure.arguments.add() + arg.value.literal.i64 = size_count + elif hasattr(agg, "column_references"): for col_id in agg.column_references: try: idx = list(node.child.ids).index(col_id) - arg = measure.measure.arguments.add() - arg.value.selection.direct_reference.struct_field.field = idx + field_expr = algebra_pb2.Expression() + field_expr.selection.direct_reference.struct_field.field = idx + + col_dtype = node.child.schema.items[idx].dtype + is_bool = col_dtype == dtypes.BOOL_DTYPE + if isinstance(agg.op, (agg_ops.StdOp, agg_ops.VarOp, agg_ops.PopVarOp)) or (isinstance(agg.op, (agg_ops.SumOp, agg_ops.MeanOp)) and is_bool): + casted_expr = self._compile_cast(field_expr, dtypes.FLOAT_DTYPE) + arg = measure.measure.arguments.add() + arg.value.CopyFrom(casted_expr) + else: + arg = measure.measure.arguments.add() + arg.value.CopyFrom(field_expr) except ValueError: pass - + + if node.dropna and node.by_column_ids: + not_null_exprs = [] + for idx in range(len(node.by_column_ids)): + key_expr = algebra_pb2.Expression() + key_expr.selection.direct_reference.struct_field.field = idx + + not_null_op = algebra_pb2.Expression() + not_null_op.scalar_function.function_reference = self._EXTENSIONS["is_not_null"] + not_null_op.scalar_function.arguments.add().value.CopyFrom(key_expr) + not_null_exprs.append(not_null_op) + + if len(not_null_exprs) > 1: + expr = not_null_exprs[0] + for e in not_null_exprs[1:]: + and_expr = algebra_pb2.Expression() + and_expr.scalar_function.function_reference = self._EXTENSIONS["and"] + and_expr.scalar_function.arguments.add().value.CopyFrom(expr) + and_expr.scalar_function.arguments.add().value.CopyFrom(e) + expr = and_expr + else: + expr = not_null_exprs[0] + + filter_rel = algebra_pb2.Rel() + filter_rel.filter.input.CopyFrom(rel) + filter_rel.filter.condition.CopyFrom(expr) + rel = filter_rel + + return rel + + def _compile_orderby(self, node: nodes.OrderByNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + sort_rel = rel.sort + sort_rel.input.CopyFrom(input_rel) + + for ord_expr in node.by: + sort_field = sort_rel.sorts.add() + + # Compile the expression: + expr_pb = self._compile_expression(ord_expr.scalar_expression, node.child) + sort_field.expr.CopyFrom(expr_pb) + + # Map sort direction: + is_asc = ord_expr.direction.is_ascending + if is_asc: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_ASC_NULLS_FIRST + else: + if ord_expr.na_last: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_LAST + else: + sort_field.direction = algebra_pb2.SortField.SortDirection.SORT_DIRECTION_DESC_NULLS_FIRST + + return rel + + def _compile_slice(self, node: nodes.SliceNode) -> algebra_pb2.Rel: + input_rel = self._compile_node(node.child) + + rel = algebra_pb2.Rel() + fetch_rel = rel.fetch + fetch_rel.input.CopyFrom(input_rel) + + is_simple = (node.start is None or node.start >= 0) and (node.stop is None or node.stop >= 0) and (node.step is None or node.step == 1) + + if is_simple: + start = node.start if node.start is not None else 0 + fetch_rel.offset = start + + if node.stop is not None: + fetch_rel.count = max(0, node.stop - start) + else: + fetch_rel.count = -1 + else: + fetch_rel.offset = 0 + fetch_rel.count = -1 + return rel _EXTENSIONS = { @@ -246,7 +802,7 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: "multiply": 3, "divide": 4, "equal": 5, - "ne": 6, + "not_equal": 6, "lt": 7, "gt": 8, "lte": 9, @@ -254,6 +810,42 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: "sum": 11, "max": 12, "and": 13, + "min": 14, + "mean": 15, + "count": 16, + "stddev": 17, + "var": 18, + "any_value": 19, + "all": 20, + "any": 21, + "coalesce": 22, + "or": 23, + "least": 24, + "greatest": 25, + "is_null": 26, + "is_not_null": 27, + "nullif": 28, + "sqrt": 29, + "bool_and": 30, + "bool_or": 31, + "product": 32, + "not": 33, + "mod": 34, + "floor": 35, + "abs": 36, + "ceil": 37, + "median": 38, + "xor": 40, + "var_pop": 53, + "row_number": 60, + "rank": 61, + "dense_rank": 62, + "first_value": 63, + "last_value": 64, + "lag": 65, + "lead": 66, + "struct": 67, + "get_field": 68, } _OP_TO_EXTENSION = { @@ -261,12 +853,24 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: numeric_ops.SubOp: "subtract", numeric_ops.MulOp: "multiply", numeric_ops.DivOp: "divide", + numeric_ops.ModOp: "mod", comparison_ops.EqOp: "equal", - comparison_ops.NeOp: "ne", + comparison_ops.NeOp: "not_equal", comparison_ops.LtOp: "lt", comparison_ops.GtOp: "gt", comparison_ops.LeOp: "lte", comparison_ops.GeOp: "gte", + generic_ops.FillNaOp: "coalesce", + generic_ops.CoalesceOp: "coalesce", + bool_ops.AndOp: "and", + bool_ops.OrOp: "or", + bool_ops.XorOp: "xor", + generic_ops.InvertOp: "not", + numeric_ops.AbsOp: "abs", + numeric_ops.CeilOp: "ceil", + numeric_ops.FloorOp: "floor", + generic_ops.IsNullOp: "is_null", + generic_ops.NotNullOp: "is_not_null", } @singledispatchmethod @@ -275,14 +879,32 @@ def _compile_expression(self, expr: ex.Expression, child: nodes.BigFrameNode) -> @_compile_expression.register def _compile_scalar_constant(self, expr: ex.ScalarConstantExpression, child: nodes.BigFrameNode) -> algebra_pb2.Expression: + import datetime pb_expr = algebra_pb2.Expression() val = expr.value - if isinstance(val, int): + if isinstance(val, bool): + pb_expr.literal.boolean = val + elif isinstance(val, int): pb_expr.literal.i64 = val elif isinstance(val, float): pb_expr.literal.fp64 = val elif isinstance(val, str): pb_expr.literal.string = val + elif isinstance(val, (pd.Timestamp, datetime.datetime)): + if getattr(val, "tzinfo", None) is not None: + epoch = pd.Timestamp("1970-01-01", tz=val.tzinfo) + us = int((val - epoch).total_seconds() * 1_000_000) + pb_expr.literal.precision_timestamp_tz.precision = 6 + pb_expr.literal.precision_timestamp_tz.value = us + else: + epoch = pd.Timestamp("1970-01-01") + us = int((val - epoch).total_seconds() * 1_000_000) + pb_expr.literal.precision_timestamp.precision = 6 + pb_expr.literal.precision_timestamp.value = us + elif isinstance(val, datetime.date): + epoch = datetime.date(1970, 1, 1) + days = (val - epoch).days + pb_expr.literal.date = days elif pd.isna(val): pb_expr.literal.null.varchar.length = 0 else: @@ -307,16 +929,242 @@ def _compile_op_expr(self, expr: ex.OpExpression, child: nodes.BigFrameNode) -> def _compile_op(self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: raise NotImplementedError(f"Op type {type(op)} not supported in Substrait compiler yet") + @_compile_op.register(ops.AsTypeOp) + def _compile_astype(self, op: ops.AsTypeOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + arg_expr = self._compile_expression(inputs[0], child) + return self._compile_cast(arg_expr, op.to_type) + + @_compile_op.register(struct_ops.StructOp) + def _compile_struct_op(self, op: struct_ops.StructOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["struct"] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + + @_compile_op.register(struct_ops.StructFieldOp) + def _compile_struct_field_op(self, op: struct_ops.StructFieldOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["get_field"] + + # Arg 0: the struct + arg_expr = self._compile_expression(inputs[0], child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + + # Arg 1: the field name as string literal + literal_expr = algebra_pb2.Expression() + literal_expr.literal.string = str(op.name_or_index) + pb_expr.scalar_function.arguments.add().value.CopyFrom(literal_expr) + return pb_expr + + def _compile_cast(self, input_expr: algebra_pb2.Expression, target_dtype: Any) -> algebra_pb2.Expression: + if input_expr.HasField("literal") and input_expr.literal.HasField("null"): + pb_expr = algebra_pb2.Expression() + type_dict = self._convert_type(target_dtype) + json_format.ParseDict(type_dict, pb_expr.literal.null) + return pb_expr + + pb_expr = algebra_pb2.Expression() + cast = pb_expr.cast + cast.input.CopyFrom(input_expr) + + type_dict = self._convert_type(target_dtype) + json_format.ParseDict(type_dict, cast.type) + + cast.failure_behavior = algebra_pb2.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL + return pb_expr + + def _get_expression_dtype(self, expr: ex.Expression, child: nodes.BigFrameNode) -> Any: + import bigframes.dtypes as dtypes + if isinstance(expr, ex.ScalarConstantExpression): + if expr.value is None or pd.isna(expr.value): + return None + return expr.dtype or dtypes.infer_literal_type(expr.value) + elif isinstance(expr, ex.DerefOp): + try: + idx = list(child.ids).index(expr.id) + return child.schema.items[idx].dtype + except ValueError: + pass + elif isinstance(expr, ex.OpExpression): + try: + input_dtypes = [self._get_expression_dtype(inp, child) for inp in expr.inputs] + return expr.op.output_type(*input_dtypes) + except Exception: + pass + return dtypes.STRING_DTYPE + + def _get_common_type(self, dtypes_list: Sequence[Any]) -> Any: + import bigframes.dtypes as dtypes + non_null_dtypes = [dt for dt in dtypes_list if dt is not None] + if not non_null_dtypes: + return dtypes.STRING_DTYPE + if len(set(non_null_dtypes)) == 1: + return non_null_dtypes[0] + if any(dt == dtypes.STRING_DTYPE for dt in non_null_dtypes): + return dtypes.STRING_DTYPE + if any(dt == dtypes.FLOAT_DTYPE for dt in non_null_dtypes): + return dtypes.FLOAT_DTYPE + if any(dt == dtypes.INT_DTYPE for dt in non_null_dtypes): + return dtypes.INT_DTYPE + return dtypes.STRING_DTYPE + + @_compile_op.register(ops.CaseWhenOp) + def _compile_casewhen(self, op: ops.CaseWhenOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + + then_dtypes = [self._get_expression_dtype(inputs[idx], child) for idx in range(1, len(inputs), 2)] + common_dtype = self._get_common_type(then_dtypes) + + for idx in range(0, len(inputs), 2): + pred = self._compile_expression(inputs[idx], child) + val_expr = self._compile_expression(inputs[idx+1], child) + + val_dtype = then_dtypes[idx // 2] + if val_dtype != common_dtype: + val = self._compile_cast(val_expr, common_dtype) + else: + val = val_expr + + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(pred) + if_clause.then.CopyFrom(val) + + type_dict = self._convert_type(common_dtype) + json_format.ParseDict(type_dict, getattr(ifthen, "else").literal.null) + return pb_expr + + @_compile_op.register(generic_ops.WhereOp) + def _compile_where(self, op: generic_ops.WhereOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + + pred = self._compile_expression(inputs[1], child) + then_val = self._compile_expression(inputs[0], child) + else_val = self._compile_expression(inputs[2], child) + + then_dtype = self._get_expression_dtype(inputs[0], child) + else_dtype = self._get_expression_dtype(inputs[2], child) + common_dtype = self._get_common_type([then_dtype, else_dtype]) + + casted_then = self._compile_cast(then_val, common_dtype) + casted_else = self._compile_cast(else_val, common_dtype) + + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(pred) + if_clause.then.CopyFrom(casted_then) + + getattr(ifthen, "else").CopyFrom(casted_else) + return pb_expr + + @_compile_op.register(numeric_ops.DivOp) + def _compile_div_op(self, op: numeric_ops.DivOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + for arg in inputs: + arg_expr = self._compile_expression(arg, child) + casted_arg = self._compile_cast(arg_expr, dtypes.FLOAT_DTYPE) + pb_expr.scalar_function.arguments.add().value.CopyFrom(casted_arg) + return pb_expr + + @_compile_op.register(numeric_ops.FloorDivOp) + def _compile_floor_div_op(self, op: numeric_ops.FloorDivOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + + dividend_expr = self._compile_expression(inputs[0], child) + divisor_expr = self._compile_expression(inputs[1], child) + + # Calculate standard floor division + div_expr = algebra_pb2.Expression() + div_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + + # Cast to float for standard division + casted_dividend = self._compile_cast(dividend_expr, dtypes.FLOAT_DTYPE) + casted_divisor = self._compile_cast(divisor_expr, dtypes.FLOAT_DTYPE) + + div_expr.scalar_function.arguments.add().value.CopyFrom(casted_dividend) + div_expr.scalar_function.arguments.add().value.CopyFrom(casted_divisor) + + floor_expr = algebra_pb2.Expression() + floor_expr.scalar_function.function_reference = self._EXTENSIONS["floor"] + floor_expr.scalar_function.arguments.add().value.CopyFrom(div_expr) + + # If both operands are integer/boolean, we short-circuit division by 0 to return 0 + left_dtype = self._get_expression_dtype(inputs[0], child) + right_dtype = self._get_expression_dtype(inputs[1], child) + + is_left_int = left_dtype == dtypes.INT_DTYPE or left_dtype == dtypes.BOOL_DTYPE + is_right_int = right_dtype == dtypes.INT_DTYPE or right_dtype == dtypes.BOOL_DTYPE + + if is_left_int and is_right_int: + # If divisor is 0, return 0 * dividend (to propagate nulls) + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + + eq_expr = algebra_pb2.Expression() + eq_expr.scalar_function.function_reference = self._EXTENSIONS["equal"] + eq_expr.scalar_function.arguments.add().value.CopyFrom(divisor_expr) + eq_expr.scalar_function.arguments.add().value.CopyFrom(zero_i64) + + zero_result = algebra_pb2.Expression() + zero_result.scalar_function.function_reference = self._EXTENSIONS["multiply"] + zero_result.scalar_function.arguments.add().value.CopyFrom(dividend_expr) + zero_result.scalar_function.arguments.add().value.CopyFrom(zero_i64) + + pb_expr = algebra_pb2.Expression() + ifthen = pb_expr.if_then + if_clause = ifthen.ifs.add() + getattr(if_clause, "if").CopyFrom(eq_expr) + if_clause.then.CopyFrom(zero_result) + + # Else, cast float floor_expr to int64 + casted_floor = self._compile_cast(floor_expr, dtypes.INT_DTYPE) + getattr(ifthen, "else").CopyFrom(casted_floor) + return pb_expr + + return floor_expr + + @_compile_op.register(generic_ops.IsInOp) + def _compile_isin(self, op: generic_ops.IsInOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.singular_or_list.value.CopyFrom(self._compile_expression(inputs[0], child)) + for val in op.values: + opt_expr = self._compile_expression(ex.const(val), child) + pb_expr.singular_or_list.options.add().CopyFrom(opt_expr) + return pb_expr + + @_compile_op.register(generic_ops.FillNaOp) + def _compile_fillna_op(self, op: generic_ops.FillNaOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + first_expr = self._compile_expression(inputs[0], child) + first_dtype = self._get_expression_dtype(inputs[0], child) + second_expr = self._compile_expression(inputs[1], child) + second_dtype = self._get_expression_dtype(inputs[1], child) + + if first_dtype is not None and second_dtype != first_dtype: + second_expr = self._compile_cast(second_expr, first_dtype) + + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS["coalesce"] + pb_expr.scalar_function.arguments.add().value.CopyFrom(first_expr) + pb_expr.scalar_function.arguments.add().value.CopyFrom(second_expr) + return pb_expr + + @_compile_op.register(generic_ops.CoalesceOp) @_compile_op.register(numeric_ops.AddOp) @_compile_op.register(numeric_ops.SubOp) @_compile_op.register(numeric_ops.MulOp) - @_compile_op.register(numeric_ops.DivOp) @_compile_op.register(comparison_ops.EqOp) @_compile_op.register(comparison_ops.NeOp) @_compile_op.register(comparison_ops.LtOp) @_compile_op.register(comparison_ops.GtOp) @_compile_op.register(comparison_ops.LeOp) @_compile_op.register(comparison_ops.GeOp) + @_compile_op.register(bool_ops.AndOp) + @_compile_op.register(bool_ops.OrOp) + @_compile_op.register(bool_ops.XorOp) def _compile_basic_binops(self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: op_class = type(op) ext_name = self._OP_TO_EXTENSION[op_class] @@ -330,6 +1178,112 @@ def _compile_basic_binop(self, ext_name: str, inputs: Sequence[ex.Expression], c pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) return pb_expr + @_compile_op.register(numeric_ops.ModOp) + def _compile_mod_op(self, op: numeric_ops.ModOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + import bigframes.dtypes as dtypes + a_expr = self._compile_expression(inputs[0], child) + b_expr = self._compile_expression(inputs[1], child) + + div_expr = algebra_pb2.Expression() + div_expr.scalar_function.function_reference = self._EXTENSIONS["divide"] + + a_float = self._compile_cast(a_expr, dtypes.FLOAT_DTYPE) + b_float = self._compile_cast(b_expr, dtypes.FLOAT_DTYPE) + div_expr.scalar_function.arguments.add().value.CopyFrom(a_float) + div_expr.scalar_function.arguments.add().value.CopyFrom(b_float) + + floor_expr = algebra_pb2.Expression() + floor_expr.scalar_function.function_reference = self._EXTENSIONS["floor"] + floor_expr.scalar_function.arguments.add().value.CopyFrom(div_expr) + + mul_expr = algebra_pb2.Expression() + mul_expr.scalar_function.function_reference = self._EXTENSIONS["multiply"] + mul_expr.scalar_function.arguments.add().value.CopyFrom(b_float) + mul_expr.scalar_function.arguments.add().value.CopyFrom(floor_expr) + + sub_expr = algebra_pb2.Expression() + sub_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + sub_expr.scalar_function.arguments.add().value.CopyFrom(a_float) + sub_expr.scalar_function.arguments.add().value.CopyFrom(mul_expr) + + a_dtype = self._get_expression_dtype(inputs[0], child) + b_dtype = self._get_expression_dtype(inputs[1], child) + common_dtype = self._get_common_type([a_dtype, b_dtype]) + + if common_dtype == dtypes.INT_DTYPE: + return self._compile_cast(sub_expr, dtypes.INT_DTYPE) + return sub_expr + + @_compile_op.register(numeric_ops.AbsOp) + @_compile_op.register(numeric_ops.CeilOp) + @_compile_op.register(numeric_ops.FloorOp) + @_compile_op.register(generic_ops.IsNullOp) + @_compile_op.register(generic_ops.NotNullOp) + def _compile_standard_unaryops(self, op: Any, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + op_class = type(op) + ext_name = self._OP_TO_EXTENSION[op_class] + return self._compile_basic_unaryop(ext_name, inputs, child) + + @_compile_op.register(numeric_ops.PosOp) + def _compile_pos_op(self, op: numeric_ops.PosOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + # Unary plus is a no-op + return self._compile_expression(inputs[0], child) + + @_compile_op.register(numeric_ops.NegOp) + def _compile_neg_op(self, op: numeric_ops.NegOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + # Compile negation as subtraction: 0 - x + arg_expr = self._compile_expression(inputs[0], child) + arg_dtype = self._get_expression_dtype(inputs[0], child) + + zero_expr = algebra_pb2.Expression() + if arg_dtype == dtypes.FLOAT_DTYPE: + zero_expr.literal.fp64 = 0.0 + else: + zero_expr.literal.i64 = 0 + + sub_expr = algebra_pb2.Expression() + sub_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + sub_expr.scalar_function.arguments.add().value.CopyFrom(zero_expr) + sub_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return sub_expr + + @_compile_op.register(generic_ops.InvertOp) + def _compile_invert_op(self, op: generic_ops.InvertOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + arg_expr = self._compile_expression(inputs[0], child) + arg_dtype = self._get_expression_dtype(inputs[0], child) + + if arg_dtype == dtypes.BOOL_DTYPE: + # Logical negation + not_expr = algebra_pb2.Expression() + not_expr.scalar_function.function_reference = self._EXTENSIONS["not"] + not_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return not_expr + else: + # Bitwise negation (two's complement mathematically equivalent to: -x - 1) + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + + neg_expr = algebra_pb2.Expression() + neg_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + neg_expr.scalar_function.arguments.add().value.CopyFrom(zero_i64) + neg_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + + one_i64 = algebra_pb2.Expression() + one_i64.literal.i64 = 1 + + result_expr = algebra_pb2.Expression() + result_expr.scalar_function.function_reference = self._EXTENSIONS["subtract"] + result_expr.scalar_function.arguments.add().value.CopyFrom(neg_expr) + result_expr.scalar_function.arguments.add().value.CopyFrom(one_i64) + return result_expr + + def _compile_basic_unaryop(self, ext_name: str, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + pb_expr = algebra_pb2.Expression() + pb_expr.scalar_function.function_reference = self._EXTENSIONS[ext_name] + arg_expr = self._compile_expression(inputs[0], child) + pb_expr.scalar_function.arguments.add().value.CopyFrom(arg_expr) + return pb_expr + def _convert_schema(self, schema: Any) -> Dict[str, Any]: # Convert bigframes schema to Substrait Type.NamedStruct fields = [] @@ -345,6 +1299,15 @@ def _convert_schema(self, schema: Any) -> Dict[str, Any]: "struct": {"types": types} } + def _get_substrait_names(self, name: str, dtype: Any) -> list[str]: + import bigframes.dtypes as dtypes + names = [name] + if dtypes.is_struct_like(dtype): + fields_dict = dtypes.get_struct_fields(dtype) + for f_name, f_dtype in fields_dict.items(): + names.extend(self._get_substrait_names(f_name, f_dtype)) + return names + def _convert_type(self, dtype: Any) -> Dict[str, Any]: import bigframes.dtypes if dtype == bigframes.dtypes.INT_DTYPE: @@ -371,6 +1334,187 @@ def _convert_type(self, dtype: Any) -> Dict[str, Any]: return {"decimal": {"precision": arrow_dtype.precision, "scale": arrow_dtype.scale}} elif dtype == bigframes.dtypes.TIMEDELTA_DTYPE: return {"interval_day": {"precision": 6, "type_variation_reference": 1}} + elif bigframes.dtypes.is_struct_like(dtype): + fields_dict = bigframes.dtypes.get_struct_fields(dtype) + return {"struct": {"types": [self._convert_type(f_dtype) for f_dtype in fields_dict.values()]}} + elif bigframes.dtypes.is_array_like(dtype): + inner_dtype = bigframes.dtypes.get_array_inner_type(dtype) + return {"list": {"type": self._convert_type(inner_dtype)}} else: # Fallback to string for now return {"string": {}} + + @_compile_op.register(ops.ArrayReduceOp) + def _compile_array_reduce(self, op: ops.ArrayReduceOp, inputs: Sequence[ex.Expression], child: nodes.BigFrameNode) -> algebra_pb2.Expression: + import bigframes.operations.array_ops as arr_ops + import bigframes.operations.aggregations as agg_ops + import bigframes.dtypes as dtypes + + arr_expr = inputs[0] + if not (isinstance(arr_expr, ex.OpExpression) and isinstance(arr_expr.op, arr_ops.ToArrayOp)): + raise NotImplementedError(f"ArrayReduceOp only supported on ToArrayOp in Substrait compiler, got {type(arr_expr)}") + + array_inputs = arr_expr.inputs + if not array_inputs: + pb_expr = algebra_pb2.Expression() + pb_expr.literal.null.varchar.length = 0 + return pb_expr + + compiled_inputs = [self._compile_expression(inp, child) for inp in array_inputs] + input_dtypes = [self._get_expression_dtype(inp, child) for inp in array_inputs] + common_dtype = self._get_common_type(input_dtypes) + + # For boolean aggregates, all operands must be Boolean. + # For others, cast all elements to the common dtype first to resolve type coercion issues. + if isinstance(op.aggregation, (agg_ops.AllOp, agg_ops.AnyOp)): + casted_inputs = [self._compile_cast(expr, dtypes.BOOL_DTYPE) for expr in compiled_inputs] + else: + casted_inputs = [self._compile_cast(expr, common_dtype) for expr in compiled_inputs] + + def get_zero(): + z = algebra_pb2.Expression() + if common_dtype == dtypes.FLOAT_DTYPE: + z.literal.fp64 = 0.0 + else: + z.literal.i64 = 0 + return z + + def call_binary(ext_name, left, right): + pb = algebra_pb2.Expression() + pb.scalar_function.function_reference = self._EXTENSIONS[ext_name] + pb.scalar_function.arguments.add().value.CopyFrom(left) + pb.scalar_function.arguments.add().value.CopyFrom(right) + return pb + + def call_unary(ext_name, arg): + pb = algebra_pb2.Expression() + pb.scalar_function.function_reference = self._EXTENSIONS[ext_name] + pb.scalar_function.arguments.add().value.CopyFrom(arg) + return pb + + def cast_expr(expr, target): + return self._compile_cast(expr, target) + + if isinstance(op.aggregation, agg_ops.AllOp): + coalesced = [] + for ci in casted_inputs: + true_expr = algebra_pb2.Expression() + true_expr.literal.boolean = True + coalesced.append(call_binary("coalesce", ci, true_expr)) + + res = coalesced[0] + for next_expr in coalesced[1:]: + res = call_binary("and", res, next_expr) + return res + + elif isinstance(op.aggregation, agg_ops.AnyOp): + coalesced = [] + for ci in casted_inputs: + false_expr = algebra_pb2.Expression() + false_expr.literal.boolean = False + coalesced.append(call_binary("coalesce", ci, false_expr)) + + res = coalesced[0] + for next_expr in coalesced[1:]: + res = call_binary("or", res, next_expr) + return res + + elif isinstance(op.aggregation, agg_ops.SumOp): + coalesced = [] + for ci in casted_inputs: + coalesced.append(call_binary("coalesce", ci, get_zero())) + + res = coalesced[0] + for next_expr in coalesced[1:]: + res = call_binary("add", res, next_expr) + return res + + elif isinstance(op.aggregation, agg_ops.MinOp): + res = casted_inputs[0] + for next_expr in casted_inputs[1:]: + res = call_binary("least", res, next_expr) + return res + + elif isinstance(op.aggregation, agg_ops.MaxOp): + res = casted_inputs[0] + for next_expr in casted_inputs[1:]: + res = call_binary("greatest", res, next_expr) + return res + + elif isinstance(op.aggregation, agg_ops.MeanOp): + coalesced = [] + for ci in casted_inputs: + coalesced.append(call_binary("coalesce", ci, get_zero())) + sum_expr = coalesced[0] + for next_expr in coalesced[1:]: + sum_expr = call_binary("add", sum_expr, next_expr) + + counts = [] + for ci in compiled_inputs: + is_not_null_expr = call_unary("is_not_null", ci) + counts.append(cast_expr(is_not_null_expr, dtypes.INT_DTYPE)) + + count_expr = counts[0] + for next_expr in counts[1:]: + count_expr = call_binary("add", count_expr, next_expr) + + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + denom = call_binary("nullif", count_expr, zero_i64) + sum_float = cast_expr(sum_expr, dtypes.FLOAT_DTYPE) + denom_float = cast_expr(denom, dtypes.FLOAT_DTYPE) + return call_binary("divide", sum_float, denom_float) + + elif isinstance(op.aggregation, (agg_ops.VarOp, agg_ops.PopVarOp, agg_ops.StdOp)): + coalesced = [] + for ci in casted_inputs: + coalesced.append(call_binary("coalesce", ci, get_zero())) + sum_expr = coalesced[0] + for next_expr in coalesced[1:]: + sum_expr = call_binary("add", sum_expr, next_expr) + + counts = [] + for ci in compiled_inputs: + is_not_null_expr = call_unary("is_not_null", ci) + counts.append(cast_expr(is_not_null_expr, dtypes.INT_DTYPE)) + + count_expr = counts[0] + for next_expr in counts[1:]: + count_expr = call_binary("add", count_expr, next_expr) + + zero_i64 = algebra_pb2.Expression() + zero_i64.literal.i64 = 0 + denom = call_binary("nullif", count_expr, zero_i64) + sum_float = cast_expr(sum_expr, dtypes.FLOAT_DTYPE) + denom_float = cast_expr(denom, dtypes.FLOAT_DTYPE) + mean_expr = call_binary("divide", sum_float, denom_float) + + sq_diffs = [] + for ci in casted_inputs: + diff = call_binary("subtract", ci, mean_expr) + sq_diff = call_binary("multiply", diff, diff) + + zero_float = algebra_pb2.Expression() + zero_float.literal.fp64 = 0.0 + sq_diffs.append(call_binary("coalesce", sq_diff, zero_float)) + + sum_sq_diff = sq_diffs[0] + for next_expr in sq_diffs[1:]: + sum_sq_diff = call_binary("add", sum_sq_diff, next_expr) + + if isinstance(op.aggregation, agg_ops.PopVarOp): + denom_var = call_binary("nullif", count_expr, zero_i64) + else: + one_i64 = algebra_pb2.Expression() + one_i64.literal.i64 = 1 + count_minus_one = call_binary("subtract", count_expr, one_i64) + denom_var = call_binary("nullif", count_minus_one, zero_i64) + denom_var_float = cast_expr(denom_var, dtypes.FLOAT_DTYPE) + var_expr = call_binary("divide", sum_sq_diff, denom_var_float) + + if isinstance(op.aggregation, agg_ops.StdOp): + return call_unary("sqrt", var_expr) + return var_expr + + else: + raise NotImplementedError(f"Array reduction aggregate {type(op.aggregation)} not supported in Substrait compiler yet") diff --git a/packages/bigframes/bigframes/core/compile/substrait/test_multiindex_drop.py b/packages/bigframes/bigframes/core/compile/substrait/test_multiindex_drop.py new file mode 100644 index 000000000000..35a7e3e1d858 --- /dev/null +++ b/packages/bigframes/bigframes/core/compile/substrait/test_multiindex_drop.py @@ -0,0 +1,28 @@ +import bigframes +import bigframes.pandas as bpd +import pandas as pd +import sys + +# Initialize session +bpd.options.compute.backend = "substrait" + +df = bpd.read_pandas(pd.DataFrame({ + "bytes_col": [b"a", b"b", b"c", b"d", b"e", b"f", b"g"], + "numeric_col": [1, 2, 3, 4, 5, 6, 7], + "val": [10, 20, 30, 40, 50, 60, 70] +})) + +sub_df = df.iloc[[4, 1, 2]] +sub_df = sub_df.set_index(["bytes_col", "numeric_col"]) +drop_index = sub_df.index + +df = df.set_index(["bytes_col", "numeric_col"]) + +print("DF INDEX:") +print(df.index) +print("DROP INDEX:") +print(drop_index) + +res = df.drop(index=drop_index) +print("RESULT:") +print(res.to_pandas()) diff --git a/packages/bigframes/bigframes/session/substrait_executor.py b/packages/bigframes/bigframes/session/substrait_executor.py index 1363255c9db9..94a148d5bace 100644 --- a/packages/bigframes/bigframes/session/substrait_executor.py +++ b/packages/bigframes/bigframes/session/substrait_executor.py @@ -76,9 +76,11 @@ def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table: import datafusion.substrait - #json_str = plan.decode('utf-8') - #print("DEBUG RE-SERIALIZED JSON SUBSTRAIT PLAN:") - #print(plan_obj.to_json()) + import substrait.plan_pb2 as plan_pb2 + from google.protobuf import json_format + plan_obj = plan_pb2.Plan.FromString(plan_proto) + print("DEBUG PLAN JSON:") + print(json_format.MessageToJson(plan_obj)) datafusion_substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_proto) logical_plan = datafusion.substrait.Consumer.from_substrait_plan(ctx, datafusion_substrait_plan) df = ctx.create_dataframe_from_logical_plan(logical_plan) @@ -102,24 +104,158 @@ def execute( ordered: bool, peek: Optional[int] = None, ) -> Optional[executor.ExecuteResult]: - rewritten_plan = plan.bottom_up(slices_rewrite.rewrite_slice) - + def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode: + if isinstance(node, nodes.PromoteOffsetsNode): + res = self.execute(node.child, ordered=ordered) + if res is None: + return node + table = res.batches().to_arrow_table() + import pyarrow as pa + table = table.append_column(node.col_id.name, pa.array(range(len(table)), type=pa.int64())) + + from bigframes.core import local_data, identifiers + from bigframes.core.schema import ArraySchema, SchemaItem + import bigframes.dtypes + + schema_items = [] + for col_name in table.column_names: + if col_name == node.col_id.name: + schema_items.append(SchemaItem(col_name, bigframes.dtypes.INT_DTYPE)) + else: + schema_items.append(SchemaItem(col_name, node.child.schema.get_type(col_name))) + new_schema = ArraySchema(tuple(schema_items)) + + scan_items = [] + for col_name in table.column_names: + col_id = identifiers.ColumnId(col_name) + scan_items.append(nodes.ScanItem(col_id, col_name)) + scan_list = nodes.ScanList(tuple(scan_items)) + + session = None + for child_node in node.child.unique_nodes(): + if isinstance(child_node, nodes.ReadLocalNode): + session = child_node.session + break + + managed_table = local_data.ManagedArrowTable.from_pyarrow(table, schema=new_schema) + new_node = nodes.ReadLocalNode( + local_data_source=managed_table, + scan_list=scan_list, + session=session, + offsets_col=None, + ) + return new_node + return node + + # 1. Rewrite all SliceNodes to standard Selection/Filter/Projection/PromoteOffsetsNodes + plan = plan.bottom_up(slices_rewrite.rewrite_slice) + + # 2. Resolve all PromoteOffsetsNodes to concrete local tables + plan = plan.bottom_up(resolve_promote_offsets) + + # 3. Wrap plan in a ResultNode to apply defer_order + from bigframes.core import expression, rewrite + output_cols = tuple((expression.DerefOp(id), id.name) for id in plan.ids) + result_node = nodes.ResultNode( + plan, + output_cols=output_cols, + ) + import typing + result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) + result_node = rewrite.defer_order(result_node, output_hidden_row_keys=False) + + rewritten_plan = result_node.child + + # 4. Apply outermost sorting if ordered + if ordered and result_node.order_by and result_node.order_by.all_ordering_columns: + rewritten_plan = nodes.OrderByNode( + rewritten_plan, + by=tuple(result_node.order_by.all_ordering_columns), + ) + + # 5. Project only the original output columns to preserve correct result schema + original_ids = tuple(id for id in plan.ids) + if rewritten_plan.ids != original_ids: + rewritten_plan = nodes.SelectionNode( + rewritten_plan, + input_output_pairs=tuple(nodes.AliasedRef.identity(id) for id in original_ids) + ) + if not self._can_execute(rewritten_plan): return None substrait_plan_proto = self._compiler.compile(rewritten_plan) - if substrait_plan_proto is None: return None + import google.protobuf.json_format as json_format + from substrait.plan_pb2 import Plan + plan_proto = Plan() + plan_proto.ParseFromString(substrait_plan_proto) + import os + import uuid + os.makedirs("/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch", exist_ok=True) + filename = f"/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch/plan_{rewritten_plan.__class__.__name__}_{uuid.uuid4().hex[:8]}.json" + with open(filename, "w") as f: + f.write(json_format.MessageToJson(plan_proto)) + tables = {} for node in rewritten_plan.unique_nodes(): if isinstance(node, nodes.ReadLocalNode): - table_name = f"table_{node.local_data_source.id.hex}" - tables[table_name] = node.local_data_source.data + table_name = f"table_{id(node)}" + table = node.local_data_source.data + table = table.select([item.source_id for item in node.scan_list.items]) + table = table.rename_columns([item.id.sql for item in node.scan_list.items]) + if node.offsets_col is not None: + from bigframes.core import pyarrow_utils + table = pyarrow_utils.append_offsets(table, node.offsets_col.sql) + tables[table_name] = table pa_table = self._consumer.consume(substrait_plan_proto, tables) + # Sanitize pa_table: replace inf/nan/is_inf with null for INT_DTYPE columns + import pyarrow.compute as pc + import bigframes.dtypes as dtypes + import pyarrow as pa + sanitized_columns = [] + for col_name in pa_table.column_names: + col_data = pa_table.column(col_name) + try: + expected_dtype = rewritten_plan.schema.get_type(col_name) + except ValueError: + expected_dtype = None + + if expected_dtype == dtypes.INT_DTYPE and pa.types.is_floating(col_data.type): + is_nan = pc.is_nan(col_data) + is_inf = pc.is_inf(col_data) + is_invalid = pc.or_(is_nan, is_inf) + null_val = pa.scalar(None, type=col_data.type) + col_data = pc.if_else(is_invalid, null_val, col_data) + sanitized_columns.append(col_data) + pa_table = pa.Table.from_arrays(sanitized_columns, names=pa_table.column_names) + + # Handle SliceNode post-processing + for node in rewritten_plan.unique_nodes(): + if isinstance(node, nodes.SliceNode): + is_simple = (node.start is None or node.start >= 0) and (node.stop is None or node.stop >= 0) and (node.step is None or node.step == 1) + if not is_simple: + df = pa_table.to_pandas() + df = df.iloc[node.start:node.stop:node.step] + pa_table = pa.Table.from_pandas(df, schema=pa_table.schema) + offset_cols = set() + for node in rewritten_plan.unique_nodes(): + if isinstance(node, nodes.PromoteOffsetsNode): + offset_cols.add(node.col_id.name) + + for col_name in pa_table.column_names: + if col_name in offset_cols: + idx = pa_table.column_names.index(col_name) + pa_table = pa_table.set_column(idx, col_name, pa.array(range(len(pa_table)), type=pa.int64())) + + import sys + sys.stderr.write(f"PA_TABLE ON EXECUTE:\n{pa_table.to_pandas()}\n") + sys.stderr.flush() + if peek is not None: pa_table = pa_table.slice(0, peek) diff --git a/packages/bigframes/tests/unit/test_dataframe_substrait.py b/packages/bigframes/tests/unit/test_dataframe_substrait.py index 701e3f413f84..bb2b5acb19fe 100644 --- a/packages/bigframes/tests/unit/test_dataframe_substrait.py +++ b/packages/bigframes/tests/unit/test_dataframe_substrait.py @@ -1219,7 +1219,6 @@ def test_df_fillna(scalars_dfs, col, fill_value): pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) -@pytest.mark.skip("b/436316698 unit test failed for python 3.12") def test_df_ffill(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs bf_result = scalars_df[["int64_col", "float64_col"]].ffill(limit=1).to_pandas() @@ -3369,6 +3368,8 @@ def test_dataframe_aggregates(scalars_dfs, op, bf_dtype): pd_result = op(scalars_pandas_df_index[col_names]) # Check dtype separately + print("PD_RESULT:\n", pd_result) + print("BF_RESULT:\n", bf_result.to_pandas()) assert bf_result.dtype == bf_dtype # Pandas may produce narrower numeric types, but bigframes always produces Float64 @@ -4490,3 +4491,31 @@ def test_recursion_limit_unit(scalars_df_index): for i in range(250): scalars_df_index = scalars_df_index + 4 scalars_df_index.to_pandas() + + +def test_dataframe_popvar(scalars_dfs): + scalars_df_index, scalars_pandas_df_index = scalars_dfs + col_names = ["int64_too", "float64_col", "int64_col"] + from bigframes.operations import aggregations as agg_ops + from bigframes.core import agg_expressions + import bigframes.core.expression as ex + + col_ids = [scalars_df_index._block.resolve_label_exact(col) for col in col_names] + aggs = [agg_expressions.UnaryAggregation(agg_ops.PopVarOp(), ex.deref(col)) for col in col_ids] + + agg_block = scalars_df_index._block.aggregate(aggs, column_labels=pandas.Index(col_names)) + + import bigframes.dataframe as bfd + bf_result = bfd.DataFrame(agg_block).to_pandas() + bf_result = bf_result.iloc[0] + + pd_result = scalars_pandas_df_index[col_names].var(ddof=0) + pd_result.index = pd_result.index.astype("string[pyarrow]") + + pandas.testing.assert_series_equal( + pd_result, + bf_result, + check_dtype=False, + check_index_type=False, + check_names=False, + ) From e34b4abbc24ef734dc5eb8f0da3c2c4cacb60963 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 20 May 2026 20:47:31 +0000 Subject: [PATCH 5/5] more work --- .../core/compile/substrait/compiler.py | 48 ++----- .../bigframes/session/substrait_executor.py | 122 +----------------- .../tests/system/small/engines/conftest.py | 15 ++- .../system/small/engines/test_aggregation.py | 2 +- 4 files changed, 29 insertions(+), 158 deletions(-) diff --git a/packages/bigframes/bigframes/core/compile/substrait/compiler.py b/packages/bigframes/bigframes/core/compile/substrait/compiler.py index 46a266be2011..3ad4867eec64 100644 --- a/packages/bigframes/bigframes/core/compile/substrait/compiler.py +++ b/packages/bigframes/bigframes/core/compile/substrait/compiler.py @@ -39,18 +39,6 @@ class SubstraitCompiler: """ Compiles BigFrameNode plans to Substrait schema (JSON representation). """ - - def _print_node_tree(self, node: bigframe_node.BigFrameNode, indent: int = 0): - import sys - try: - ids = list(node.ids) - except Exception as e: - ids = f"" - sys.stderr.write(" " * indent + f"- {type(node).__name__}: ids={ids}\n") - sys.stderr.flush() - for child in node.child_nodes: - self._print_node_tree(child, indent + 1) - def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: """ Compiles a BigFrameNode to Substrait bytes (JSON encoded via protobuf). @@ -58,11 +46,6 @@ def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: if not self.can_compile(plan): return None - import sys - sys.stderr.write("DEBUG TREE:\n") - sys.stderr.flush() - self._print_node_tree(plan) - pb_rel = self._compile_node(plan) pb_plan = plan_pb2.Plan() @@ -84,7 +67,6 @@ def compile(self, plan: bigframe_node.BigFrameNode) -> Optional[bytes]: def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: """ Checks if the plan can be compiled to Substrait. - For the skeleton, we support ReadLocalNode, SelectionNode, and FilterNode. """ supported_nodes = ( nodes.ReadLocalNode, @@ -95,7 +77,6 @@ def can_compile(self, plan: bigframe_node.BigFrameNode) -> bool: nodes.JoinNode, nodes.AggregateNode, nodes.OrderByNode, - nodes.PromoteOffsetsNode, nodes.WindowOpNode, nodes.ConcatNode, ) @@ -123,8 +104,6 @@ def _compile_node(self, node: bigframe_node.BigFrameNode) -> algebra_pb2.Rel: return self._compile_orderby(node) elif isinstance(node, nodes.SliceNode): return self._compile_slice(node) - elif isinstance(node, nodes.PromoteOffsetsNode): - return self._compile_promote_offsets(node) elif isinstance(node, nodes.WindowOpNode): return self._compile_window(node) elif isinstance(node, nodes.ConcatNode): @@ -180,22 +159,6 @@ def _compile_selection(self, node: nodes.SelectionNode) -> algebra_pb2.Rel: return rel - def _compile_promote_offsets(self, node: nodes.PromoteOffsetsNode) -> algebra_pb2.Rel: - input_rel = self._compile_node(node.child) - - rel = algebra_pb2.Rel() - project_rel = rel.project - project_rel.input.CopyFrom(input_rel) - - # Add a dummy literal i64 = 0 for the offsets column - expr = project_rel.expressions.add() - expr.literal.i64 = 0 - - child_ids = list(node.child.ids) - project_rel.common.emit.output_mapping.extend(range(len(child_ids) + 1)) - - return rel - def _compile_filter(self, node: nodes.FilterNode) -> algebra_pb2.Rel: input_rel = self._compile_node(node.child) @@ -684,6 +647,10 @@ def _compile_aggregate(self, node: nodes.AggregateNode) -> algebra_pb2.Rel: func_ref = self._EXTENSIONS["product"] elif isinstance(agg.op, agg_ops.MedianOp): func_ref = self._EXTENSIONS["median"] + elif isinstance(agg.op, agg_ops.CovOp): + func_ref = self._EXTENSIONS["cov"] + elif isinstance(agg.op, agg_ops.CorrOp): + func_ref = self._EXTENSIONS["corr"] else: raise NotImplementedError(f"Aggregation {type(agg.op)} not supported in Substrait compiler yet") @@ -846,6 +813,9 @@ def _compile_slice(self, node: nodes.SliceNode) -> algebra_pb2.Rel: "lead": 66, "struct": 67, "get_field": 68, + "pow": 69, + "cov": 70, + "corr": 71, } _OP_TO_EXTENSION = { @@ -854,6 +824,8 @@ def _compile_slice(self, node: nodes.SliceNode) -> algebra_pb2.Rel: numeric_ops.MulOp: "multiply", numeric_ops.DivOp: "divide", numeric_ops.ModOp: "mod", + numeric_ops.PowOp: "pow", + numeric_ops.UnsafePowOp: "pow", comparison_ops.EqOp: "equal", comparison_ops.NeOp: "not_equal", comparison_ops.LtOp: "lt", @@ -1156,6 +1128,8 @@ def _compile_fillna_op(self, op: generic_ops.FillNaOp, inputs: Sequence[ex.Expre @_compile_op.register(numeric_ops.AddOp) @_compile_op.register(numeric_ops.SubOp) @_compile_op.register(numeric_ops.MulOp) + @_compile_op.register(numeric_ops.PowOp) + @_compile_op.register(numeric_ops.UnsafePowOp) @_compile_op.register(comparison_ops.EqOp) @_compile_op.register(comparison_ops.NeOp) @_compile_op.register(comparison_ops.LtOp) diff --git a/packages/bigframes/bigframes/session/substrait_executor.py b/packages/bigframes/bigframes/session/substrait_executor.py index 94a148d5bace..138d6b7e7b37 100644 --- a/packages/bigframes/bigframes/session/substrait_executor.py +++ b/packages/bigframes/bigframes/session/substrait_executor.py @@ -21,6 +21,7 @@ from bigframes.session import executor, semi_executor import bigframes.core.rewrite.slices as slices_rewrite from bigframes.core import nodes +import asyncio if TYPE_CHECKING: import pyarrow as pa @@ -62,25 +63,14 @@ def consume(self, plan_proto: bytes, tables: dict[str, pa.Table]) -> pa.Table: "Install it with `pip install datafusion`." ) - # Create a DataFusion context ctx = datafusion.SessionContext() for name, table in tables.items(): df = ctx.from_arrow_table(table) ctx.register_table(name, df) - # NOTE: The actual API for running Substrait in DataFusion python bindings may vary. - # Assuming something like ctx.from_substrait(plan) or ctx.execute_substrait(plan). - # We will need to verify this with the actual datafusion python package if available. - # For now, we raise NotImplementedError if we cannot find the method, or try a likely one. - import datafusion.substrait - import substrait.plan_pb2 as plan_pb2 - from google.protobuf import json_format - plan_obj = plan_pb2.Plan.FromString(plan_proto) - print("DEBUG PLAN JSON:") - print(json_format.MessageToJson(plan_obj)) datafusion_substrait_plan = datafusion.substrait.Serde.deserialize_bytes(plan_proto) logical_plan = datafusion.substrait.Consumer.from_substrait_plan(ctx, datafusion_substrait_plan) df = ctx.create_dataframe_from_logical_plan(logical_plan) @@ -98,62 +88,14 @@ def __init__(self, consumer: SubstraitConsumer): from bigframes.core.compile.substrait.compiler import SubstraitCompiler self._compiler = SubstraitCompiler() - def execute( + async def execute( self, plan: bigframe_node.BigFrameNode, ordered: bool, peek: Optional[int] = None, ) -> Optional[executor.ExecuteResult]: - def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode: - if isinstance(node, nodes.PromoteOffsetsNode): - res = self.execute(node.child, ordered=ordered) - if res is None: - return node - table = res.batches().to_arrow_table() - import pyarrow as pa - table = table.append_column(node.col_id.name, pa.array(range(len(table)), type=pa.int64())) - - from bigframes.core import local_data, identifiers - from bigframes.core.schema import ArraySchema, SchemaItem - import bigframes.dtypes - - schema_items = [] - for col_name in table.column_names: - if col_name == node.col_id.name: - schema_items.append(SchemaItem(col_name, bigframes.dtypes.INT_DTYPE)) - else: - schema_items.append(SchemaItem(col_name, node.child.schema.get_type(col_name))) - new_schema = ArraySchema(tuple(schema_items)) - - scan_items = [] - for col_name in table.column_names: - col_id = identifiers.ColumnId(col_name) - scan_items.append(nodes.ScanItem(col_id, col_name)) - scan_list = nodes.ScanList(tuple(scan_items)) - - session = None - for child_node in node.child.unique_nodes(): - if isinstance(child_node, nodes.ReadLocalNode): - session = child_node.session - break - - managed_table = local_data.ManagedArrowTable.from_pyarrow(table, schema=new_schema) - new_node = nodes.ReadLocalNode( - local_data_source=managed_table, - scan_list=scan_list, - session=session, - offsets_col=None, - ) - return new_node - return node - - # 1. Rewrite all SliceNodes to standard Selection/Filter/Projection/PromoteOffsetsNodes plan = plan.bottom_up(slices_rewrite.rewrite_slice) - # 2. Resolve all PromoteOffsetsNodes to concrete local tables - plan = plan.bottom_up(resolve_promote_offsets) - - # 3. Wrap plan in a ResultNode to apply defer_order from bigframes.core import expression, rewrite output_cols = tuple((expression.DerefOp(id), id.name) for id in plan.ids) result_node = nodes.ResultNode( @@ -166,14 +108,12 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B rewritten_plan = result_node.child - # 4. Apply outermost sorting if ordered if ordered and result_node.order_by and result_node.order_by.all_ordering_columns: rewritten_plan = nodes.OrderByNode( rewritten_plan, by=tuple(result_node.order_by.all_ordering_columns), ) - # 5. Project only the original output columns to preserve correct result schema original_ids = tuple(id for id in plan.ids) if rewritten_plan.ids != original_ids: rewritten_plan = nodes.SelectionNode( @@ -188,17 +128,6 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B if substrait_plan_proto is None: return None - import google.protobuf.json_format as json_format - from substrait.plan_pb2 import Plan - plan_proto = Plan() - plan_proto.ParseFromString(substrait_plan_proto) - import os - import uuid - os.makedirs("/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch", exist_ok=True) - filename = f"/usr/local/google/home/tbergeron/src/google-cloud-python/packages/bigframes/scratch/plan_{rewritten_plan.__class__.__name__}_{uuid.uuid4().hex[:8]}.json" - with open(filename, "w") as f: - f.write(json_format.MessageToJson(plan_proto)) - tables = {} for node in rewritten_plan.unique_nodes(): if isinstance(node, nodes.ReadLocalNode): @@ -211,52 +140,9 @@ def resolve_promote_offsets(node: bigframe_node.BigFrameNode) -> bigframe_node.B table = pyarrow_utils.append_offsets(table, node.offsets_col.sql) tables[table_name] = table - pa_table = self._consumer.consume(substrait_plan_proto, tables) - - # Sanitize pa_table: replace inf/nan/is_inf with null for INT_DTYPE columns - import pyarrow.compute as pc - import bigframes.dtypes as dtypes - import pyarrow as pa - sanitized_columns = [] - for col_name in pa_table.column_names: - col_data = pa_table.column(col_name) - try: - expected_dtype = rewritten_plan.schema.get_type(col_name) - except ValueError: - expected_dtype = None - - if expected_dtype == dtypes.INT_DTYPE and pa.types.is_floating(col_data.type): - is_nan = pc.is_nan(col_data) - is_inf = pc.is_inf(col_data) - is_invalid = pc.or_(is_nan, is_inf) - null_val = pa.scalar(None, type=col_data.type) - col_data = pc.if_else(is_invalid, null_val, col_data) - sanitized_columns.append(col_data) - pa_table = pa.Table.from_arrays(sanitized_columns, names=pa_table.column_names) - - # Handle SliceNode post-processing - for node in rewritten_plan.unique_nodes(): - if isinstance(node, nodes.SliceNode): - is_simple = (node.start is None or node.start >= 0) and (node.stop is None or node.stop >= 0) and (node.step is None or node.step == 1) - if not is_simple: - df = pa_table.to_pandas() - df = df.iloc[node.start:node.stop:node.step] - pa_table = pa.Table.from_pandas(df, schema=pa_table.schema) - offset_cols = set() - for node in rewritten_plan.unique_nodes(): - if isinstance(node, nodes.PromoteOffsetsNode): - offset_cols.add(node.col_id.name) - - for col_name in pa_table.column_names: - if col_name in offset_cols: - idx = pa_table.column_names.index(col_name) - pa_table = pa_table.set_column(idx, col_name, pa.array(range(len(pa_table)), type=pa.int64())) - - import sys - sys.stderr.write(f"PA_TABLE ON EXECUTE:\n{pa_table.to_pandas()}\n") - sys.stderr.flush() + pa_table = await asyncio.to_thread(self._consumer.consume, substrait_plan_proto, tables) - if peek is not None: + if peek is not None: pa_table = pa_table.slice(0, peek) return executor.LocalExecuteResult( diff --git a/packages/bigframes/tests/system/small/engines/conftest.py b/packages/bigframes/tests/system/small/engines/conftest.py index 758b697f25e2..36897efac21f 100644 --- a/packages/bigframes/tests/system/small/engines/conftest.py +++ b/packages/bigframes/tests/system/small/engines/conftest.py @@ -26,6 +26,7 @@ local_scan_executor, polars_executor, semi_executor, + substrait_executor, ) CURRENT_DIR = pathlib.Path(__file__).parent @@ -81,9 +82,17 @@ def sqlglot_engine( ) -@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot"]) +@pytest.fixture(scope="session") +def substrait_datafusion_engine( +) -> semi_executor.SemiExecutor: + return substrait_executor.SubstraitExecutor( + consumer = substrait_executor.DataFusionSubstraitConsumer() + ) + + +@pytest.fixture(scope="session", params=["pyarrow", "polars", "bq", "bq-sqlglot", "substrait-datafusion"]) def engine( - request, pyarrow_engine, polars_engine, bq_engine, sqlglot_engine + request, pyarrow_engine, polars_engine, bq_engine, sqlglot_engine, substrait_datafusion_engine ) -> semi_executor.SemiExecutor: if request.param == "pyarrow": return pyarrow_engine @@ -93,6 +102,8 @@ def engine( return bq_engine if request.param == "bq-sqlglot": return sqlglot_engine + if request.param == "substrait-datafusion": + return substrait_datafusion_engine raise ValueError(f"Unrecognized param: {request.param}") diff --git a/packages/bigframes/tests/system/small/engines/test_aggregation.py b/packages/bigframes/tests/system/small/engines/test_aggregation.py index e6e4ac571578..19ff0ea8f6c9 100644 --- a/packages/bigframes/tests/system/small/engines/test_aggregation.py +++ b/packages/bigframes/tests/system/small/engines/test_aggregation.py @@ -55,7 +55,7 @@ def apply_agg_to_all_valid( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot", "substrait-datafusion"], indirect=True) def test_engines_aggregate_post_filter_size( scalars_array_value: array_value.ArrayValue, engine,