Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/osekit/public_api/export_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING

from osekit import config, setup_logging
from osekit.config import global_logging_context as glc
from osekit.core_api.audio_dataset import AudioDataset
from osekit.core_api.spectro_dataset import SpectroDataset
from osekit.public_api.analysis import AnalysisType
from osekit.public_api.dataset import Dataset
from osekit.utils.deserialization import deserialize_spectro_or_ltas_dataset

if TYPE_CHECKING:
from osekit.core_api.spectro_dataset import SpectroDataset


def write_analysis(
Expand Down Expand Up @@ -314,7 +318,7 @@ def main() -> None:
else None
)
sds = (
SpectroDataset.from_json(Path(args.sds_json))
deserialize_spectro_or_ltas_dataset(path=Path(args.sds_json))
if args.sds_json.lower() != "none"
else None
)
Expand Down
25 changes: 25 additions & 0 deletions src/osekit/utils/deserialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pathlib import Path

from osekit.core_api.ltas_dataset import LTASDataset
from osekit.core_api.spectro_dataset import SpectroDataset


def deserialize_spectro_or_ltas_dataset(path: Path) -> SpectroDataset | LTASDataset:
"""Return a ``LTASDataset`` or a ``SpectroDataset`` from the specified json file.

Parameters
----------
path: Path
Path to the json file.

Returns
-------
SpectroDataset | LTASDataset
The deserialized ``LTASDataset`` if ``nb_time_bins`` is set to an integer, else
the deserialized ``SpectroDataset``.

"""
try:
return LTASDataset.from_json(file=path)
except KeyError:
return SpectroDataset.from_json(file=path)
7 changes: 5 additions & 2 deletions tests/test_export_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from osekit import config
from osekit.core_api.audio_dataset import AudioDataset
from osekit.core_api.spectro_dataset import SpectroDataset
from osekit.public_api import export_analysis
from osekit.public_api.export_analysis import create_parser
from osekit.utils.job import Job
Expand Down Expand Up @@ -164,7 +163,11 @@ def mock_sds_json(path: Path) -> Path:
calls["sds_json"] = path
return path

monkeypatch.setattr(SpectroDataset, "from_json", mock_sds_json)
monkeypatch.setattr(
export_analysis,
"deserialize_spectro_or_ltas_dataset",
mock_sds_json,
)

def mock_write_analysis(*args: list, **kwargs: dict) -> None:
for k, v in kwargs.items():
Expand Down
28 changes: 28 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,29 @@
import time
from contextlib import nullcontext
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import pytest
from pandas import Timedelta

from osekit.core_api.ltas_dataset import LTASDataset
from osekit.core_api.spectro_dataset import SpectroDataset
from osekit.utils.audio_utils import Normalization, normalize
from osekit.utils.core_utils import (
file_indexes_per_batch,
get_closest_value_index,
locked,
nb_files_per_batch,
)
from osekit.utils.deserialization import deserialize_spectro_or_ltas_dataset
from osekit.utils.formatting_utils import aplose2raven
from osekit.utils.path_utils import is_absolute, move_tree

if TYPE_CHECKING:
from _pytest.monkeypatch import MonkeyPatch


@pytest.fixture
def aplose_dataframe() -> pd.DataFrame:
Expand Down Expand Up @@ -546,3 +553,24 @@ def test_normalization(
) -> None:
normalized = normalize(values=values, normalization=normalization)
assert np.array_equal(normalized, expected)


def test_deserialize_spectro_or_ltas_dataset(monkeypatch: MonkeyPatch) -> None:
def sds_deserialization(*args, **kwargs) -> str: # noqa: ANN002, ANN003
return "SpectroDataset"

def ltasds_deserialization(*args, **kwargs) -> str: # noqa: ANN002, ANN003
return "LTASDataset"

monkeypatch.setattr(SpectroDataset, "from_json", sds_deserialization)
monkeypatch.setattr(LTASDataset, "from_json", ltasds_deserialization)

assert deserialize_spectro_or_ltas_dataset(path=Path()) == "LTASDataset"

def raise_key_error(*args, **kwargs) -> None: # noqa: ANN002, ANN003
msg = "'nb_time_bins'"
raise KeyError(msg)

monkeypatch.setattr(LTASDataset, "from_json", raise_key_error)

assert deserialize_spectro_or_ltas_dataset(path=Path()) == "SpectroDataset"