diff --git a/src/osekit/public_api/analysis.py b/src/osekit/public_api/analysis.py index 0a71ce81..2d51876c 100644 --- a/src/osekit/public_api/analysis.py +++ b/src/osekit/public_api/analysis.py @@ -148,17 +148,19 @@ def __init__( representing the maximum number of averaged time bins. """ + self._validate_sample_rate(sample_rate=sample_rate, fft=fft) + self.analysis_type = analysis_type self.begin = begin self.end = end self.data_duration = data_duration self.mode = mode self.overlap = overlap + self.fft = fft self.sample_rate = sample_rate self.name = name self.normalization = normalization self.subtype = subtype - self.fft = fft self.v_lim = v_lim self.colormap = colormap self.scale = scale @@ -179,3 +181,45 @@ def is_spectro(self) -> bool: AnalysisType.WELCH, ) ) + + @property + def sample_rate(self) -> float | None: + """Return the sample rate of the analysis.""" + return self._sample_rate + + @sample_rate.setter + def sample_rate(self, value: float | None) -> None: + """Set the sample rate of the analysis.""" + if self.fft is not None and value is not None: + self.fft.fs = value + self._sample_rate = value + + @property + def fft(self) -> ShortTimeFFT | None: + """Return the FFT used in the analysis.""" + return self._fft + + @fft.setter + def fft(self, value: ShortTimeFFT | None) -> None: + """Set the FFT used in the analysis.""" + if hasattr(self, "_sample_rate"): + self._validate_sample_rate(sample_rate=self.sample_rate, fft=value) + self._fft = value + + @staticmethod + def _validate_sample_rate( + sample_rate: float | None, + fft: ShortTimeFFT | None, + ) -> None: + if sample_rate is None: + return + if fft is None: + return + if fft.fs == sample_rate: + return + msg = ( + rf"The sample rate of the analysis ({sample_rate} Hz) " + rf"does not match the sampling frequency of the " + rf"fft ({fft.fs} Hz)" + ) + raise ValueError(msg) diff --git a/tests/test_public_api.py b/tests/test_public_api.py index 999ac4b7..1a11b5c5 100644 --- a/tests/test_public_api.py +++ b/tests/test_public_api.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +from contextlib import nullcontext, AbstractContextManager from copy import deepcopy from pathlib import Path @@ -681,6 +682,93 @@ def test_analysis_is_spectro(analysis: Analysis, expected: bool) -> None: assert analysis.is_spectro is expected +def test_analysis_constructor_rejects_mismatched_fs() -> None: + with pytest.raises( + ValueError, + match="does not match", + ): + Analysis( + analysis_type=AnalysisType.SPECTROGRAM, + sample_rate=32_000, + fft=ShortTimeFFT(hamming(1024), 1024, 48_000), + ) + + +def test_analysis_rejects_setting_fft_with_wrong_fs() -> None: + analysis = Analysis( + analysis_type=AnalysisType.SPECTROGRAM, + sample_rate=48_000, + fft=ShortTimeFFT(hamming(1024), 1024, 48_000), + ) + + with pytest.raises( + ValueError, + match="does not match", + ): + analysis.fft = ShortTimeFFT(hamming(1024), 1024, 32_000) + + +def test_analysis_sample_rate_propagates_to_fft() -> None: + analysis = Analysis( + analysis_type=AnalysisType.SPECTROGRAM, + sample_rate=48_000, + fft=ShortTimeFFT(hamming(1024), 1024, 48_000), + ) + + new_samplerate = 32_000 + analysis.sample_rate = new_samplerate + + assert analysis.sample_rate == new_samplerate + assert analysis.fft.fs == new_samplerate + + +@pytest.mark.parametrize( + ("sample_rate", "fft", "expected"), + [ + pytest.param( + None, + None, + nullcontext(), + id="no_sample_rate_nor_fft_shouldnt_raise", + ), + pytest.param( + None, + ShortTimeFFT(hamming(1024), 1024, 48_000), + nullcontext(), + id="no_sample_rate_shouldnt_raise", + ), + pytest.param( + 48_000, + None, + nullcontext(), + id="no_fft_shouldnt_raise", + ), + pytest.param( + 48_000, + ShortTimeFFT(hamming(1024), 1024, 48_000), + nullcontext(), + id="matching_sample_rate_and_fft_shouldnt_raise", + ), + pytest.param( + 32_000, + ShortTimeFFT(hamming(1024), 1024, 48_000), + pytest.raises(ValueError, match="does not match"), + id="mismatching_sample_rate_and_fft_raises", + ), + ], +) +def test_analysis_validate_sample_rate( + sample_rate: float | None, + fft: ShortTimeFFT | None, + expected: AbstractContextManager, +) -> None: + with expected: + Analysis(AnalysisType.AUDIO)._validate_sample_rate( + sample_rate=sample_rate, + fft=fft, + ) + + @pytest.mark.parametrize( ("audio_files", "instrument", "analysis", "expected_data"), [