diff --git a/src/osekit/public_api/export_analysis.py b/src/osekit/public_api/export_analysis.py index 05de0859..c435743e 100644 --- a/src/osekit/public_api/export_analysis.py +++ b/src/osekit/public_api/export_analysis.py @@ -6,13 +6,17 @@ import logging import os from pathlib import Path +from typing import TYPE_CHECKING from osekit import config, setup_logging from osekit.config import global_logging_context as glc from osekit.core_api.audio_dataset import AudioDataset -from osekit.core_api.spectro_dataset import SpectroDataset from osekit.public_api.analysis import AnalysisType from osekit.public_api.dataset import Dataset +from osekit.utils.deserialization import deserialize_spectro_or_ltas_dataset + +if TYPE_CHECKING: + from osekit.core_api.spectro_dataset import SpectroDataset def write_analysis( @@ -314,7 +318,7 @@ def main() -> None: else None ) sds = ( - SpectroDataset.from_json(Path(args.sds_json)) + deserialize_spectro_or_ltas_dataset(path=Path(args.sds_json)) if args.sds_json.lower() != "none" else None ) diff --git a/src/osekit/utils/deserialization.py b/src/osekit/utils/deserialization.py new file mode 100644 index 00000000..37957f24 --- /dev/null +++ b/src/osekit/utils/deserialization.py @@ -0,0 +1,25 @@ +from pathlib import Path + +from osekit.core_api.ltas_dataset import LTASDataset +from osekit.core_api.spectro_dataset import SpectroDataset + + +def deserialize_spectro_or_ltas_dataset(path: Path) -> SpectroDataset | LTASDataset: + """Return a ``LTASDataset`` or a ``SpectroDataset`` from the specified json file. + + Parameters + ---------- + path: Path + Path to the json file. + + Returns + ------- + SpectroDataset | LTASDataset + The deserialized ``LTASDataset`` if ``nb_time_bins`` is set to an integer, else + the deserialized ``SpectroDataset``. + + """ + try: + return LTASDataset.from_json(file=path) + except KeyError: + return SpectroDataset.from_json(file=path) diff --git a/tests/test_export_analysis.py b/tests/test_export_analysis.py index 7b606ee2..de3655f1 100644 --- a/tests/test_export_analysis.py +++ b/tests/test_export_analysis.py @@ -8,7 +8,6 @@ from osekit import config from osekit.core_api.audio_dataset import AudioDataset -from osekit.core_api.spectro_dataset import SpectroDataset from osekit.public_api import export_analysis from osekit.public_api.export_analysis import create_parser from osekit.utils.job import Job @@ -164,7 +163,11 @@ def mock_sds_json(path: Path) -> Path: calls["sds_json"] = path return path - monkeypatch.setattr(SpectroDataset, "from_json", mock_sds_json) + monkeypatch.setattr( + export_analysis, + "deserialize_spectro_or_ltas_dataset", + mock_sds_json, + ) def mock_write_analysis(*args: list, **kwargs: dict) -> None: for k, v in kwargs.items(): diff --git a/tests/test_utils.py b/tests/test_utils.py index a9b32de8..02f8bcda 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,12 +3,15 @@ import time from contextlib import nullcontext from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pandas as pd import pytest from pandas import Timedelta +from osekit.core_api.ltas_dataset import LTASDataset +from osekit.core_api.spectro_dataset import SpectroDataset from osekit.utils.audio_utils import Normalization, normalize from osekit.utils.core_utils import ( file_indexes_per_batch, @@ -16,9 +19,13 @@ locked, nb_files_per_batch, ) +from osekit.utils.deserialization import deserialize_spectro_or_ltas_dataset from osekit.utils.formatting_utils import aplose2raven from osekit.utils.path_utils import is_absolute, move_tree +if TYPE_CHECKING: + from _pytest.monkeypatch import MonkeyPatch + @pytest.fixture def aplose_dataframe() -> pd.DataFrame: @@ -546,3 +553,24 @@ def test_normalization( ) -> None: normalized = normalize(values=values, normalization=normalization) assert np.array_equal(normalized, expected) + + +def test_deserialize_spectro_or_ltas_dataset(monkeypatch: MonkeyPatch) -> None: + def sds_deserialization(*args, **kwargs) -> str: # noqa: ANN002, ANN003 + return "SpectroDataset" + + def ltasds_deserialization(*args, **kwargs) -> str: # noqa: ANN002, ANN003 + return "LTASDataset" + + monkeypatch.setattr(SpectroDataset, "from_json", sds_deserialization) + monkeypatch.setattr(LTASDataset, "from_json", ltasds_deserialization) + + assert deserialize_spectro_or_ltas_dataset(path=Path()) == "LTASDataset" + + def raise_key_error(*args, **kwargs) -> None: # noqa: ANN002, ANN003 + msg = "'nb_time_bins'" + raise KeyError(msg) + + monkeypatch.setattr(LTASDataset, "from_json", raise_key_error) + + assert deserialize_spectro_or_ltas_dataset(path=Path()) == "SpectroDataset"