Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# DL Type (Deep Learning Type Library)

This typing library is intended to replace jaxtyping for runtime type checking of torch tensors and numpy arrays.
This typing library is intended to replace jaxtyping for runtime type checking of torch tensors, Jax arrays, and numpy arrays.

In particular, we support two functions that beartype/jaxtype do not:

Expand All @@ -23,7 +23,7 @@ pip3 install dltype
```

> [!NOTE]
> dltype does not depend explicitly on torch or numpy, but you must have at least one of them installed at import time otherwise the import will fail.
> dltype does not depend explicitly on torch, jax, or numpy, but you must have at least one of them installed at import time otherwise the import will fail.

## Usage

Expand Down Expand Up @@ -195,18 +195,21 @@ def returned_tuple_func() -> tuple[Annotated[torch.Tensor, dltype.UInt8Tensor["b
return torch.zeros(1, 3, 1080, 1920, dtype=torch.uint8), 8
```

## Numpy and Tensor Mixing
## Numpy, jax, and Tensor Mixing

```python
from typing import Annotated

import jax
import torch
import numpy as np
from dltype import FloatTensor, dltyped

@dltyped()
def transform_tensors(
points: Annotated[np.ndarray, FloatTensor["N 3"]]
transform: Annotated[torch.Tensor, FloatTensor["3 3"]]
points: Annotated[np.ndarray, FloatTensor["N 3"]],
transform: Annotated[torch.Tensor, FloatTensor["3 3"]],
aux: Annotated[jax.Array, FloatTensor["N"]],
) -> Annotated[torch.Tensor, FloatTensor["N 3"]]:
return torch.from_numpy(points) @ transform
```
Expand Down
2 changes: 1 addition & 1 deletion dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR09
return func

@wraps(func)
@_dependency_utilities.torch_jit_unused # pyright: ignore[reportUnknownMemberType]
@_dependency_utilities.torch_jit_unused
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901
__tracebackhide__ = not _constants.DEBUG_MODE
nonlocal signature
Expand Down
13 changes: 12 additions & 1 deletion dltype/_lib/_dependency_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _empty_wrapper(fn: Callable[P, Ret]) -> Callable[P, Ret]:
import torch

# re-export for compatibility
torch_jit_unused = torch.jit.unused # pyright: ignore[reportUnknownVariableType, reportUnknownMemberType]
torch_jit_unused = torch.jit.unused
except ImportError:
torch_jit_unused = _empty_wrapper
torch = None
Expand All @@ -29,6 +29,11 @@ def _empty_wrapper(fn: Callable[P, Ret]) -> Callable[P, Ret]:
except ImportError:
np = None

try:
import jax
except ImportError:
jax = None


@cache
def is_torch_available() -> bool:
Expand All @@ -42,6 +47,12 @@ def is_numpy_available() -> bool:
return np is not None


@cache
def is_jax_available() -> bool:
"""Check if jax is available."""
return jax is not None


@cache
def is_np_float128_available() -> bool:
float_128_available = False
Expand Down
2 changes: 1 addition & 1 deletion dltype/_lib/_dltype_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def assert_context(self) -> None:
tensor_name=tensor_context.tensor_arg_name,
)

self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype
self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype # pyright: ignore[reportUnknownMemberType] (jax doesn't parametrize the numpy dtype correctly)
expected_shape = tensor_context.get_expected_shape(
tensor_context.tensor,
)
Expand Down
32 changes: 26 additions & 6 deletions dltype/_lib/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,52 @@

import typing

from dltype._lib import (
_dependency_utilities as _deps,
)
from dltype._lib import _dependency_utilities as _deps

# NOTE: the order of these is important, pyright assumes the last branch is taken
# so we get proper union type hint checking
if _deps.is_numpy_available() and not _deps.is_torch_available():
if _deps.is_numpy_available() and not _deps.is_torch_available() and not _deps.is_jax_available():
# numpy is here, jax and torch are not
import numpy as np
import numpy.typing as npt

DLtypeTensorT: typing.TypeAlias = npt.NDArray[typing.Any] # pyright: ignore[reportRedeclaration]
DLtypeDtypeT: typing.TypeAlias = npt.DTypeLike # pyright: ignore[reportRedeclaration]
SUPPORTED_TENSOR_TYPES: typing.Final = {np.ndarray}
elif _deps.is_torch_available() and not _deps.is_numpy_available():
elif not _deps.is_numpy_available() and _deps.is_torch_available():
# If numpy is not available jax cannot be installed so we have only torch.
import torch

DLtypeTensorT: typing.TypeAlias = torch.Tensor # pyright: ignore[reportRedeclaration]
DLtypeDtypeT: typing.TypeAlias = torch.dtype # pyright: ignore[reportRedeclaration]
SUPPORTED_TENSOR_TYPES: typing.Final = {torch.Tensor} # pyright: ignore[reportGeneralTypeIssues, reportConstantRedefinition]
elif _deps.is_numpy_available() and _deps.is_torch_available():
elif _deps.is_numpy_available() and not _deps.is_torch_available() and _deps.is_jax_available():
# we have numpy and jax but not torch
import jax
import numpy as np
import numpy.typing as npt

