diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e6f2f7c..782ceca2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and ## Unreleased +- [#820](https://github.com/pytask-dev/pytask/pull/820) fixes collection and node + display for remote `UPath`-backed nodes, while preserving correct handling of local + `file://` and `local://` `UPath`s across platforms. - [#743](https://github.com/pytask-dev/pytask/pull/743) adds the `pytask.lock` lockfile as the primary state backend with a portable format and documentation. When no lockfile exists, pytask reads the legacy SQLite state and writes `pytask.lock`; diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index 109833bc..4a09c9b0 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -42,6 +42,8 @@ from _pytask.outcomes import count_outcomes from _pytask.path import find_case_sensitive_path from _pytask.path import import_path +from _pytask.path import is_non_local_path +from _pytask.path import normalize_local_upath from _pytask.path import shorten_path from _pytask.pluginmanager import hookimpl from _pytask.reports import CollectionReport @@ -455,7 +457,14 @@ def pytask_collect_node( # noqa: C901, PLR0912 node.name = create_name_of_python_node(node_info) return node - if isinstance(node, PPathNode) and not node.path.is_absolute(): + if isinstance(node, PPathNode): + node.path = normalize_local_upath(node.path) + + if ( + isinstance(node, PPathNode) + and not is_non_local_path(node.path) + and not node.path.is_absolute() + ): node.path = path.joinpath(node.path) # ``normpath`` removes ``../`` from the path which is necessary for the casing @@ -487,6 +496,9 @@ def pytask_collect_node( # noqa: C901, PLR0912 node.name = create_name_of_python_node(node_info) return node + if isinstance(node, UPath): # pragma: no cover + node = normalize_local_upath(node) + if isinstance(node, UPath): # pragma: no cover if not node.protocol: node = Path(node) diff --git a/src/_pytask/collect_command.py b/src/_pytask/collect_command.py index bd75eecc..4ad86229 100644 --- a/src/_pytask/collect_command.py +++ b/src/_pytask/collect_command.py @@ -30,6 +30,8 @@ from _pytask.node_protocols import PTaskWithPath from _pytask.outcomes import ExitCode from _pytask.path import find_common_ancestor +from _pytask.path import is_non_local_path +from _pytask.path import normalize_local_upath from _pytask.path import relative_to from _pytask.pluginmanager import hookimpl from _pytask.pluginmanager import storage @@ -125,10 +127,14 @@ def _find_common_ancestor_of_all_nodes( all_paths.append(task.path) if show_nodes: all_paths.extend( - x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode) + normalize_local_upath(x.path) + for x in tree_leaves(task.depends_on) + if isinstance(x, PPathNode) and not is_non_local_path(x.path) ) all_paths.extend( - x.path for x in tree_leaves(task.produces) if isinstance(x, PPathNode) + normalize_local_upath(x.path) + for x in tree_leaves(task.produces) + if isinstance(x, PPathNode) and not is_non_local_path(x.path) ) return find_common_ancestor(*all_paths, *paths) diff --git a/src/_pytask/path.py b/src/_pytask/path.py index a70c16d5..5619c04f 100644 --- a/src/_pytask/path.py +++ b/src/_pytask/path.py @@ -13,6 +13,8 @@ from types import ModuleType from typing import TYPE_CHECKING +from upath import UPath + from _pytask._hashlib import file_digest from _pytask.cache import Cache @@ -25,11 +27,17 @@ "find_common_ancestor", "hash_path", "import_path", + "is_non_local_path", + "normalize_local_upath", "relative_to", "shorten_path", ] +_LOCAL_UPATH_PROTOCOLS = frozenset(("", "file", "local")) +_WINDOWS_DRIVE_PREFIX_LENGTH = 3 + + def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Path: """Make a path relative to another path. @@ -56,6 +64,27 @@ def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Pat return Path(source_name, path.relative_to(source)) +def is_non_local_path(path: Path) -> bool: + """Return whether a path points to a non-local `UPath` resource.""" + return isinstance(path, UPath) and path.protocol not in _LOCAL_UPATH_PROTOCOLS + + +def normalize_local_upath(path: Path) -> Path: + """Convert local `UPath` variants to a stdlib `Path`.""" + if isinstance(path, UPath) and path.protocol in {"file", "local"}: + local_path = path.path + if ( + sys.platform == "win32" + and local_path.startswith("/") + and len(local_path) >= _WINDOWS_DRIVE_PREFIX_LENGTH + and local_path[1].isalpha() + and local_path[2] == ":" + ): + local_path = local_path[1:] + return Path(local_path) + return path + + def find_closest_ancestor( path: Path, potential_ancestors: Sequence[Path] ) -> Path | None: @@ -432,6 +461,12 @@ def shorten_path(path: Path, paths: Sequence[Path]) -> str: path from one path in ``session.config["paths"]`` to the node. """ + if is_non_local_path(path): + return path.as_posix() + + path = normalize_local_upath(path) + paths = [normalize_local_upath(p) for p in paths] + ancestor = find_closest_ancestor(path, paths) if ancestor is None: try: diff --git a/tests/test_collect.py b/tests/test_collect.py index 7e08c2d4..29f80c59 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -14,6 +14,7 @@ from pytask import CollectionOutcome from pytask import ExitCode from pytask import NodeInfo +from pytask import PickleNode from pytask import Session from pytask import Task from pytask import build @@ -21,6 +22,10 @@ from tests.conftest import noop +def _make_local_upath_uri(path: Path, protocol: str) -> str: + return f"{protocol}:///{path.as_posix().lstrip('/')}" + + @pytest.mark.parametrize( ("depends_on", "produces"), [ @@ -190,6 +195,56 @@ def test_pytask_collect_node(session, path, node_info, expected): assert str(result.load()) == str(expected) +def test_pytask_collect_remote_path_node_keeps_uri_name(): + upath = pytest.importorskip("upath") + + session = Session.from_config( + {"check_casing_of_paths": False, "paths": (Path.cwd(),), "root": Path.cwd()} + ) + + result = pytask_collect_node( + session, + Path.cwd(), + NodeInfo( + arg_name="path", + path=(), + value=PickleNode(path=upath.UPath("s3://bucket/file.pkl")), + task_path=Path.cwd() / "task_example.py", + task_name="task_example", + ), + ) + + assert isinstance(result, PPathNode) + assert result.name == "s3://bucket/file.pkl" + + +@pytest.mark.parametrize("protocol", ["file", "local"]) +def test_pytask_collect_local_upath_protocol_node_is_shortened(tmp_path, protocol): + upath = pytest.importorskip("upath") + + session = Session.from_config( + {"check_casing_of_paths": False, "paths": (tmp_path,), "root": tmp_path} + ) + + result = pytask_collect_node( + session, + tmp_path, + NodeInfo( + arg_name="path", + path=(), + value=PickleNode( + path=upath.UPath(_make_local_upath_uri(tmp_path / "file.pkl", protocol)) + ), + task_path=tmp_path / "task_example.py", + task_name="task_example", + ), + ) + + assert isinstance(result, PPathNode) + assert result.path == tmp_path / "file.pkl" + assert result.name == f"{tmp_path.name}/file.pkl" + + @pytest.mark.skipif( sys.platform != "win32", reason="Only works on case-insensitive file systems." ) diff --git a/tests/test_collect_command.py b/tests/test_collect_command.py index 532d6b97..f84cc98c 100644 --- a/tests/test_collect_command.py +++ b/tests/test_collect_command.py @@ -21,6 +21,10 @@ from _pytask.node_protocols import PTaskWithPath +def _make_local_upath_uri(path: Path, protocol: str) -> str: + return f"{protocol}:///{path.as_posix().lstrip('/')}" + + def test_collect_task(runner, tmp_path): source = """ from pathlib import Path @@ -396,6 +400,59 @@ def test_task_name_is_shortened(runner, tmp_path): assert "a/b/task_example.py::task_example" not in result.output +def test_collect_task_with_remote_upath_node(runner, tmp_path): + pytest.importorskip("upath") + + source = """ + from pathlib import Path + from typing import Annotated + + from upath import UPath + + from pytask import PickleNode + from pytask import Product + + def task_example( + data=PickleNode(path=UPath("s3://bucket/in.pkl")), + path: Annotated[Path, Product] = Path("out.txt"), + ): ... + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + + assert result.exit_code == ExitCode.OK + assert "s3://bucket/in.pkl" in result.output + + +@pytest.mark.parametrize("protocol", ["file", "local"]) +def test_collect_task_with_local_upath_protocol_node(runner, tmp_path, protocol): + pytest.importorskip("upath") + + uri = _make_local_upath_uri(tmp_path / "in.pkl", protocol) + + source = f""" + from pathlib import Path + from typing import Annotated + + from upath import UPath + + from pytask import PickleNode + from pytask import Product + + def task_example( + data=PickleNode(path=UPath("{uri}")), + path: Annotated[Path, Product] = Path("out.txt"), + ): ... + """ + tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()]) + + assert result.exit_code == ExitCode.OK + assert f"{tmp_path.name}/in.pkl" in result.output + + def test_python_node_is_collected(runner, tmp_path): source = """ from pytask import Product diff --git a/tests/test_path.py b/tests/test_path.py index 4d72932f..ba490149 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -19,13 +19,20 @@ from _pytask.path import find_case_sensitive_path from _pytask.path import find_closest_ancestor from _pytask.path import find_common_ancestor +from _pytask.path import is_non_local_path +from _pytask.path import normalize_local_upath from _pytask.path import relative_to +from _pytask.path import shorten_path from pytask.path import import_path if TYPE_CHECKING: from collections.abc import Generator +def _make_local_upath_uri(path: Path, protocol: str) -> str: + return f"{protocol}:///{path.as_posix().lstrip('/')}" + + @pytest.mark.parametrize( ("path", "source", "include_source", "expected"), [ @@ -110,6 +117,36 @@ def test_find_common_ancestor(path_1, path_2, expectation, expected): assert result == expected +def test_shorten_path_keeps_non_local_uri(): + upath = pytest.importorskip("upath") + + path = upath.UPath("s3://bucket/file.pkl") + + assert shorten_path(path, [Path.cwd()]) == "s3://bucket/file.pkl" + + +@pytest.mark.parametrize("protocol", ["file", "local"]) +def test_shorten_path_treats_local_upath_protocols_as_local(tmp_path, protocol): + upath = pytest.importorskip("upath") + + path = upath.UPath(_make_local_upath_uri(tmp_path / "file.pkl", protocol)) + + assert not is_non_local_path(path) + assert shorten_path(path, [tmp_path]) == f"{tmp_path.name}/file.pkl" + + +@pytest.mark.parametrize("protocol", ["file", "local"]) +def test_normalize_local_upath_strips_windows_drive_prefix(monkeypatch, protocol): + upath = pytest.importorskip("upath") + + monkeypatch.setattr(sys, "platform", "win32") + path = upath.UPath(f"{protocol}:///C:/tmp/file.pkl") + + result = normalize_local_upath(path) + + assert result.as_posix() == "C:/tmp/file.pkl" + + @pytest.mark.skipif(sys.platform != "win32", reason="Only works on Windows.") @pytest.mark.parametrize( ("path", "existing_paths", "expected"),