From 13261a1a77e7d9ff2952add1dd3db498b9f22e42 Mon Sep 17 00:00:00 2001 From: David Langerman Date: Sat, 31 Jan 2026 12:42:30 -0500 Subject: [PATCH] Allow disabling dltype via decorators and environment variables --- .github/workflows/pr.yaml | 2 +- .github/workflows/publish.yaml | 5 +++++ README.md | 13 +++++++++++++ dltype/_lib/_constants.py | 32 ++++++++++++++++++++++++++++++- dltype/_lib/_core.py | 26 +++++++++++++++++++++---- dltype/tests/dltype_test.py | 35 ++++++++++++++++++++++++++++++++++ pyproject.toml | 5 +++-- uv.lock | 27 +++++++++++++++++++++++++- 8 files changed, 136 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index a149a32..fdb0fa5 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -24,7 +24,7 @@ jobs: - name: Run prek run: | uv python install - uv tool run prek --all-files + uv tool run prek --all-files --skip no-commit-to-branch - name: Run ruff check run: | diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 42cb592..bb19f3e 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -26,6 +26,11 @@ jobs: - name: Install the project run: uv sync --locked --all-extras --dev + - name: Run prek + run: | + uv python install + uv tool run prek --all-files --skip no-commit-to-branch + - name: Run tests run: | uv run pytest diff --git a/README.md b/README.md index 91b210a..77db506 100644 --- a/README.md +++ b/README.md @@ -279,6 +279,19 @@ def free_function(tensor: FloatTensor["batch dim1"]) -> None: - `BoolTensor`: For boolean tensors - `TensorTypeBase`: Base class for any tensor which does not enforce any specific datatype, feel free to add custom validation logic by overriding the `check` method. +## Disabling checks + +For some scenarios such as benchmarking tight loops, it may be desirable to disable type checking. +To do this, you may provide `enabled=False` to any @dltyped decorator. +If you would like to disable checking for an entire project you may use the `DLTYPE_DISABLE=1` environment variable. + +> [!NOTE] +The environment variable is only checked once at import time. + +## Debugging + +If you run into issues with a dltyped decorator and would like to see detailed stack traces you may turn on debug mode via the environment variable `DLTYPE_DEBUG_MODE=1`. + ## Limitations - In the current implementation, _every_ call will be checked, the performance overhead on most systems should be negligible (OTOO microseconds). diff --git a/dltype/_lib/_constants.py b/dltype/_lib/_constants.py index 27321ed..5d256fd 100644 --- a/dltype/_lib/_constants.py +++ b/dltype/_lib/_constants.py @@ -1,8 +1,38 @@ """Constants related to the dltype library.""" import typing +import warnings + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class _Env(BaseSettings): + """Environment variables controlling dltype behavior.""" + + model_config = SettingsConfigDict( + frozen=True, + env_prefix="DLTYPE_", + case_sensitive=False, + ) + + DISABLE: bool = False + """Disable all dltype checking.""" + + DEBUG_MODE: bool = False + """If true, set debug mode enabled for debugging library issues.""" + # Constants +__env = _Env() PYDANTIC_INFO_KEY: typing.Final = "__dltype__" -DEBUG_MODE: typing.Final = False +DEBUG_MODE: typing.Final = __env.DEBUG_MODE MAX_ACCEPTABLE_EVALUATION_TIME_NS: typing.Final = int(5e9) # 5ms +GLOBAL_DISABLE: typing.Final = __env.DISABLE + + +if GLOBAL_DISABLE: + warnings.warn( + "DLType disabled via environment variable, all decorated functions not explicitly marked with enable=True will be turned off.", + UserWarning, + stacklevel=1, + ) diff --git a/dltype/_lib/_core.py b/dltype/_lib/_core.py index 2976c0a..368eede 100644 --- a/dltype/_lib/_core.py +++ b/dltype/_lib/_core.py @@ -167,6 +167,8 @@ def _resolve_value( def dltyped( # noqa: C901, PLR0915 scope_provider: DLTypeScopeProvider | Literal["self"] | None = None, + *, + enabled: bool = not _constants.GLOBAL_DISABLE, ) -> Callable[[Callable[P, R]], Callable[P, R]]: """ Apply type checking to the decorated function. @@ -174,6 +176,7 @@ def dltyped( # noqa: C901, PLR0915 Args: scope_provider: An optional scope provider to use for type checking, if None, no scope provider is used, if 'self' is used, the first argument of the function is expected to be a DLTypeScopeProvider and the function must be a method. + enabled: if set to false, perform no type checking. Returns: A wrapper function with type checking @@ -181,7 +184,7 @@ def dltyped( # noqa: C901, PLR0915 """ def _inner_dltyped(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901, PLR0915 - if _dependency_utilities.is_torch_scripting(): + if _dependency_utilities.is_torch_scripting() or not enabled: # jit script doesn't support annotated type hints at all, we have no choice but to skip the type checking return func @@ -302,16 +305,25 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901 NT = TypeVar("NT", bound=NamedTuple) -def dltyped_namedtuple() -> Callable[[type[NT]], type[NT]]: +def dltyped_namedtuple( + *, + enabled: bool = not _constants.GLOBAL_DISABLE, +) -> Callable[[type[NT]], type[NT]]: """ Apply type checking to a NamedTuple class. + Args: + enabled: if set to false, perform no type checking. + Returns: A modified NamedTuple class with type checking on construction """ def _inner_dltyped_namedtuple(cls: type[NT]) -> type[NT]: + if not enabled: + return cls + # NOTE: NamedTuple isn't actually a class, it's a factory function that returns a new class so we can't use issubclass here if not ( isinstance(cls, type) and hasattr(cls, "_fields") and issubclass(cls, tuple) # pyright: ignore[reportUnnecessaryIsInstance] @@ -366,20 +378,26 @@ def validated_new(cls_inner: type[NT], *args: Any, **kwargs: Any) -> NT: # noqa DataclassT = TypeVar("DataclassT") -def dltyped_dataclass() -> Callable[[type[DataclassT]], type[DataclassT]]: +def dltyped_dataclass( + *, + enabled: bool = not _constants.GLOBAL_DISABLE, +) -> Callable[[type[DataclassT]], type[DataclassT]]: """ Apply type checking to a dataclass. This will validate all fields with DLType annotations during object construction. Works with both regular and frozen dataclasses. + Args: + enabled: if set to false, perform no type checking. + Returns: A modified dataclass with type checking on initialization """ def _inner_dltyped_dataclass(cls: type[DataclassT]) -> type[DataclassT]: - if _dependency_utilities.is_torch_scripting(): + if _dependency_utilities.is_torch_scripting() or not enabled: return cls # check that we are a dataclass, raise an error if not diff --git a/dltype/tests/dltype_test.py b/dltype/tests/dltype_test.py index 7d3325a..99a3a9c 100644 --- a/dltype/tests/dltype_test.py +++ b/dltype/tests/dltype_test.py @@ -1647,3 +1647,38 @@ class JaxNamedTuple(NamedTuple): with pytest.raises(dltype.DLTypeDtypeError): JaxNamedTuple(arr1=jax.numpy.zeros((4, 3, 2)), arr2=jax.numpy.zeros((4, 3), dtype=np.uint8)) + + +@pytest.mark.parametrize("enabled", [True, False]) +def test_disabling(enabled: bool) -> None: + @dltype.dltyped(enabled=enabled) + def checked(arg: Annotated[NPFloatArrayT, dltype.FloatTensor["b c h w"]]) -> None: + pass + + @dltype.dltyped_dataclass(enabled=enabled) + @dataclass + class Checked: + arg: Annotated[NPFloatArrayT, dltype.FloatTensor["b c h w"]] + + @dltype.dltyped_namedtuple(enabled=enabled) + class CheckedNT(NamedTuple): + arg: Annotated[NPFloatArrayT, dltype.FloatTensor["b c h w"]] + + bad_arr = np.zeros((1, 2, 3), dtype=np.float32) + good_arr = np.zeros((1, 2, 3, 4), dtype=np.float32) + + checked(good_arr) + Checked(arg=good_arr) + CheckedNT(arg=good_arr) + + if enabled: + with pytest.raises(dltype.DLTypeNDimsError): + checked(bad_arr) + with pytest.raises(dltype.DLTypeNDimsError): + Checked(arg=bad_arr) + with pytest.raises(dltype.DLTypeNDimsError): + CheckedNT(arg=bad_arr) + else: + checked(bad_arr) + Checked(arg=bad_arr) + CheckedNT(arg=bad_arr) diff --git a/pyproject.toml b/pyproject.toml index ecfda41..973a740 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ dev = [ [project] dependencies = [ 'pydantic>=2.0;python_version<"3.14"', - 'pydantic>=2.12;python_version>="3.14"' + 'pydantic>=2.12;python_version>="3.14"', + "pydantic-settings>=2.12.0" ] description = "An extremely lightweight typing library for torch tensors or numpy arrays. Supports runtime shape checking and data type validation." keywords = ["pytorch", "numpy", "shape check", "type check"] @@ -24,7 +25,7 @@ license-files = ["LICENSE"] name = "dltype" readme = "README.md" requires-python = ">=3.10" -version = "0.9.1" +version = "0.10.0" [project.optional-dependencies] jax = ["jax>=0.6.2"] diff --git a/uv.lock b/uv.lock index 45371f2..049da9b 100644 --- a/uv.lock +++ b/uv.lock @@ -158,10 +158,11 @@ wheels = [ [[package]] name = "dltype" -version = "0.9.1" +version = "0.10.0" source = { virtual = "." } dependencies = [ { name = "pydantic" }, + { name = "pydantic-settings" }, ] [package.optional-dependencies] @@ -199,6 +200,7 @@ requires-dist = [ { name = "numpy", marker = "extra == 'numpy'" }, { name = "pydantic", marker = "python_full_version < '3.14'", specifier = ">=2.0" }, { name = "pydantic", marker = "python_full_version >= '3.14'", specifier = ">=2.12" }, + { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "torch", marker = "extra == 'torch'", specifier = ">=1.11.0" }, ] provides-extras = ["jax", "numpy", "torch"] @@ -1088,6 +1090,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -1142,6 +1158,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + [[package]] name = "ruff" version = "0.14.14"