DLtypeTensorT: typing.TypeAlias = jax.Array | npt.NDArray[typing.Any] # pyright: ignore[reportRedeclaration]
DLtypeDtypeT: typing.TypeAlias = npt.DTypeLike # pyright: ignore[reportRedeclaration]
SUPPORTED_TENSOR_TYPES: typing.Final = {np.ndarray, jax.Array} # pyright: ignore[reportGeneralTypeIssues, reportConstantRedefinition]
elif _deps.is_numpy_available() and _deps.is_torch_available() and not _deps.is_jax_available():
# we have torch and numpy but not jax
import numpy as np
import numpy.typing as npt
import torch

DLtypeTensorT: typing.TypeAlias = torch.Tensor | npt.NDArray[typing.Any] # pyright: ignore[reportRedeclaration]
DLtypeDtypeT: typing.TypeAlias = torch.dtype | npt.DTypeLike # pyright: ignore[reportRedeclaration]
SUPPORTED_TENSOR_TYPES: typing.Final = {torch.Tensor, np.ndarray} # pyright: ignore[reportGeneralTypeIssues, reportConstantRedefinition]
elif _deps.is_numpy_available() and _deps.is_torch_available() and _deps.is_jax_available():
# we have all three
import jax
import numpy as np
import numpy.typing as npt
import torch

DLtypeTensorT: typing.TypeAlias = jax.Array | npt.NDArray[typing.Any] | torch.Tensor # pyright: ignore[reportRedeclaration]
DLtypeDtypeT: typing.TypeAlias = npt.DTypeLike | torch.dtype # pyright: ignore[reportRedeclaration]
SUPPORTED_TENSOR_TYPES: typing.Final = {np.ndarray, jax.Array, torch.Tensor} # pyright: ignore[reportGeneralTypeIssues, reportConstantRedefinition]
else:
_deps.raise_for_missing_dependency()
4 changes: 2 additions & 2 deletions dltype/_lib/_tensor_type_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def check(
tensor_name=tensor_name,
)

if self.DTYPES and tensor.dtype not in self.DTYPES:
if self.DTYPES and tensor.dtype not in self.DTYPES: # pyright: ignore[reportUnknownMemberType] (jax doesn't parametrize the numpy dtype correctly)
raise _errors.DLTypeDtypeError(
expected=self.DTYPES,
received={tensor.dtype},
received={tensor.dtype}, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] (jax doesn't parametrize the numpy dtype correctly)
tensor_name=tensor_name,
)

Expand Down
3 changes: 2 additions & 1 deletion dltype/_lib/_universal_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

from dltype._lib import _tensor_type_base
from dltype._lib._dependency_utilities import (
is_jax_available,
is_numpy_available,
is_torch_available,
raise_for_missing_dependency,
)

if is_numpy_available():
if is_numpy_available() or is_jax_available():
from dltype._lib._numpy_tensors import (
BoolTensor as NumPyBoolTensor,
)
Expand Down
71 changes: 70 additions & 1 deletion dltype/tests/dltype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Annotated, Final, NamedTuple, TypeAlias
from unittest.mock import patch

import jax
import numpy as np
import numpy.typing as npt
import pytest
Expand Down Expand Up @@ -537,18 +538,22 @@ def forward(
f.name,
input_names=["input"],
output_names=["output"],
dynamo=False,
)

assert Path(f.name).exists()
assert Path(f.name).stat().st_size > 0

with pytest.raises(dltype.DLTypeNDimsError):
with pytest.raises(
dltype.DLTypeNDimsError, match="Invalid number of dimensions, tensor=x expected ndims=4 actual=3"
):
torch.onnx.export(
_DummyModule(),
(torch.rand(1, 2, 3),),
f.name,
input_names=["input"],
output_names=["output"],
dynamo=False,
)


