Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and

## Unreleased

- [#822](https://github.com/pytask-dev/pytask/pull/822) fixes unstable signatures
for remote `UPath`-backed `PathNode`s and `PickleNode`s so unchanged remote inputs
are no longer reported as missing from the state database on subsequent runs.
- [#820](https://github.com/pytask-dev/pytask/pull/820) fixes collection and node
display for remote `UPath`-backed nodes, while preserving correct handling of local
`file://` and `local://` `UPath`s across platforms.
Expand Down
24 changes: 22 additions & 2 deletions src/_pytask/_hashlib.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from __future__ import annotations

import hashlib
import os
import sys
from contextlib import suppress
from pathlib import Path
from typing import Any

from upath import UPath

_LOCAL_UPATH_PROTOCOLS = frozenset(("", "file", "local"))
_WINDOWS_DRIVE_PREFIX_LENGTH = 3


if sys.version_info >= (3, 11): # pragma: no cover
from hashlib import file_digest
Expand Down Expand Up @@ -227,8 +233,22 @@ def hash_value(value: Any) -> int | str:
return 0xFCA86420
if isinstance(value, (tuple, list)):
value = "".join(str(hash_value(i)) for i in value)
if isinstance(value, Path):
value = str(value)
if isinstance(value, UPath):
if value.protocol in _LOCAL_UPATH_PROTOCOLS:
local_path = value.path
if (
sys.platform == "win32"
and local_path.startswith("/")
and len(local_path) >= _WINDOWS_DRIVE_PREFIX_LENGTH
and local_path[1].isalpha()
and local_path[2] == ":"
):
local_path = local_path[1:]
value = os.fspath(Path(local_path))
else:
value = str(value)
elif isinstance(value, os.PathLike):
value = os.fspath(value)
if isinstance(value, str):
value = value.encode()
if isinstance(value, bytes):
Expand Down
5 changes: 1 addition & 4 deletions tests/test_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import cloudpickle
import pytest
import upath

from _pytask.collect import _find_shortest_uniquely_identifiable_name_for_tasks
from _pytask.collect import pytask_collect_node
Expand Down Expand Up @@ -196,8 +197,6 @@ def test_pytask_collect_node(session, path, node_info, expected):


def test_pytask_collect_remote_path_node_keeps_uri_name():
upath = pytest.importorskip("upath")

session = Session.from_config(
{"check_casing_of_paths": False, "paths": (Path.cwd(),), "root": Path.cwd()}
)
Expand All @@ -220,8 +219,6 @@ def test_pytask_collect_remote_path_node_keeps_uri_name():

@pytest.mark.parametrize("protocol", ["file", "local"])
def test_pytask_collect_local_upath_protocol_node_is_shortened(tmp_path, protocol):
upath = pytest.importorskip("upath")

session = Session.from_config(
{"check_casing_of_paths": False, "paths": (tmp_path,), "root": tmp_path}
)
Expand Down
4 changes: 0 additions & 4 deletions tests/test_collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,6 @@ def test_task_name_is_shortened(runner, tmp_path):


def test_collect_task_with_remote_upath_node(runner, tmp_path):
pytest.importorskip("upath")

source = """
from pathlib import Path
from typing import Annotated
Expand All @@ -427,8 +425,6 @@ def task_example(

@pytest.mark.parametrize("protocol", ["file", "local"])
def test_collect_task_with_local_upath_protocol_node(runner, tmp_path, protocol):
pytest.importorskip("upath")

uri = _make_local_upath_uri(tmp_path / "in.pkl", protocol)

source = f"""
Expand Down
28 changes: 28 additions & 0 deletions tests/test_hashlib.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from __future__ import annotations

import os
from pathlib import Path

import pytest
import upath

from _pytask._hashlib import hash_value


class RemotePathLike(os.PathLike[str]):
def __init__(self, value: str) -> None:
self.value = value

def __fspath__(self) -> str:
return self.value


@pytest.mark.parametrize(
("value", "expected"),
[
Expand All @@ -24,8 +34,26 @@
Path("file.py"),
"48b38abeefb3ba2622b6d1534d36c1ffd9b4deebf2cd71e4af8a33723e734ada",
),
(
RemotePathLike("s3://bucket/file.pkl"),
"5bbedd1ab74242143481060b901083e77080661d97003b96e0cbae3a887ebce6",
),
],
)
def test_hash_value(value, expected):
hash_ = hash_value(value)
assert hash_ == expected


def test_hash_value_of_remote_upath():
hash_ = hash_value(upath.UPath("s3://bucket/file.pkl"))

assert hash_ == "5bbedd1ab74242143481060b901083e77080661d97003b96e0cbae3a887ebce6"


@pytest.mark.parametrize("protocol", ["file", "local"])
def test_hash_value_of_local_upath_matches_path(tmp_path, protocol):
path = tmp_path / "file.pkl"
upath_value = upath.UPath(f"{protocol}:///{path.as_posix().lstrip('/')}")

assert hash_value(upath_value) == hash_value(path)
26 changes: 26 additions & 0 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import hashlib
import pickle
import sys
from pathlib import Path
from typing import cast

import cloudpickle
import pytest
import upath

from pytask import NodeInfo
from pytask import PathNode
Expand Down Expand Up @@ -118,6 +121,29 @@ def test_hash_of_pickle_node(tmp_path, value, exists, expected):
assert state is expected


@pytest.mark.parametrize("node_cls", [PathNode, PickleNode])
def test_signature_of_remote_upath_node(node_cls):
node = node_cls(name="test", path=cast("Path", upath.UPath("s3://bucket/file.pkl")))

expected = hashlib.sha256(
b"5bbedd1ab74242143481060b901083e77080661d97003b96e0cbae3a887ebce6"
).hexdigest()

assert node.signature == expected


@pytest.mark.parametrize("node_cls", [PathNode, PickleNode])
@pytest.mark.parametrize("protocol", ["file", "local"])
def test_signature_of_local_upath_node_matches_path(tmp_path, node_cls, protocol):
path = tmp_path / "file.pkl"
upath_value = upath.UPath(f"{protocol}:///{path.as_posix().lstrip('/')}")

local_node = node_cls(name="test", path=path)
upath_node = node_cls(name="test", path=cast("Path", upath_value))

assert upath_node.signature == local_node.signature


@pytest.mark.parametrize(
("node", "protocol", "expected"),
[
Expand Down
7 changes: 1 addition & 6 deletions tests/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any

import pytest
import upath

from _pytask.path import _insert_missing_modules
from _pytask.path import _module_name_from_path
Expand Down Expand Up @@ -118,17 +119,13 @@ def test_find_common_ancestor(path_1, path_2, expectation, expected):


def test_shorten_path_keeps_non_local_uri():
upath = pytest.importorskip("upath")

path = upath.UPath("s3://bucket/file.pkl")

assert shorten_path(path, [Path.cwd()]) == "s3://bucket/file.pkl"


@pytest.mark.parametrize("protocol", ["file", "local"])
def test_shorten_path_treats_local_upath_protocols_as_local(tmp_path, protocol):
upath = pytest.importorskip("upath")

path = upath.UPath(_make_local_upath_uri(tmp_path / "file.pkl", protocol))

assert not is_non_local_path(path)
Expand All @@ -137,8 +134,6 @@ def test_shorten_path_treats_local_upath_protocols_as_local(tmp_path, protocol):

@pytest.mark.parametrize("protocol", ["file", "local"])
def test_normalize_local_upath_strips_windows_drive_prefix(monkeypatch, protocol):
upath = pytest.importorskip("upath")

monkeypatch.setattr(sys, "platform", "win32")
path = upath.UPath(f"{protocol}:///C:/tmp/file.pkl")

Expand Down
Loading