Skip to content

Commit 07f259b

Browse files
authored
Merge pull request #3 from CERBSim/fix_webgpu_in_vscode
fixes for running webgpu in vscode notebooks
2 parents aa39131 + 49f2945 commit 07f259b

2 files changed

Lines changed: 73 additions & 19 deletions

File tree

webgpu/jupyter.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def _draw_scene(scene: Scene, width, height, id_):
9191
html_canvas.height = height
9292
gui_element = platform.js.document.getElementById(f"{id_}lilgui")
9393

94-
canvas = Canvas(utils.get_device(), html_canvas)
94+
# Lazily initialize the WebGPU device the first time we draw.
95+
canvas = Canvas(init_device_sync(), html_canvas)
9596
scene.gui = LilGUI(gui_element, scene)
9697
scene.init(canvas)
9798
scene.render()
@@ -136,7 +137,13 @@ def Draw(
136137
height = height if height is not None else 640
137138

138139
scene, id_ = _init_html(scene, width, height, flex)
139-
_draw_scene(scene, width, height, id_)
140+
141+
# In classic Jupyter we already have a websocket connection at import
142+
# time, so this callback runs immediately. In VS Code, outputs are only
143+
# processed once the cell has finished executing; using execute_when_init
144+
# ensures that drawing happens once the websocket connection is ready
145+
# instead of blocking the import.
146+
platform.execute_when_init(lambda js: _draw_scene(scene, width, height, id_))
140147
return scene
141148

142149

@@ -178,12 +185,26 @@ def Draw(
178185
js_code += f"\nwindow.pyodide_ready = init_pyodide('{webgpu_module_b64}');"
179186
display(Javascript(js_code))
180187
else:
181-
# Not exporting and not running in pyodide -> Start a websocket server and wait for the client to connect
188+
# Not exporting and not running in pyodide -> Start a websocket server
189+
# and wait for the client to connect.
190+
#
191+
# In VS Code notebooks, outputs are typically only processed once the
192+
# cell has completed execution. If we were to block here waiting for
193+
# the websocket connection, the JavaScript that establishes the
194+
# connection would never run, leading to a deadlock. We therefore
195+
# avoid blocking on the connection in that environment and instead
196+
# defer drawing until the link is ready via execute_when_init.
197+
198+
def _webgpu_js(server):
199+
js = _link_js_code + """
200+
const __is_vscode = (typeof location !== 'undefined' && location.protocol === 'vscode-webview:');
201+
const __webgpu_host = __is_vscode ? '127.0.0.1' : ((typeof location !== 'undefined' && location.hostname) || '127.0.0.1');
202+
WebsocketLink('ws://' + __webgpu_host + ':{port}');
203+
""".format(port=server.port)
204+
display(Javascript(js))
205+
206+
is_vscode = "VSCODE_PID" in os.environ
182207
platform.init(
183-
before_wait_for_connection=lambda server: display(
184-
Javascript(
185-
_link_js_code + f"WebsocketLink('ws://'+location.hostname+':{server.port}');"
186-
)
187-
)
208+
before_wait_for_connection=_webgpu_js,
209+
block_on_connection=not is_vscode,
188210
)
189-
device = init_device_sync()

webgpu/platform.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from collections.abc import Mapping
11+
import threading
1112

1213
is_pyodide = False
1314
is_pyodide_main_thread = False
@@ -127,16 +128,37 @@ def _serialize_jsproxy(link, value):
127128

128129

129130
def execute_when_init(func):
131+
"""Register a callback to run once the JS side is ready.
132+
133+
If the platform has already been initialized, the callback is executed
134+
immediately. Otherwise it is queued and executed from ``init`` once the
135+
websocket connection has been established and ``js`` is set.
136+
"""
137+
130138
if js is not None:
131139
func(js)
132140
else:
133141
_funcs_after_init.append(func)
134142

135143

136-
def init(before_wait_for_connection=None):
144+
def init(before_wait_for_connection=None, block_on_connection: bool = True):
145+
"""Initialize the websocket link to the browser.
146+
147+
In the default (classic Jupyter) mode, this blocks until the browser has
148+
connected via websocket so that ``js`` is ready to use.
149+
150+
In environments like VS Code notebooks, outputs are typically only
151+
processed once the cell has finished executing. In that situation calling
152+
``init`` with ``block_on_connection=False`` avoids a deadlock by moving the
153+
blocking ``wait_for_connection`` part to a background thread. Code that
154+
depends on ``js`` should use :func:`execute_when_init` so it runs once the
155+
connection is ready.
156+
"""
157+
137158
global js, create_proxy, destroy_proxy, websocket_server, link
138159
if is_pyodide or js is not None:
139160
return
161+
140162
websocket_server = WebsocketLinkServer()
141163
create_proxy = websocket_server.create_proxy
142164
destroy_proxy = websocket_server.destroy_proxy
@@ -147,19 +169,30 @@ def init(before_wait_for_connection=None):
147169
if before_wait_for_connection:
148170
before_wait_for_connection(websocket_server)
149171

150-
websocket_server.wait_for_connection()
151-
js = websocket_server.get(None, None)
152-
153172
from .link.base import LinkBase
154173
from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
155174

156-
LinkBase.register_serializer(BaseWebGPUHandle, lambda _, v: v.handle)
157-
LinkBase.register_serializer(BaseWebGPUObject, lambda _, v: v.__dict__ or None)
175+
def _finish_init():
176+
websocket_server.wait_for_connection()
177+
js_local = websocket_server.get(None, None)
158178

159-
websocket_server._start_handling_messages.set()
160-
for func in _funcs_after_init:
161-
func(js)
162-
_funcs_after_init.clear()
179+
LinkBase.register_serializer(BaseWebGPUHandle, lambda _, v: v.handle)
180+
LinkBase.register_serializer(
181+
BaseWebGPUObject, lambda _, v: v.__dict__ or None
182+
)
183+
184+
# Publish js and run any deferred callbacks.
185+
globals()["js"] = js_local
186+
websocket_server._start_handling_messages.set()
187+
for func in _funcs_after_init:
188+
func(js_local)
189+
_funcs_after_init.clear()
190+
191+
if block_on_connection:
192+
_finish_init()
193+
else:
194+
thread = threading.Thread(target=_finish_init, daemon=True)
195+
thread.start()
163196

164197

165198
def init_pyodide(link_):

0 commit comments

Comments
 (0)