Expand Down Expand Up @@ -1568,3 +1573,67 @@ def func(

with pytest.raises(dltype.DLTypeShapeError):
func((torch.zeros(1, 1, 3), torch.zeros(3, 2, 1), 1))


def test_jax() -> None:
@dltype.dltyped()
def func(
arr: Annotated[jax.Array, dltype.FloatTensor["1 2 3"]],
) -> Annotated[jax.Array, dltype.FloatTensor["3 2 1"]]:
return arr.transpose(2, 1, 0)

func(jax.numpy.zeros((1, 2, 3), dtype=np.float32))

with pytest.raises(dltype.DLTypeShapeError):
func(jax.numpy.zeros((1, 2, 4), dtype=np.float32))

with pytest.raises(dltype.DLTypeDtypeError):
func(jax.numpy.zeros((1, 2, 3), dtype=np.int8))

@dltype.dltyped_dataclass()
@dataclass(frozen=True, slots=True)
class JaxDataclass:
arr1: Annotated[jax.Array, dltype.UInt8Tensor["*batch chann feat"]]
arr2: Annotated[jax.Array, dltype.Float32Tensor["*batch chann"]]

JaxDataclass(
arr1=jax.numpy.zeros((4, 3, 2), dtype=np.uint8),
arr2=jax.numpy.zeros(
(4, 3),
dtype=np.float32,
),
)

with pytest.raises(dltype.DLTypeShapeError):
JaxDataclass(
arr1=jax.numpy.zeros((4, 4, 2), dtype=np.uint8), arr2=jax.numpy.zeros((4, 3), dtype=np.float32)
)

JaxDataclass(
arr1=jax.numpy.zeros((8, 4, 3, 2), dtype=np.uint8), arr2=jax.numpy.zeros((8, 4, 3), dtype=np.float32)
)

class JaxBaseModel(BaseModel):
arr1: Annotated[jax.Array, dltype.FloatTensor["batch chann=3 feat"]]
arr2: Annotated[jax.Array, dltype.FloatTensor["batch chann=3"]]

JaxBaseModel(arr1=jax.numpy.zeros((1, 3, 2)), arr2=jax.numpy.zeros((1, 3)))

with pytest.raises(dltype.DLTypeShapeError):
JaxBaseModel(arr1=jax.numpy.zeros((4, 4, 2)), arr2=jax.numpy.zeros((4, 3)))

with pytest.raises(dltype.DLTypeDtypeError):
JaxBaseModel(arr1=jax.numpy.zeros((4, 3, 2)), arr2=jax.numpy.zeros((4, 3), dtype=np.uint8))

@dltype.dltyped_namedtuple()
class JaxNamedTuple(NamedTuple):
arr1: Annotated[jax.Array, dltype.FloatTensor["batch chann=3 feat"]]
arr2: Annotated[jax.Array, dltype.FloatTensor["batch chann=3"]]

JaxNamedTuple(arr1=jax.numpy.zeros((1, 3, 2)), arr2=jax.numpy.zeros((1, 3)))

with pytest.raises(dltype.DLTypeShapeError):
JaxNamedTuple(arr1=jax.numpy.zeros((4, 4, 2)), arr2=jax.numpy.zeros((4, 3)))

with pytest.raises(dltype.DLTypeDtypeError):
JaxNamedTuple(arr1=jax.numpy.zeros((4, 3, 2)), arr2=jax.numpy.zeros((4, 3), dtype=np.uint8))
23 changes: 22 additions & 1 deletion dltype/tests/interop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
def clear_cached_available_fns() -> None:
"""Clear cached functions to ensure fresh imports."""
# Clear the cache for the dependency utilities
from dltype._lib._dependency_utilities import is_numpy_available, is_torch_available
from dltype._lib._dependency_utilities import is_jax_available, is_numpy_available, is_torch_available

is_torch_available.cache_clear()
is_numpy_available.cache_clear()
is_jax_available.cache_clear()


@pytest.fixture(autouse=True)
Expand All @@ -46,6 +47,13 @@ def mock_missing_torch() -> Iterator[None]:
yield


@pytest.fixture
def mock_missing_jax() -> Iterator[None]:
"""Mock jax as missing."""
with patch("dltype._lib._dependency_utilities.jax", None):
yield


def test_dltype_imports_without_torch(mock_missing_torch: None) -> None:
"""Test that dltype can be imported without torch."""
del sys.modules["torch"]
Expand All @@ -70,6 +78,18 @@ def test_dltype_imports_without_numpy(mock_missing_numpy: None) -> None:
assert (torch.bool,) == reloaded_dltype.BoolTensor.DTYPES


def test_dltype_imports_without_torch_with_jax(mock_missing_torch: None) -> None:
"""Test dltype works without jax and numpy."""
del sys.modules["torch"]

with pytest.raises(ImportError):
reload(torch)

reloaded_dltype = reload(dltype)

assert (np.bool,) == reloaded_dltype.BoolTensor.DTYPES


def test_dltype_imports_with_both() -> None:
"""Test that dltype can be imported with both torch and numpy."""
reloaded_dltype = reload(dltype)
Expand All @@ -82,6 +102,7 @@ def test_dltype_imports_with_both() -> None:
def test_dltype_asserts_import_error_with_neither(
mock_missing_numpy: None,
mock_missing_torch: None,
mock_missing_jax: None,
) -> None:
"""Test that dltype raises ImportError if neither torch nor numpy is available."""

Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ dev = [
"torch>=1.4.0",
"setuptools>=60.0.0",
"pyright>=1.1.407",
"pytest-cov>=7.0.0"
"pytest-cov>=7.0.0",
"jax>=0.6.2",
"onnxscript>=0.5.7"
]

[project]
Expand All @@ -19,9 +21,10 @@ license-files = ["LICENSE"]
name = "dltype"
readme = "README.md"
requires-python = ">=3.10"
version = "0.8.0"
version = "0.9.0"

[project.optional-dependencies]
jax = ["jax>=0.6.2"]
numpy = ["numpy"]
torch = ["torch>=1.11.0"]

Expand Down
Loading