Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and

## Unreleased

- [#820](https://github.com/pytask-dev/pytask/pull/820) fixes collection and node
display for remote `UPath`-backed nodes, while preserving correct handling of local
`file://` and `local://` `UPath`s across platforms.
- [#743](https://github.com/pytask-dev/pytask/pull/743) adds the `pytask.lock`
lockfile as the primary state backend with a portable format and documentation. When
no lockfile exists, pytask reads the legacy SQLite state and writes `pytask.lock`;
Expand Down
14 changes: 13 additions & 1 deletion src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from _pytask.outcomes import count_outcomes
from _pytask.path import find_case_sensitive_path
from _pytask.path import import_path
from _pytask.path import is_non_local_path
from _pytask.path import normalize_local_upath
from _pytask.path import shorten_path
from _pytask.pluginmanager import hookimpl
from _pytask.reports import CollectionReport
Expand Down Expand Up @@ -455,7 +457,14 @@ def pytask_collect_node( # noqa: C901, PLR0912
node.name = create_name_of_python_node(node_info)
return node

if isinstance(node, PPathNode) and not node.path.is_absolute():
if isinstance(node, PPathNode):
node.path = normalize_local_upath(node.path)

if (
isinstance(node, PPathNode)
and not is_non_local_path(node.path)
and not node.path.is_absolute()
):
node.path = path.joinpath(node.path)

# ``normpath`` removes ``../`` from the path which is necessary for the casing
Expand Down Expand Up @@ -487,6 +496,9 @@ def pytask_collect_node( # noqa: C901, PLR0912
node.name = create_name_of_python_node(node_info)
return node

if isinstance(node, UPath): # pragma: no cover
node = normalize_local_upath(node)

if isinstance(node, UPath): # pragma: no cover
if not node.protocol:
node = Path(node)
Expand Down
10 changes: 8 additions & 2 deletions src/_pytask/collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from _pytask.node_protocols import PTaskWithPath
from _pytask.outcomes import ExitCode
from _pytask.path import find_common_ancestor
from _pytask.path import is_non_local_path
from _pytask.path import normalize_local_upath
from _pytask.path import relative_to
from _pytask.pluginmanager import hookimpl
from _pytask.pluginmanager import storage
Expand Down Expand Up @@ -125,10 +127,14 @@ def _find_common_ancestor_of_all_nodes(
all_paths.append(task.path)
if show_nodes:
all_paths.extend(
x.path for x in tree_leaves(task.depends_on) if isinstance(x, PPathNode)
normalize_local_upath(x.path)
for x in tree_leaves(task.depends_on)
if isinstance(x, PPathNode) and not is_non_local_path(x.path)
)
all_paths.extend(
x.path for x in tree_leaves(task.produces) if isinstance(x, PPathNode)
normalize_local_upath(x.path)
for x in tree_leaves(task.produces)
if isinstance(x, PPathNode) and not is_non_local_path(x.path)
)

return find_common_ancestor(*all_paths, *paths)
Expand Down
35 changes: 35 additions & 0 deletions src/_pytask/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from types import ModuleType
from typing import TYPE_CHECKING

from upath import UPath

from _pytask._hashlib import file_digest
from _pytask.cache import Cache

Expand All @@ -25,11 +27,17 @@
"find_common_ancestor",
"hash_path",
"import_path",
"is_non_local_path",
"normalize_local_upath",
"relative_to",
"shorten_path",
]


_LOCAL_UPATH_PROTOCOLS = frozenset(("", "file", "local"))
_WINDOWS_DRIVE_PREFIX_LENGTH = 3


def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Path:
"""Make a path relative to another path.

Expand All @@ -56,6 +64,27 @@ def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Pat
return Path(source_name, path.relative_to(source))


def is_non_local_path(path: Path) -> bool:
"""Return whether a path points to a non-local `UPath` resource."""
return isinstance(path, UPath) and path.protocol not in _LOCAL_UPATH_PROTOCOLS


def normalize_local_upath(path: Path) -> Path:
"""Convert local `UPath` variants to a stdlib `Path`."""
if isinstance(path, UPath) and path.protocol in {"file", "local"}:
local_path = path.path
if (
sys.platform == "win32"
and local_path.startswith("/")
and len(local_path) >= _WINDOWS_DRIVE_PREFIX_LENGTH
and local_path[1].isalpha()
and local_path[2] == ":"
):
local_path = local_path[1:]
return Path(local_path)
return path


def find_closest_ancestor(
path: Path, potential_ancestors: Sequence[Path]
) -> Path | None:
Expand Down Expand Up @@ -432,6 +461,12 @@ def shorten_path(path: Path, paths: Sequence[Path]) -> str:
path from one path in ``session.config["paths"]`` to the node.

"""
if is_non_local_path(path):
return path.as_posix()

path = normalize_local_upath(path)
paths = [normalize_local_upath(p) for p in paths]

ancestor = find_closest_ancestor(path, paths)
if ancestor is None:
try:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
from pytask import CollectionOutcome
from pytask import ExitCode
from pytask import NodeInfo
from pytask import PickleNode
from pytask import Session
from pytask import Task
from pytask import build
from pytask import cli
from tests.conftest import noop


def _make_local_upath_uri(path: Path, protocol: str) -> str:
return f"{protocol}:///{path.as_posix().lstrip('/')}"


@pytest.mark.parametrize(
("depends_on", "produces"),
[
Expand Down Expand Up @@ -190,6 +195,56 @@ def test_pytask_collect_node(session, path, node_info, expected):
assert str(result.load()) == str(expected)


def test_pytask_collect_remote_path_node_keeps_uri_name():
upath = pytest.importorskip("upath")

session = Session.from_config(
{"check_casing_of_paths": False, "paths": (Path.cwd(),), "root": Path.cwd()}
)

result = pytask_collect_node(
session,
Path.cwd(),
NodeInfo(
arg_name="path",
path=(),
value=PickleNode(path=upath.UPath("s3://bucket/file.pkl")),
task_path=Path.cwd() / "task_example.py",
task_name="task_example",
),
)

assert isinstance(result, PPathNode)
assert result.name == "s3://bucket/file.pkl"


@pytest.mark.parametrize("protocol", ["file", "local"])
def test_pytask_collect_local_upath_protocol_node_is_shortened(tmp_path, protocol):
upath = pytest.importorskip("upath")

session = Session.from_config(
{"check_casing_of_paths": False, "paths": (tmp_path,), "root": tmp_path}
)

result = pytask_collect_node(
session,
tmp_path,
NodeInfo(
arg_name="path",
path=(),
value=PickleNode(
path=upath.UPath(_make_local_upath_uri(tmp_path / "file.pkl", protocol))
),
task_path=tmp_path / "task_example.py",
task_name="task_example",
),
)

assert isinstance(result, PPathNode)
assert result.path == tmp_path / "file.pkl"
assert result.name == f"{tmp_path.name}/file.pkl"


@pytest.mark.skipif(
sys.platform != "win32", reason="Only works on case-insensitive file systems."
)
Expand Down
57 changes: 57 additions & 0 deletions tests/test_collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from _pytask.node_protocols import PTaskWithPath


def _make_local_upath_uri(path: Path, protocol: str) -> str:
return f"{protocol}:///{path.as_posix().lstrip('/')}"


def test_collect_task(runner, tmp_path):
source = """
from pathlib import Path
Expand Down Expand Up @@ -396,6 +400,59 @@ def test_task_name_is_shortened(runner, tmp_path):
assert "a/b/task_example.py::task_example" not in result.output


def test_collect_task_with_remote_upath_node(runner, tmp_path):
pytest.importorskip("upath")

source = """
from pathlib import Path
from typing import Annotated

from upath import UPath

from pytask import PickleNode
from pytask import Product

def task_example(
data=PickleNode(path=UPath("s3://bucket/in.pkl")),
path: Annotated[Path, Product] = Path("out.txt"),
): ...
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))

result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()])

assert result.exit_code == ExitCode.OK
assert "s3://bucket/in.pkl" in result.output


@pytest.mark.parametrize("protocol", ["file", "local"])
def test_collect_task_with_local_upath_protocol_node(runner, tmp_path, protocol):
pytest.importorskip("upath")

uri = _make_local_upath_uri(tmp_path / "in.pkl", protocol)

source = f"""
from pathlib import Path
from typing import Annotated

from upath import UPath

from pytask import PickleNode
from pytask import Product

def task_example(
data=PickleNode(path=UPath("{uri}")),
path: Annotated[Path, Product] = Path("out.txt"),
): ...
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))

result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()])

assert result.exit_code == ExitCode.OK
assert f"{tmp_path.name}/in.pkl" in result.output


def test_python_node_is_collected(runner, tmp_path):
source = """
from pytask import Product
Expand Down
37 changes: 37 additions & 0 deletions tests/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,20 @@
from _pytask.path import find_case_sensitive_path
from _pytask.path import find_closest_ancestor
from _pytask.path import find_common_ancestor
from _pytask.path import is_non_local_path
from _pytask.path import normalize_local_upath
from _pytask.path import relative_to
from _pytask.path import shorten_path
from pytask.path import import_path

if TYPE_CHECKING:
from collections.abc import Generator


def _make_local_upath_uri(path: Path, protocol: str) -> str:
return f"{protocol}:///{path.as_posix().lstrip('/')}"


@pytest.mark.parametrize(
("path", "source", "include_source", "expected"),
[
Expand Down Expand Up @@ -110,6 +117,36 @@ def test_find_common_ancestor(path_1, path_2, expectation, expected):
assert result == expected


def test_shorten_path_keeps_non_local_uri():
upath = pytest.importorskip("upath")

path = upath.UPath("s3://bucket/file.pkl")

assert shorten_path(path, [Path.cwd()]) == "s3://bucket/file.pkl"


@pytest.mark.parametrize("protocol", ["file", "local"])
def test_shorten_path_treats_local_upath_protocols_as_local(tmp_path, protocol):
upath = pytest.importorskip("upath")

path = upath.UPath(_make_local_upath_uri(tmp_path / "file.pkl", protocol))

assert not is_non_local_path(path)
assert shorten_path(path, [tmp_path]) == f"{tmp_path.name}/file.pkl"


@pytest.mark.parametrize("protocol", ["file", "local"])
def test_normalize_local_upath_strips_windows_drive_prefix(monkeypatch, protocol):
upath = pytest.importorskip("upath")

monkeypatch.setattr(sys, "platform", "win32")
path = upath.UPath(f"{protocol}:///C:/tmp/file.pkl")

result = normalize_local_upath(path)

assert result.as_posix() == "C:/tmp/file.pkl"


@pytest.mark.skipif(sys.platform != "win32", reason="Only works on Windows.")
@pytest.mark.parametrize(
("path", "existing_paths", "expected"),
Expand Down
Loading