Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Run prek
run: |
uv python install
uv tool run prek --all-files
uv tool run prek --all-files --skip no-commit-to-branch

- name: Run ruff check
run: |
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ jobs:
- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Run prek
run: |
uv python install
uv tool run prek --all-files --skip no-commit-to-branch

- name: Run tests
run: |
uv run pytest
Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,19 @@ def free_function(tensor: FloatTensor["batch dim1"]) -> None:
- `BoolTensor`: For boolean tensors
- `TensorTypeBase`: Base class for any tensor which does not enforce any specific datatype, feel free to add custom validation logic by overriding the `check` method.

## Disabling checks

For some scenarios such as benchmarking tight loops, it may be desirable to disable type checking.
To do this, you may provide `enabled=False` to any @dltyped decorator.
If you would like to disable checking for an entire project you may use the `DLTYPE_DISABLE=1` environment variable.

> [!NOTE]
The environment variable is only checked once at import time.

## Debugging

If you run into issues with a dltyped decorator and would like to see detailed stack traces you may turn on debug mode via the environment variable `DLTYPE_DEBUG_MODE=1`.

## Limitations

- In the current implementation, _every_ call will be checked, the performance overhead on most systems should be negligible (OTOO microseconds).
Expand Down
32 changes: 31 additions & 1 deletion dltype/_lib/_constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,38 @@
"""Constants related to the dltype library."""

import typing
import warnings

from pydantic_settings import BaseSettings, SettingsConfigDict


class _Env(BaseSettings):
"""Environment variables controlling dltype behavior."""

model_config = SettingsConfigDict(
frozen=True,
env_prefix="DLTYPE_",
case_sensitive=False,
)

DISABLE: bool = False
"""Disable all dltype checking."""

DEBUG_MODE: bool = False
"""If true, set debug mode enabled for debugging library issues."""


# Constants
__env = _Env()
PYDANTIC_INFO_KEY: typing.Final = "__dltype__"
DEBUG_MODE: typing.Final = False
DEBUG_MODE: typing.Final = __env.DEBUG_MODE
MAX_ACCEPTABLE_EVALUATION_TIME_NS: typing.Final = int(5e9) # 5ms
GLOBAL_DISABLE: typing.Final = __env.DISABLE


if GLOBAL_DISABLE:
warnings.warn(
"DLType disabled via environment variable, all decorated functions not explicitly marked with enable=True will be turned off.",
UserWarning,
stacklevel=1,
)
26 changes: 22 additions & 4 deletions dltype/_lib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,24 @@ def _resolve_value(

def dltyped( # noqa: C901, PLR0915
scope_provider: DLTypeScopeProvider | Literal["self"] | None = None,
*,
enabled: bool = not _constants.GLOBAL_DISABLE,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Apply type checking to the decorated function.

Args:
scope_provider: An optional scope provider to use for type checking, if None, no scope provider is used, if 'self'
is used, the first argument of the function is expected to be a DLTypeScopeProvider and the function must be a method.
enabled: if set to false, perform no type checking.

Returns:
A wrapper function with type checking

"""

def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR0915
if _dependency_utilities.is_torch_scripting():
if _dependency_utilities.is_torch_scripting() or not enabled:
# jit script doesn't support annotated type hints at all, we have no choice but to skip the type checking
return func

Expand Down Expand Up @@ -302,16 +305,25 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901
NT = TypeVar("NT", bound=NamedTuple)


def dltyped_namedtuple() -> Callable[[type[NT]], type[NT]]:
def dltyped_namedtuple(
*,
enabled: bool = not _constants.GLOBAL_DISABLE,
) -> Callable[[type[NT]], type[NT]]:
"""
Apply type checking to a NamedTuple class.

Args:
enabled: if set to false, perform no type checking.

Returns:
A modified NamedTuple class with type checking on construction

"""

def _inner_dltyped_namedtuple(cls: type[NT]) -> type[NT]:
if not enabled:
return cls

# NOTE: NamedTuple isn't actually a class, it's a factory function that returns a new class so we can't use issubclass here
if not (
isinstance(cls, type) and hasattr(cls, "_fields") and issubclass(cls, tuple) # pyright: ignore[reportUnnecessaryIsInstance]
Expand Down Expand Up @@ -366,20 +378,26 @@ def validated_new(cls_inner: type[NT], *args: Any, **kwargs: Any) -> NT: # noqa
DataclassT = TypeVar("DataclassT")


def dltyped_dataclass() -> Callable[[type[DataclassT]], type[DataclassT]]:
def dltyped_dataclass(
*,
enabled: bool = not _constants.GLOBAL_DISABLE,
) -> Callable[[type[DataclassT]], type[DataclassT]]:
"""
Apply type checking to a dataclass.

This will validate all fields with DLType annotations during object construction.
Works with both regular and frozen dataclasses.

Args:
enabled: if set to false, perform no type checking.

Returns:
A modified dataclass with type checking on initialization

"""

def _inner_dltyped_dataclass(cls: type[DataclassT]) -> type[DataclassT]:
if _dependency_utilities.is_torch_scripting():
if _dependency_utilities.is_torch_scripting() or not enabled:
return cls

# check that we are a dataclass, raise an error if not
Expand Down
35 changes: 35 additions & 0 deletions dltype/tests/dltype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,3 +1647,38 @@ class JaxNamedTuple(NamedTuple):

with pytest.raises(dltype.DLTypeDtypeError):
JaxNamedTuple(arr1=jax.numpy.zeros((4, 3, 2)), arr2=jax.numpy.zeros((4, 3), dtype=np.uint8))


@pytest.mark.parametrize("enabled", [True, False])
def test_disabling(enabled: bool) -> None:
@dltype.dltyped(enabled=enabled)
def checked(arg: Annotated[NPFloatArrayT, dltype.FloatTensor["b c h w"]]) -> None:
pass

@dltype.dltyped_dataclass(enabled=enabled)
@dataclass
class Checked:
arg: Annotated[NPFloatArrayT, dltype.FloatTensor["b c h w"]]

@dltype.dltyped_namedtuple(enabled=enabled)
class CheckedNT(NamedTuple):
arg: Annotated[NPFloatArrayT, dltype.FloatTensor["b c h w"]]

bad_arr = np.zeros((1, 2, 3), dtype=np.float32)
good_arr = np.zeros((1, 2, 3, 4), dtype=np.float32)

checked(good_arr)
Checked(arg=good_arr)
CheckedNT(arg=good_arr)

if enabled:
with pytest.raises(dltype.DLTypeNDimsError):
checked(bad_arr)
with pytest.raises(dltype.DLTypeNDimsError):
Checked(arg=bad_arr)
with pytest.raises(dltype.DLTypeNDimsError):
CheckedNT(arg=bad_arr)
else:
checked(bad_arr)
Checked(arg=bad_arr)
CheckedNT(arg=bad_arr)
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ dev = [
[project]
dependencies = [
'pydantic>=2.0;python_version<"3.14"',
'pydantic>=2.12;python_version>="3.14"'
'pydantic>=2.12;python_version>="3.14"',
"pydantic-settings>=2.12.0"
]
description = "An extremely lightweight typing library for torch tensors or numpy arrays. Supports runtime shape checking and data type validation."
keywords = ["pytorch", "numpy", "shape check", "type check"]
Expand All @@ -24,7 +25,7 @@ license-files = ["LICENSE"]
name = "dltype"
readme = "README.md"
requires-python = ">=3.10"
version = "0.9.1"
version = "0.10.0"

[project.optional-dependencies]
jax = ["jax>=0.6.2"]
Expand Down
27 changes: 26 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.