Skip to content

Commit ad193e2

Browse files
Fix pickling errors in parallel workers (#137)
* Fix pickling errors by importing task modules in workers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Document fix for pickling errors * Fix typing, lint, and docs checks * Allow global worker root with ruff ignore * Add regression test for mark import loop * Strip local from task. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Release 0.5.2 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f168e73 commit ad193e2

File tree

9 files changed

+188
-146
lines changed

9 files changed

+188
-146
lines changed

docs/source/changes.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask-parallel).
77

8-
## Unreleased
8+
## 0.5.2 - 2026-02-05
99

1010
- {pull}`129` drops support for Python 3.8 and 3.9 and adds support for Python 3.14.
1111
- {pull}`130` switches type checking to ty.
1212
- {pull}`131` updates pre-commit hooks.
1313
- {pull}`132` removes the tox configuration in favor of uv and just.
14+
- {pull}`137` fixes pickling errors in parallel workers when task modules contain
15+
non-picklable globals. Fixes {issue}`136`.
1416

1517
## 0.5.1 - 2025-03-09
1618

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ docs = [
3333
"matplotlib",
3434
"myst-parser",
3535
"nbsphinx",
36-
"sphinx",
36+
"sphinx<9",
3737
"sphinx-autobuild",
3838
"sphinx-click",
3939
"sphinx-copybutton",

requirements-dev.lock

Lines changed: 0 additions & 86 deletions
This file was deleted.

requirements.lock

Lines changed: 0 additions & 51 deletions
This file was deleted.

src/pytask_parallel/backends.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import os
6+
import sys
57
import warnings
68
from concurrent.futures import Executor
79
from concurrent.futures import Future
@@ -19,7 +21,46 @@
1921
if TYPE_CHECKING:
2022
from collections.abc import Callable
2123

22-
__all__ = ["ParallelBackend", "ParallelBackendRegistry", "WorkerType", "registry"]
24+
__all__ = [
25+
"ParallelBackend",
26+
"ParallelBackendRegistry",
27+
"WorkerType",
28+
"registry",
29+
"set_worker_root",
30+
]
31+
32+
_WORKER_ROOT: str | None = None
33+
34+
35+
def set_worker_root(path: os.PathLike[str] | str) -> None:
36+
"""Configure the root path for worker processes.
37+
38+
Spawned workers (notably on Windows) start with a clean interpreter and may not
39+
inherit the parent's import path. We set both ``sys.path`` and ``PYTHONPATH`` so
40+
task modules are importable by reference, which avoids pickling module globals.
41+
42+
"""
43+
root = os.fspath(path)
44+
global _WORKER_ROOT # noqa: PLW0603
45+
_WORKER_ROOT = root
46+
if root not in sys.path:
47+
sys.path.insert(0, root)
48+
# Ensure custom process backends can import task modules by reference.
49+
separator = os.pathsep
50+
current = os.environ.get("PYTHONPATH", "")
51+
parts = [p for p in current.split(separator) if p] if current else []
52+
if root not in parts:
53+
parts.insert(0, root)
54+
os.environ["PYTHONPATH"] = separator.join(parts)
55+
56+
57+
def _configure_worker(root: str | None) -> None:
58+
"""Set cwd and sys.path for worker processes."""
59+
if not root:
60+
return
61+
os.chdir(root)
62+
if root not in sys.path:
63+
sys.path.insert(0, root)
2364

2465

2566
def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
@@ -75,12 +116,20 @@ def _get_dask_executor(n_workers: int) -> Executor:
75116

76117
def _get_loky_executor(n_workers: int) -> Executor:
77118
"""Get a loky executor."""
78-
return get_reusable_executor(max_workers=n_workers)
119+
return get_reusable_executor(
120+
max_workers=n_workers,
121+
initializer=_configure_worker,
122+
initargs=(_WORKER_ROOT,),
123+
)
79124

80125

81126
def _get_process_pool_executor(n_workers: int) -> Executor:
82127
"""Get a process pool executor."""
83-
return _CloudpickleProcessPoolExecutor(max_workers=n_workers)
128+
return _CloudpickleProcessPoolExecutor(
129+
max_workers=n_workers,
130+
initializer=_configure_worker,
131+
initargs=(_WORKER_ROOT,),
132+
)
84133

85134

86135
def _get_thread_pool_executor(n_workers: int) -> Executor:

src/pytask_parallel/execute.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,14 @@
2626

2727
from pytask_parallel.backends import WorkerType
2828
from pytask_parallel.backends import registry
29+
from pytask_parallel.backends import set_worker_root
2930
from pytask_parallel.typing import CarryOverPath
3031
from pytask_parallel.typing import is_coiled_function
3132
from pytask_parallel.utils import create_kwargs_for_task
3233
from pytask_parallel.utils import get_module
3334
from pytask_parallel.utils import parse_future_result
35+
from pytask_parallel.utils import should_pickle_module_by_value
36+
from pytask_parallel.utils import strip_annotation_locals
3437

3538
if TYPE_CHECKING:
3639
from concurrent.futures import Future
@@ -57,6 +60,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
5760

5861
# The executor can only be created after the collection to give users the
5962
# possibility to inject their own executors.
63+
set_worker_root(session.config["root"])
6064
session.config["_parallel_executor"] = registry.get_parallel_backend(
6165
session.config["parallel_backend"], n_workers=session.config["n_workers"]
6266
)
@@ -195,6 +199,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
195199
kwargs = create_kwargs_for_task(task, remote=remote)
196200

197201
if is_coiled_function(task):
202+
strip_annotation_locals(task)
198203
# Prevent circular import for coiled backend.
199204
from pytask_parallel.wrappers import ( # noqa: PLC0415
200205
rewrap_task_with_coiled_function,
@@ -208,7 +213,8 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
208213
# cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373
209214
# and pytask#374.
210215
task_module = get_module(task.function, getattr(task, "path", None))
211-
cloudpickle.register_pickle_by_value(task_module)
216+
if should_pickle_module_by_value(task_module):
217+
cloudpickle.register_pickle_by_value(task_module)
212218

213219
return cast("Any", wrapper_func).submit(
214220
task=task,
@@ -221,6 +227,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
221227
)
222228

223229
if worker_type == WorkerType.PROCESSES:
230+
strip_annotation_locals(task)
224231
# Prevent circular import for loky backend.
225232
from pytask_parallel.wrappers import wrap_task_in_process # noqa: PLC0415
226233

@@ -230,7 +237,8 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
230237
# cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373
231238
# and pytask#374.
232239
task_module = get_module(task.function, getattr(task, "path", None))
233-
cloudpickle.register_pickle_by_value(task_module)
240+
if should_pickle_module_by_value(task_module):
241+
cloudpickle.register_pickle_by_value(task_module)
234242

235243
return session.config["_parallel_executor"].submit(
236244
wrap_task_in_process,

src/pytask_parallel/utils.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations
44

5+
import importlib.util
56
import inspect
67
from functools import partial
8+
from pathlib import Path
79
from typing import TYPE_CHECKING
810
from typing import Any
911

@@ -20,7 +22,6 @@
2022
if TYPE_CHECKING:
2123
from collections.abc import Callable
2224
from concurrent.futures import Future
23-
from pathlib import Path
2425
from types import ModuleType
2526
from types import TracebackType
2627

@@ -39,6 +40,8 @@ class CoiledFunction: ...
3940
"create_kwargs_for_task",
4041
"get_module",
4142
"parse_future_result",
43+
"should_pickle_module_by_value",
44+
"strip_annotation_locals",
4245
]
4346

4447

@@ -150,3 +153,43 @@ def get_module(func: Callable[..., Any], path: Path | None) -> ModuleType:
150153
if path:
151154
return inspect.getmodule(func, path.as_posix()) # type: ignore[return-value]
152155
return inspect.getmodule(func) # type: ignore[return-value]
156+
157+
158+
def strip_annotation_locals(task: PTask) -> None:
159+
"""Remove annotation locals from task functions before pickling.
160+
161+
The locals snapshot is only needed during collection to evaluate annotations.
162+
Keeping it around for execution can break pickling when it contains non-serializable
163+
objects (for example, when importing ``pytask.mark`` in loop-generated tasks).
164+
165+
"""
166+
meta = getattr(task.function, "pytask_meta", None)
167+
if meta is not None and getattr(meta, "annotation_locals", None) is not None:
168+
meta.annotation_locals = None
169+
170+
171+
def should_pickle_module_by_value(module: ModuleType) -> bool:
172+
"""Return whether a module should be pickled by value.
173+
174+
We only pickle by value when the module is not importable by name in the worker.
175+
This avoids serializing all module globals, which can fail for non-picklable
176+
objects (e.g., closed file handles or locks stored at module scope).
177+
178+
"""
179+
module_name = getattr(module, "__name__", None)
180+
module_file = getattr(module, "__file__", None)
181+
if not module_name or module_name == "__main__" or module_file is None:
182+
return True
183+
184+
try:
185+
spec = importlib.util.find_spec(module_name)
186+
except (ImportError, ValueError, AttributeError):
187+
return True
188+
189+
if spec is None or spec.origin is None:
190+
return True
191+
192+
try:
193+
return Path(spec.origin).resolve() != Path(module_file).resolve()
194+
except OSError:
195+
return True

src/pytask_parallel/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def _render_traceback_to_string(
217217
traceback = Traceback(exc_info, show_locals=show_locals)
218218
segments = console.render(cast("Any", traceback), options=console_options)
219219
text = "".join(segment.text for segment in segments)
220-
return (*exc_info[:2], text) # ty: ignore[invalid-return-type]
220+
return (*exc_info[:2], text)
221221

222222

223223
def _handle_function_products(

0 commit comments

Comments
 (0)