Skip to content

Commit daaed67

Browse files
refactor(bigframes): make executor internals async (#17093)
1 parent b6bb63e commit daaed67

22 files changed

Lines changed: 419 additions & 213 deletions

packages/bigframes/bigframes/core/events.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
18+
import concurrent.futures
1719
import dataclasses
1820
import datetime
1921
import threading
@@ -68,6 +70,9 @@ class Publisher:
6870
def __init__(self):
6971
self._subscribers_lock = threading.Lock()
7072
self._subscribers: Set[Subscriber] = set()
73+
self._executor: concurrent.futures.Executor = (
74+
concurrent.futures.ThreadPoolExecutor()
75+
)
7176

7277
def subscribe(self, callback: Callable[[Event], None]) -> Subscriber:
7378
# TODO(b/448176657): figure out how to handle subscribers/publishers in
@@ -86,6 +91,16 @@ def publish(self, event: Event):
8691
for subscriber in self._subscribers:
8792
subscriber(event)
8893

94+
async def publish_async(self, event: Event):
95+
with self._subscribers_lock:
96+
subscribers_snapshot = list(self._subscribers)
97+
loop = asyncio.get_running_loop()
98+
tasks = [
99+
loop.run_in_executor(self._executor, subscriber, event)
100+
for subscriber in subscribers_snapshot
101+
]
102+
return await asyncio.gather(*tasks, return_exceptions=True)
103+
89104

90105
class Event:
91106
pass

packages/bigframes/bigframes/core/indexes/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
293293

294294
count_scalar = (
295295
self._block.session._executor.execute(
296-
count_result, ex_spec.ExecutionSpec(promise_under_10gb=True)
296+
count_result,
297+
ex_spec.ExecutionSpec(promise_under_10gb=True),
297298
)
298299
.batches()
299300
.to_py_scalar()
@@ -308,7 +309,8 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]:
308309
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
309310
position_scalar = (
310311
self._block.session._executor.execute(
311-
position_result, ex_spec.ExecutionSpec(promise_under_10gb=True)
312+
position_result,
313+
ex_spec.ExecutionSpec(promise_under_10gb=True),
312314
)
313315
.batches()
314316
.to_py_scalar()

packages/bigframes/bigframes/core/local_data.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,11 @@ def iter_array(
248248
elif dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
249249
if duration_type == "int":
250250
yield from map(
251-
lambda x: ((x.days * 3600 * 24) + x.seconds) * 1_000_000
252-
+ x.microseconds
253-
if x is not None
254-
else x,
251+
lambda x: (
252+
((x.days * 3600 * 24) + x.seconds) * 1_000_000 + x.microseconds
253+
if x is not None
254+
else x
255+
),
255256
values,
256257
)
257258
else:

packages/bigframes/bigframes/formatting_helpers.py

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import google.cloud.bigquery as bigquery
2727
import humanize
2828

29+
import bigframes._config
30+
2931
if TYPE_CHECKING:
3032
import bigframes.core.events
3133

@@ -137,78 +139,85 @@ def repr_query_job_html(query_job: Optional[bigquery.QueryJob]):
137139
current_display_id: Optional[str] = None
138140

139141

140-
def progress_callback(
141-
event: bigframes.core.events.Event,
142-
):
143-
"""Displays a progress bar while the query is running"""
144-
global current_display_id
142+
def create_progress_callback():
143+
# bind potentially thread-local config to the callback so that it uses the user thread
144+
# config even if callback is invoked from a worker thread.
145+
display_opts = bigframes._config.options.display
145146

146-
try:
147-
import bigframes._config
148-
import bigframes.core.events
149-
except ImportError:
150-
# Since this gets called from __del__, skip if the import fails to avoid
151-
# ImportError: sys.meta_path is None, Python is likely shutting down.
152-
# This will allow cleanup to continue.
153-
return
147+
def progress_callback(
148+
event: bigframes.core.events.Event,
149+
):
150+
"""Displays a progress bar while the query is running"""
151+
global current_display_id
154152

155-
progress_bar = bigframes._config.options.display.progress_bar
153+
try:
154+
import bigframes._config
155+
import bigframes.core.events
156+
except ImportError:
157+
# Since this gets called from __del__, skip if the import fails to avoid
158+
# ImportError: sys.meta_path is None, Python is likely shutting down.
159+
# This will allow cleanup to continue.
160+
return
156161

157-
if progress_bar == "auto":
158-
progress_bar = "notebook" if in_ipython() else "terminal"
162+
progress_bar = display_opts.progress_bar
163+
164+
if progress_bar == "auto":
165+
progress_bar = "notebook" if in_ipython() else "terminal"
159166

160-
if progress_bar == "notebook":
161-
import IPython.display as display
167+
if progress_bar == "notebook":
168+
import IPython.display as display
162169

163-
display_html = None
170+
display_html = None
164171

165-
if isinstance(event, bigframes.core.events.ExecutionStarted):
166-
# Start a new context for progress output.
167-
current_display_id = None
172+
if isinstance(event, bigframes.core.events.ExecutionStarted):
173+
# Start a new context for progress output.
174+
current_display_id = None
168175

169-
elif isinstance(event, bigframes.core.events.BigQuerySentEvent):
170-
display_html = render_bqquery_sent_event_html(event)
176+
elif isinstance(event, bigframes.core.events.BigQuerySentEvent):
177+
display_html = render_bqquery_sent_event_html(event)
171178

172-
elif isinstance(event, bigframes.core.events.BigQueryRetryEvent):
173-
display_html = render_bqquery_retry_event_html(event)
179+
elif isinstance(event, bigframes.core.events.BigQueryRetryEvent):
180+
display_html = render_bqquery_retry_event_html(event)
174181

175-
elif isinstance(event, bigframes.core.events.BigQueryReceivedEvent):
176-
display_html = render_bqquery_received_event_html(event)
182+
elif isinstance(event, bigframes.core.events.BigQueryReceivedEvent):
183+
display_html = render_bqquery_received_event_html(event)
177184

178-
elif isinstance(event, bigframes.core.events.BigQueryFinishedEvent):
179-
display_html = render_bqquery_finished_event_html(event)
185+
elif isinstance(event, bigframes.core.events.BigQueryFinishedEvent):
186+
display_html = render_bqquery_finished_event_html(event)
180187

181-
elif isinstance(event, bigframes.core.events.SessionClosed):
182-
display_html = f"Session {event.session_id} closed."
188+
elif isinstance(event, bigframes.core.events.SessionClosed):
189+
display_html = f"Session {event.session_id} closed."
183190

184-
if display_html:
185-
if current_display_id:
186-
display.update_display(
187-
display.HTML(display_html),
188-
display_id=current_display_id,
189-
)
190-
else:
191-
current_display_id = str(random.random())
192-
display.display(
193-
display.HTML(display_html),
194-
display_id=current_display_id,
195-
)
191+
if display_html:
192+
if current_display_id:
193+
display.update_display(
194+
display.HTML(display_html),
195+
display_id=current_display_id,
196+
)
197+
else:
198+
current_display_id = str(random.random())
199+
display.display(
200+
display.HTML(display_html),
201+
display_id=current_display_id,
202+
)
196203

197-
elif progress_bar == "terminal":
198-
message = None
199-
200-
if isinstance(event, bigframes.core.events.BigQuerySentEvent):
201-
message = render_bqquery_sent_event_plaintext(event)
202-
print(message)
203-
elif isinstance(event, bigframes.core.events.BigQueryRetryEvent):
204-
message = render_bqquery_retry_event_plaintext(event)
205-
print(message)
206-
elif isinstance(event, bigframes.core.events.BigQueryReceivedEvent):
207-
message = render_bqquery_received_event_plaintext(event)
208-
print(message)
209-
elif isinstance(event, bigframes.core.events.BigQueryFinishedEvent):
210-
message = render_bqquery_finished_event_plaintext(event)
211-
print(message)
204+
elif progress_bar == "terminal":
205+
message = None
206+
207+
if isinstance(event, bigframes.core.events.BigQuerySentEvent):
208+
message = render_bqquery_sent_event_plaintext(event)
209+
print(message)
210+
elif isinstance(event, bigframes.core.events.BigQueryRetryEvent):
211+
message = render_bqquery_retry_event_plaintext(event)
212+
print(message)
213+
elif isinstance(event, bigframes.core.events.BigQueryReceivedEvent):
214+
message = render_bqquery_received_event_plaintext(event)
215+
print(message)
216+
elif isinstance(event, bigframes.core.events.BigQueryFinishedEvent):
217+
message = render_bqquery_finished_event_plaintext(event)
218+
print(message)
219+
220+
return progress_callback
212221

213222

214223
def wait_for_job(job: GenericJob, progress_bar: Optional[str] = None):

packages/bigframes/bigframes/operations/python_op_maps.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@
7171
abs: numeric_ops.abs_op,
7272
pow: numeric_ops.pow_op,
7373
### builtins -- iterable
74-
all: array_ops.ArrayReduceOp(aggregations.all_op),
75-
any: array_ops.ArrayReduceOp(aggregations.any_op),
76-
sum: array_ops.ArrayReduceOp(aggregations.sum_op),
77-
min: array_ops.ArrayReduceOp(aggregations.min_op),
78-
max: array_ops.ArrayReduceOp(aggregations.max_op),
74+
all: array_ops.ArrayReduceOp(aggregations.all_op), # type: ignore
75+
any: array_ops.ArrayReduceOp(aggregations.any_op), # type: ignore
76+
sum: array_ops.ArrayReduceOp(aggregations.sum_op), # type: ignore
77+
min: array_ops.ArrayReduceOp(aggregations.min_op), # type: ignore
78+
max: array_ops.ArrayReduceOp(aggregations.max_op), # type: ignore
7979
}
8080

8181

packages/bigframes/bigframes/session/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ def __init__(
183183
# Publisher needs to be created before the other objects, especially
184184
# the executors, because they access it.
185185
self._publisher = bigframes.core.events.Publisher()
186-
self._publisher.subscribe(bigframes.formatting_helpers.progress_callback)
186+
self._publisher.subscribe(
187+
bigframes.formatting_helpers.create_progress_callback()
188+
)
187189

188190
if context is None:
189191
context = bigquery_options.BigQueryOptions()

packages/bigframes/bigframes/session/_io/bigquery/__init__.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def create_job_configs_labels(
5555
if job_configs_labels is None:
5656
job_configs_labels = {}
5757

58-
# If the user has labels they wish to set, make sure we set those first so
59-
# they are preserved.
60-
for key, value in bigframes.options.compute.extra_query_labels.items():
61-
job_configs_labels[key] = value
62-
6358
if api_methods and "bigframes-api" not in job_configs_labels:
6459
job_configs_labels["bigframes-api"] = api_methods[0]
6560
del api_methods[0]
@@ -230,7 +225,9 @@ def format_option(key: str, value: Union[bool, str]) -> str:
230225
return f"{key}={repr(value)}"
231226

232227

233-
def add_and_trim_labels(job_config, session=None):
228+
def add_and_trim_labels(
229+
job_config, session=None, extra_query_labels: Optional[Mapping[str, str]] = None
230+
):
234231
"""
235232
Add additional labels to the job configuration and trim the total number of labels
236233
to ensure they do not exceed MAX_LABELS_COUNT labels per job.

packages/bigframes/bigframes/session/bigquery_session.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@ def close(self):
119119
bfbqio.start_query_job_optional(
120120
self.bqclient,
121121
f"CALL BQ.ABORT_SESSION('{self._session_id}')",
122-
job_config=bigquery.QueryJobConfig(),
122+
# Assume this is being called in the user thread, so we can access
123+
# this thread-local config.
124+
job_config=bigquery.QueryJobConfig(
125+
labels=bigframes.options.compute.extra_query_labels
126+
),
123127
location=self.location,
124128
project=None,
125129
timeout=None,

0 commit comments

Comments
 (0)