Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 113 additions & 16 deletions packages/bigframes/bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(self):
# Lock to synchronize the update of the session artifacts
self._artifacts_lock = threading.Lock()

self._deployed_routines: set[bytes] = set()
self._deploying_routines: set[bytes] = set()

def _resolve_session(self, session: Optional[Session]) -> Session:
"""Resolves the BigFrames session."""
import bigframes.pandas as bpd
Expand Down Expand Up @@ -191,6 +194,81 @@ def _update_temp_artifacts(self, bqrf_routine: str, gcf_path: str):
with self._artifacts_lock:
self._temp_artifacts[bqrf_routine] = gcf_path

def deploy_undeployed_udf(
self,
session: Session,
bq_udf: udf_def.PythonUdf,
) -> udf_def.BigqueryUdf:
"""Deploys a UDF to BigQuery if not already deployed."""
udf_hash = bq_udf.stable_hash()
import time

bigquery_client = self._resolve_bigquery_client(session, None)
bq_connection_manager = session.bqconnectionmanager
dataset_ref = self._resolve_dataset_reference(session, bigquery_client, None)
bq_location, _ = _utils.get_remote_function_locations(bigquery_client.location)

managed_function_client = _function_client.FunctionClient(
dataset_ref.project,
bq_location,
dataset_ref.dataset_id,
bigquery_client,
bq_connection_manager,
session=session,
)

config = bq_udf.to_managed_function_config()
bq_function_name = _function_client.get_managed_function_name(
config, session.session_id
)
full_rf_name = (
managed_function_client.get_remote_function_fully_qualilfied_name(
bq_function_name
)
)
routine_ref = bigquery.RoutineReference.from_string(full_rf_name)

with self._artifacts_lock:
if udf_hash in self._deployed_routines:
return udf_def.BigqueryUdf(
routine_ref=routine_ref,
signature=bq_udf.signature,
)

while True:
with self._artifacts_lock:
if udf_hash in self._deployed_routines:
return udf_def.BigqueryUdf(
routine_ref=routine_ref,
signature=bq_udf.signature,
)

if udf_hash not in self._deploying_routines:
self._deploying_routines.add(udf_hash)
break

time.sleep(0.2)

try:
managed_function_client.provision_bq_managed_function(
name=bq_function_name,
config=config,
)
except Exception:
with self._artifacts_lock:
self._deploying_routines.discard(udf_hash)
raise

with self._artifacts_lock:
self._deploying_routines.discard(udf_hash)
self._deployed_routines.add(udf_hash)
self._temp_artifacts[full_rf_name] = ""

return udf_def.BigqueryUdf(
routine_ref=routine_ref,
signature=bq_udf.signature,
)

