From da2d60119221d7e9aef0f18114c7fd2bd5f1bc83 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 10 Feb 2026 11:15:09 +0100 Subject: [PATCH 1/3] add deserialize_spectro_or_ltas_dataset() helper function --- src/osekit/public_api/export_analysis.py | 3 ++- src/osekit/utils/core_utils.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/osekit/public_api/export_analysis.py b/src/osekit/public_api/export_analysis.py index 05de0859..709f2492 100644 --- a/src/osekit/public_api/export_analysis.py +++ b/src/osekit/public_api/export_analysis.py @@ -13,6 +13,7 @@ from osekit.core_api.spectro_dataset import SpectroDataset from osekit.public_api.analysis import AnalysisType from osekit.public_api.dataset import Dataset +from osekit.utils.core_utils import deserialize_spectro_or_ltas_dataset def write_analysis( @@ -314,7 +315,7 @@ def main() -> None: else None ) sds = ( - SpectroDataset.from_json(Path(args.sds_json)) + deserialize_spectro_or_ltas_dataset(path=Path(args.sds_json)) if args.sds_json.lower() != "none" else None ) diff --git a/src/osekit/utils/core_utils.py b/src/osekit/utils/core_utils.py index 3666c1ce..2a5d840f 100644 --- a/src/osekit/utils/core_utils.py +++ b/src/osekit/utils/core_utils.py @@ -9,6 +9,8 @@ from typing import TYPE_CHECKING from osekit.config import global_logging_context as glc +from osekit.core_api.ltas_dataset import LTASDataset +from osekit.core_api.spectro_dataset import SpectroDataset if TYPE_CHECKING: from pathlib import Path @@ -228,3 +230,24 @@ def get_closest_value_index(target: float, values: list[float]) -> int: (closest_lower_index, closest_upper_index), key=lambda i: abs(values[i] - target), ) + + +def deserialize_spectro_or_ltas_dataset(path: Path) -> SpectroDataset | LTASDataset: + """Return a ``LTASDataset`` or a ``SpectroDataset`` from the specified json file. + + Parameters + ---------- + path: Path + Path to the json file. + + Returns + ------- + SpectroDataset | LTASDataset + The deserialized ``LTASDataset`` if ``nb_time_bins`` is set to an integer, else + the deserialized ``SpectroDataset``. + + """ + try: + return LTASDataset.from_json(file=path) + except KeyError: + return SpectroDataset.from_json(file=path) From 75e894db899ff2f086feb5823137f49ed52f6478 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 10 Feb 2026 11:30:07 +0100 Subject: [PATCH 2/3] patch deserialize_spectro_or_ltas_dataset() function in export_analysis tests --- src/osekit/public_api/export_analysis.py | 7 +++++-- src/osekit/utils/core_utils.py | 23 ---------------------- src/osekit/utils/deserialization.py | 25 ++++++++++++++++++++++++ tests/test_export_analysis.py | 7 +++++-- 4 files changed, 35 insertions(+), 27 deletions(-) create mode 100644 src/osekit/utils/deserialization.py diff --git a/src/osekit/public_api/export_analysis.py b/src/osekit/public_api/export_analysis.py index 709f2492..c435743e 100644 --- a/src/osekit/public_api/export_analysis.py +++ b/src/osekit/public_api/export_analysis.py @@ -6,14 +6,17 @@ import logging import os from pathlib import Path +from typing import TYPE_CHECKING from osekit import config, setup_logging from osekit.config import global_logging_context as glc from osekit.core_api.audio_dataset import AudioDataset -from osekit.core_api.spectro_dataset import SpectroDataset from osekit.public_api.analysis import AnalysisType from osekit.public_api.dataset import Dataset -from osekit.utils.core_utils import deserialize_spectro_or_ltas_dataset +from osekit.utils.deserialization import deserialize_spectro_or_ltas_dataset + +if TYPE_CHECKING: + from osekit.core_api.spectro_dataset import SpectroDataset def write_analysis( diff --git a/src/osekit/utils/core_utils.py b/src/osekit/utils/core_utils.py index 2a5d840f..3666c1ce 100644 --- a/src/osekit/utils/core_utils.py +++ b/src/osekit/utils/core_utils.py @@ -9,8 +9,6 @@ from typing import TYPE_CHECKING from osekit.config import global_logging_context as glc -from osekit.core_api.ltas_dataset import LTASDataset -from osekit.core_api.spectro_dataset import SpectroDataset if TYPE_CHECKING: from pathlib import Path @@ -230,24 +228,3 @@ def get_closest_value_index(target: float, values: list[float]) -> int: (closest_lower_index, closest_upper_index), key=lambda i: abs(values[i] - target), ) - - -def deserialize_spectro_or_ltas_dataset(path: Path) -> SpectroDataset | LTASDataset: - """Return a ``LTASDataset`` or a ``SpectroDataset`` from the specified json file. - - Parameters - ---------- - path: Path - Path to the json file. - - Returns - ------- - SpectroDataset | LTASDataset - The deserialized ``LTASDataset`` if ``nb_time_bins`` is set to an integer, else - the deserialized ``SpectroDataset``. - - """ - try: - return LTASDataset.from_json(file=path) - except KeyError: - return SpectroDataset.from_json(file=path) diff --git a/src/osekit/utils/deserialization.py b/src/osekit/utils/deserialization.py new file mode 100644 index 00000000..37957f24 --- /dev/null +++ b/src/osekit/utils/deserialization.py @@ -0,0 +1,25 @@ +from pathlib import Path + +from osekit.core_api.ltas_dataset import LTASDataset +from osekit.core_api.spectro_dataset import SpectroDataset + + +def deserialize_spectro_or_ltas_dataset(path: Path) -> SpectroDataset | LTASDataset: + """Return a ``LTASDataset`` or a ``SpectroDataset`` from the specified json file. + + Parameters + ---------- + path: Path + Path to the json file. + + Returns + ------- + SpectroDataset | LTASDataset + The deserialized ``LTASDataset`` if ``nb_time_bins`` is set to an integer, else + the deserialized ``SpectroDataset``. + + """ + try: + return LTASDataset.from_json(file=path) + except KeyError: + return SpectroDataset.from_json(file=path) diff --git a/tests/test_export_analysis.py b/tests/test_export_analysis.py index 7b606ee2..de3655f1 100644 --- a/tests/test_export_analysis.py +++ b/tests/test_export_analysis.py @@ -8,7 +8,6 @@ from osekit import config from osekit.core_api.audio_dataset import AudioDataset -from osekit.core_api.spectro_dataset import SpectroDataset from osekit.public_api import export_analysis from osekit.public_api.export_analysis import create_parser from osekit.utils.job import Job @@ -164,7 +163,11 @@ def mock_sds_json(path: Path) -> Path: calls["sds_json"] = path return path - monkeypatch.setattr(SpectroDataset, "from_json", mock_sds_json) + monkeypatch.setattr( + export_analysis, + "deserialize_spectro_or_ltas_dataset", + mock_sds_json, + ) def mock_write_analysis(*args: list, **kwargs: dict) -> None: for k, v in kwargs.items(): From 9a64e640eaa9e71c17d62001cc4374379fc79825 Mon Sep 17 00:00:00 2001 From: Gautzilla Date: Tue, 10 Feb 2026 11:39:18 +0100 Subject: [PATCH 3/3] add deserialize_spectro_or_ltas_dataset() test --- tests/test_utils.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index a9b32de8..02f8bcda 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,12 +3,15 @@ import time from contextlib import nullcontext from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pandas as pd import pytest from pandas import Timedelta +from osekit.core_api.ltas_dataset import LTASDataset +from osekit.core_api.spectro_dataset import SpectroDataset from osekit.utils.audio_utils import Normalization, normalize from osekit.utils.core_utils import ( file_indexes_per_batch, @@ -16,9 +19,13 @@ locked, nb_files_per_batch, ) +from osekit.utils.deserialization import deserialize_spectro_or_ltas_dataset from osekit.utils.formatting_utils import aplose2raven from osekit.utils.path_utils import is_absolute, move_tree +if TYPE_CHECKING: + from _pytest.monkeypatch import MonkeyPatch + @pytest.fixture def aplose_dataframe() -> pd.DataFrame: @@ -546,3 +553,24 @@ def test_normalization( ) -> None: normalized = normalize(values=values, normalization=normalization) assert np.array_equal(normalized, expected) + + +def test_deserialize_spectro_or_ltas_dataset(monkeypatch: MonkeyPatch) -> None: + def sds_deserialization(*args, **kwargs) -> str: # noqa: ANN002, ANN003 + return "SpectroDataset" + + def ltasds_deserialization(*args, **kwargs) -> str: # noqa: ANN002, ANN003 + return "LTASDataset" + + monkeypatch.setattr(SpectroDataset, "from_json", sds_deserialization) + monkeypatch.setattr(LTASDataset, "from_json", ltasds_deserialization) + + assert deserialize_spectro_or_ltas_dataset(path=Path()) == "LTASDataset" + + def raise_key_error(*args, **kwargs) -> None: # noqa: ANN002, ANN003 + msg = "'nb_time_bins'" + raise KeyError(msg) + + monkeypatch.setattr(LTASDataset, "from_json", raise_key_error) + + assert deserialize_spectro_or_ltas_dataset(path=Path()) == "SpectroDataset"