Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: 0.7.17
version: 0.9.26
enable-cache: true

- name: Run prek
run: |
uv python install
uv tool run prek --all-files

- name: Run ruff check
run: |
uv run ruff format --check
uv run ruff check

- name: Run default unit tests
run: uv run pytest
run: |
uv run pytest

- name: Run type checking
run: uv run pyright --stats
Expand All @@ -37,12 +43,16 @@ jobs:
matrix:
numpy_version: [==1.22.0, '']
torch_version: [==1.11.0, '']
python_version: ['3.10', '3.11', '3.12', '3.13']
python_version: ['3.10', '3.11', '3.12', '3.13', '3.14']
exclude:
- numpy_version: ==1.22.0
python_version: '3.12'
- numpy_version: ==1.22.0
python_version: '3.13'
- numpy_version: ==1.22.0
python_version: '3.14'
- torch_version: ==1.11.0
python_version: '3.14'

runs-on: ubuntu-latest
steps:
Expand All @@ -52,7 +62,7 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: 0.7.17
version: 0.9.26
python-version: ${{ matrix.python_version }}
enable-cache: true

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: 0.7.17
version: 0.9.26

- name: Install the project
run: uv sync --locked --all-extras --dev
Expand Down
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
default_language_version:
python: python3.14

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.10
3.14
8 changes: 8 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
"[json]": {
"editor.formatOnSave": false
},
"[python]": {
"editor.codeActionsOnSave": {
"source.organizeImports.ruff": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.formatOnSaveMode": "file"
},
"[yaml]": {
"editor.formatOnSave": false
},
Expand Down
6 changes: 5 additions & 1 deletion dltype/_lib/_dependency_utilities.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Utilities to handle optional dependencies in dltype."""

from __future__ import annotations

import typing
from collections.abc import Callable
from functools import cache

if typing.TYPE_CHECKING:
from collections.abc import Callable

Ret = typing.TypeVar("Ret")
P = typing.ParamSpec("P")

Expand Down
2 changes: 1 addition & 1 deletion dltype/_lib/_dltype_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def assert_context(self) -> None:
tensor_name=tensor_context.tensor_arg_name,
)

self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype # pyright: ignore[reportUnknownMemberType] (jax doesn't parametrize the numpy dtype correctly)
self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype
expected_shape = tensor_context.get_expected_shape(
tensor_context.tensor,
)
Expand Down
6 changes: 5 additions & 1 deletion dltype/_lib/_errors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""Errors for the dltype library."""

from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from collections import abc

from dltype._lib._dtypes import SUPPORTED_TENSOR_TYPES, DLtypeDtypeT

if typing.TYPE_CHECKING:
from collections import abc


class DLTypeError(TypeError, ABC):
"""An error raised when a type assertion is hit."""
Expand Down
4 changes: 2 additions & 2 deletions dltype/_lib/_tensor_type_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def check(
tensor_name=tensor_name,
)

if self.DTYPES and tensor.dtype not in self.DTYPES: # pyright: ignore[reportUnknownMemberType] (jax doesn't parametrize the numpy dtype correctly)
if self.DTYPES and tensor.dtype not in self.DTYPES:
raise _errors.DLTypeDtypeError(
expected=self.DTYPES,
received={tensor.dtype}, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] (jax doesn't parametrize the numpy dtype correctly)
received={tensor.dtype},
tensor_name=tensor_name,
)

Expand Down
16 changes: 13 additions & 3 deletions dltype/tests/dltype_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# pyright: reportPrivateUsage=false, reportUnknownMemberType=false
"""Tests for common types used in deep learning."""

from __future__ import annotations

import re
import sys
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Annotated, Final, NamedTuple, TypeAlias
from typing import TYPE_CHECKING, Annotated, Final, NamedTuple, TypeAlias
from unittest.mock import patch

import jax
Expand All @@ -20,6 +22,9 @@

import dltype

if TYPE_CHECKING:
from collections.abc import Callable

np_rand = np.random.RandomState(42).rand
NPFloatArrayT: TypeAlias = npt.NDArray[np.float32 | np.float64]
NPIntArrayT: TypeAlias = npt.NDArray[np.int32 | np.uint16 | np.uint32 | np.uint8]
Expand Down Expand Up @@ -582,6 +587,10 @@ def forward(
warnings.simplefilter("ignore", TracerWarning)
torch.jit.trace(_DummyModule(), torch.rand(1, 2, 3, 4))

if sys.version_info.minor >= 14:
# torch doesn't support script in 3.14
return

scripted_module = torch.jit.script(_DummyModule())

scripted_module(torch.rand(1, 2, 3, 4))
Expand Down Expand Up @@ -650,6 +659,7 @@ def other_bad_function(
def test_bad_dimension_name() -> None:
with pytest.raises(SyntaxError):

@dltype.dltyped()
def bad_function( # pyright: ignore[reportUnusedFunction]
tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b?"]],
) -> None:
Expand Down Expand Up @@ -1405,7 +1415,7 @@ def signed_vs_unsigned(

np.testing.assert_allclose(
signed_vs_unsigned(
np.array([6], dtype=np.int32), # pyright: ignore[reportUnknownArgumentType]
np.array([6], dtype=np.int32),
np.array([8], dtype=np.uint32),
).numpy(),
np.array([48], dtype=np.uint8),
Expand Down
7 changes: 6 additions & 1 deletion dltype/tests/interop_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Test that dltype can operate with either numpy or torch installed."""

from __future__ import annotations

import sys
from collections.abc import Iterator
from importlib import reload
from typing import TYPE_CHECKING
from unittest.mock import patch

import numpy as np
Expand All @@ -11,6 +13,9 @@

import dltype

if TYPE_CHECKING:
from collections.abc import Iterator


@pytest.fixture(autouse=True)
def clear_cached_available_fns() -> None:
Expand Down
1 change: 1 addition & 0 deletions dltype/tests/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_parse_expression(
("isqrt(*)", {}),
("dim+", {}),
("dim%", {}),
("dim?", {}),
],
)
def test_parse_invalid_expression(expression: str, scope: dict[str, int]) -> None:
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@ dev = [
]

[project]
dependencies = ["pydantic>=2.0"]
dependencies = [
'pydantic>=2.0;python_version<"3.14"',
'pydantic>=2.12;python_version>="3.14"'
]
description = "An extremely lightweight typing library for torch tensors or numpy arrays. Supports runtime shape checking and data type validation."
keywords = ["pytorch", "numpy", "shape check", "type check"]
license = "Apache-2.0"
license-files = ["LICENSE"]
name = "dltype"
readme = "README.md"
requires-python = ">=3.10"
version = "0.9.0"
version = "0.9.1"

[project.optional-dependencies]
jax = ["jax>=0.6.2"]
Expand All @@ -45,7 +48,7 @@ addopts = "--cov=dltype --cov-report lcov:lcov.info --cov-report html"
[tool.ruff]
indent-width = 4
line-length = 110
target-version = "py310"
target-version = "py314"

[tool.ruff.format]
docstring-code-format = true
Expand Down
16 changes: 7 additions & 9 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@

# This script sets up the environment for the project

curl -LsSf https://astral.sh/uv/install.sh | sh
if ! command -v uv >/dev/null 2>&1
then
echo "uv not found, installing..."
curl -LsSf https://astral.sh/uv/install.sh | sh
fi

uv python install
uv tool install prek
uv sync

if ! command -v pre-commit >/dev/null 2>&1
then
echo "WARNING: pre-commit not found, please install it for a better dev experience"
echo "pip install pre-commit --break-system-packages"
echo "pre-commit install --install-hooks"
else
pre-commit install --install-hooks
fi
prek install --install-hooks
Loading