def clean_up(
self,
bqclient: bigquery.Client,
Expand Down Expand Up @@ -679,6 +757,8 @@ def udf(
max_batching_rows: Optional[int] = None,
container_cpu: Optional[float] = None,
container_memory: Optional[str] = None,
*,
_force_deploy: bool = False,
):
"""Decorator to turn a Python user defined function (udf) into a
BigQuery managed function.
Expand Down Expand Up @@ -835,27 +915,46 @@ def wrapper(func):
capture_references=False,
)

bq_function_name = managed_function_client.provision_bq_managed_function(
name=name,
config=config,
)
full_rf_name = (
managed_function_client.get_remote_function_fully_qualilfied_name(
bq_function_name
)
requirements = udf_def.RuntimeRequirements(
container_cpu=container_cpu,
container_memory=container_memory,
bq_connection_id=bq_connection_id,
max_batching_rows=max_batching_rows,
packages=tuple(packages) if packages else (),
)

udf_definition = udf_def.BigqueryUdf(
routine_ref=bigquery.RoutineReference.from_string(full_rf_name),
signature=udf_sig,
)
if (
not name and not _force_deploy
): # session-owned resource - deferred deployment
udf_definition = udf_def.PythonUdf(
signature=udf_sig,
code=code_def,
requirements=requirements,
)
else:
bq_function_name = (
managed_function_client.provision_bq_managed_function(
name=name,
config=config,
)
)
full_rf_name = (
managed_function_client.get_remote_function_fully_qualilfied_name(
bq_function_name
)
)
udf_definition = udf_def.BigqueryUdf(
routine_ref=bigquery.RoutineReference.from_string(full_rf_name),
signature=udf_sig,
)

if udf_sig.is_row_processor:
msg = bfe.format_message("input_types=Series is in preview.")
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)

if not name: # session-owned resource - will be cleaned up automatically
self._update_temp_artifacts(full_rf_name, "")
if _force_deploy:
self._update_temp_artifacts(full_rf_name, "")
return bq_functions.UdfRoutine(func=func, _udf_def=udf_definition)

# user-managed permanent resource - will not be cleaned up automatically
Expand Down Expand Up @@ -888,9 +987,7 @@ def deploy_udf(
A wrapped Python user defined function, usable in
:meth:`~bigframes.series.Series.apply`.
"""
# TODO(tswast): If we update udf to defer deployment, update this method
# to deploy immediately.
return self.udf(**kwargs)(func)
return self.udf(_force_deploy=True, **kwargs)(func)


def _resolve_signature(
Expand Down
8 changes: 4 additions & 4 deletions packages/bigframes/bigframes/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import dataclasses
import logging
from typing import TYPE_CHECKING, Callable, Optional, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Callable, Optional, Protocol, Union, runtime_checkable

import google.api_core.exceptions
from google.cloud import bigquery
Expand Down Expand Up @@ -162,7 +162,7 @@ class Udf(Protocol):
"""

@property
def udf_def(self) -> udf_def.BigqueryUdf: ...
def udf_def(self) -> Union[udf_def.BigqueryUdf, udf_def.PythonUdf]: ...


class BigqueryCallableRoutine:
Expand Down Expand Up @@ -242,11 +242,11 @@ class UdfRoutine:
func: Callable
# Try not to depend on this, bq managed function creation will be deferred later
# And this ref will be replaced with requirements rather to support lazy creation
_udf_def: udf_def.BigqueryUdf
_udf_def: Union[udf_def.BigqueryUdf, udf_def.PythonUdf]

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

@property
def udf_def(self) -> udf_def.BigqueryUdf:
def udf_def(self) -> Union[udf_def.BigqueryUdf, udf_def.PythonUdf]:
return self._udf_def
55 changes: 55 additions & 0 deletions packages/bigframes/bigframes/functions/udf_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,30 @@ def stable_hash(self) -> bytes:
return hash_val.digest()


@dataclasses.dataclass(frozen=True)
class RuntimeRequirements:
container_cpu: Optional[float] = None
container_memory: Optional[str] = None
bq_connection_id: Optional[str] = None
max_batching_rows: Optional[int] = None
packages: tuple[str, ...] = ()

def stable_hash(self) -> bytes:
hash_val = google_crc32c.Checksum()
if self.container_cpu is not None:
hash_val.update(str(self.container_cpu).encode())
if self.container_memory is not None:
hash_val.update(str(self.container_memory).encode())
if self.bq_connection_id is not None:
hash_val.update(str(self.bq_connection_id).encode())
if self.max_batching_rows is not None:
hash_val.update(str(self.max_batching_rows).encode())
if self.packages:
for p in sorted(self.packages):
hash_val.update(p.encode())
return hash_val.digest()


@dataclasses.dataclass(frozen=True)
class BigqueryUdf:
"""
Expand Down Expand Up @@ -398,6 +422,37 @@ def from_routine(
return cls(routine.reference, signature=signature)


@dataclasses.dataclass(frozen=True)
class PythonUdf:
"""
Represents user-requested Python UDF semantics, including the code and runtime requirements.
"""

signature: UdfSignature
code: CodeDef
requirements: RuntimeRequirements = dataclasses.field(
default_factory=RuntimeRequirements
)

def stable_hash(self) -> bytes:
hash_val = google_crc32c.Checksum()
hash_val.update(self.code.stable_hash())
hash_val.update(self.signature.stable_hash())
hash_val.update(self.requirements.stable_hash())
return hash_val.digest()

def to_managed_function_config(self) -> ManagedFunctionConfig:
return ManagedFunctionConfig(
code=self.code,
signature=self.signature,
max_batching_rows=self.requirements.max_batching_rows,
container_cpu=self.requirements.container_cpu,
container_memory=self.requirements.container_memory,
bq_connection_id=self.requirements.bq_connection_id,
capture_references=False,
)


@dataclasses.dataclass(frozen=True)
class CodeDef:
# Produced by cloudpickle, not compatible across python versions
Expand Down
2 changes: 1 addition & 1 deletion packages/bigframes/bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def udf(
output_type: Optional[type] = None,
dataset: str,
bigquery_connection: Optional[str] = None,
name: str,
name: Optional[str] = None,
packages: Optional[Sequence[str]] = None,
max_batching_rows: Optional[int] = None,
container_cpu: Optional[float] = None,
Expand Down
2 changes: 1 addition & 1 deletion packages/bigframes/bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,7 +1958,7 @@ def udf(
output_type: Optional[type] = None,
dataset: str,
bigquery_connection: Optional[str] = None,
name: str,
name: Optional[str] = None,
packages: Optional[Sequence[str]] = None,
max_batching_rows: Optional[int] = None,
container_cpu: Optional[float] = None,
Expand Down
78 changes: 78 additions & 0 deletions packages/bigframes/bigframes/session/bq_caching_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,90 @@ def _prepare_plan_simplify(self, plan: nodes.BigFrameNode) -> nodes.BigFrameNode
plan = plan.top_down(rewrite.fold_row_counts)
return plan

async def _deploy_undeployed_udfs(
self, plan: nodes.BigFrameNode
) -> nodes.BigFrameNode:
import dataclasses

import bigframes.core.expression as expression
import bigframes.functions.udf_def as udf_def
import bigframes.operations as ops

undeployed_udfs: list[udf_def.PythonUdf] = []
for node in plan.unique_nodes():
for expr in node._node_expressions:
for sub_expr in expr.walk():
if isinstance(sub_expr, expression.OpExpression):
op = sub_expr.op
if isinstance(
op,
(
ops.RemoteFunctionOp,
ops.BinaryRemoteFunctionOp,
ops.NaryRemoteFunctionOp,
),
):
func_def = op.function_def
if isinstance(func_def, udf_def.PythonUdf):
undeployed_udfs.append(func_def)

if not undeployed_udfs:
return plan

# Deduplicate while preserving order
seen = set()
unique_undeployed_udfs = []
for udf in undeployed_udfs:
if udf not in seen:
seen.add(udf)
unique_undeployed_udfs.append(udf)

session = self.loader._session
deployed_mapping: dict[udf_def.PythonUdf, udf_def.BigqueryUdf] = {}
for udf in unique_undeployed_udfs:
deployed_udf = await asyncio.to_thread(
session._function_session.deploy_undeployed_udf,
session,
udf,
)
deployed_mapping[udf] = deployed_udf
Comment on lines +554 to +560
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

UDFs are currently deployed sequentially. Since each deployment involves network calls to BigQuery and resource provisioning, this can significantly delay query execution when multiple UDFs are used in a single plan. Parallelizing these deployments using asyncio.gather would improve performance.

        # Deploy UDFs in parallel to improve performance
        tasks = [
            asyncio.to_thread(
                session._function_session.deploy_undeployed_udf,
                session,
                udf,
            )
            for udf in unique_undeployed_udfs
        ]
        results = await asyncio.gather(*tasks)
        deployed_mapping = dict(zip(unique_undeployed_udfs, results))


# Now rewrite the plan using bottom_up to substitute the UDF definitions!
def replace_in_expr(expr: expression.Expression) -> expression.Expression:
def replace_step(e: expression.Expression) -> expression.Expression:
if isinstance(e, expression.OpExpression):
op = e.op
if isinstance(
op,
(
ops.RemoteFunctionOp,
ops.BinaryRemoteFunctionOp,
ops.NaryRemoteFunctionOp,
),
):
func_def = op.function_def
if func_def in deployed_mapping:
new_func_def = deployed_mapping[func_def]
new_op = dataclasses.replace(op, function_def=new_func_def)
return dataclasses.replace(e, op=new_op)
return e

return expr.bottom_up(replace_step)

def replace_in_node(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
if hasattr(node, "transform_exprs"):
return node.transform_exprs(replace_in_expr)
return node

return plan.bottom_up(replace_in_node)

async def _prepare_plan_bq_execution(
self,
plan: nodes.BigFrameNode,
compute_options: Optional[ex_spec.BqComputeOptions] = None,
) -> nodes.BigFrameNode:
"""Prepare the plan for BigQuery execution by caching subtrees and uploading large local sources."""
plan = await self._deploy_undeployed_udfs(plan)
if compute_options is not None and compute_options.enable_multi_query_execution:
await self._simplify_with_caching(plan, compute_options=compute_options)
plan = self._prepare_plan_simplify(plan)
Expand Down
Loading
Loading