Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/osekit/public_api/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,19 @@ def __init__(
representing the maximum number of averaged time bins.

"""
self._validate_sample_rate(sample_rate=sample_rate, fft=fft)

self.analysis_type = analysis_type
self.begin = begin
self.end = end
self.data_duration = data_duration
self.mode = mode
self.overlap = overlap
self.fft = fft
self.sample_rate = sample_rate
self.name = name
self.normalization = normalization
self.subtype = subtype
self.fft = fft
self.v_lim = v_lim
self.colormap = colormap
self.scale = scale
Expand All @@ -179,3 +181,45 @@ def is_spectro(self) -> bool:
AnalysisType.WELCH,
)
)

@property
def sample_rate(self) -> float | None:
"""Return the sample rate of the analysis."""
return self._sample_rate

@sample_rate.setter
def sample_rate(self, value: float | None) -> None:
"""Set the sample rate of the analysis."""
if self.fft is not None and value is not None:
self.fft.fs = value
self._sample_rate = value

@property
def fft(self) -> ShortTimeFFT | None:
"""Return the FFT used in the analysis."""
return self._fft

@fft.setter
def fft(self, value: ShortTimeFFT | None) -> None:
"""Set the FFT used in the analysis."""
if hasattr(self, "_sample_rate"):
self._validate_sample_rate(sample_rate=self.sample_rate, fft=value)
self._fft = value

@staticmethod
def _validate_sample_rate(
sample_rate: float | None,
fft: ShortTimeFFT | None,
) -> None:
if sample_rate is None:
return
if fft is None:
return
if fft.fs == sample_rate:
return
msg = (
rf"The sample rate of the analysis ({sample_rate} Hz) "
rf"does not match the sampling frequency of the "
rf"fft ({fft.fs} Hz)"
)
raise ValueError(msg)
88 changes: 88 additions & 0 deletions tests/test_public_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
from contextlib import nullcontext, AbstractContextManager
from copy import deepcopy
from pathlib import Path

Expand Down Expand Up @@ -681,6 +682,93 @@ def test_analysis_is_spectro(analysis: Analysis, expected: bool) -> None:
assert analysis.is_spectro is expected


def test_analysis_constructor_rejects_mismatched_fs() -> None:
with pytest.raises(
ValueError,
match="does not match",
):
Analysis(
analysis_type=AnalysisType.SPECTROGRAM,
sample_rate=32_000,
fft=ShortTimeFFT(hamming(1024), 1024, 48_000),
)


def test_analysis_rejects_setting_fft_with_wrong_fs() -> None:
analysis = Analysis(
analysis_type=AnalysisType.SPECTROGRAM,
sample_rate=48_000,
fft=ShortTimeFFT(hamming(1024), 1024, 48_000),
)

with pytest.raises(
ValueError,
match="does not match",
):
analysis.fft = ShortTimeFFT(hamming(1024), 1024, 32_000)


def test_analysis_sample_rate_propagates_to_fft() -> None:
analysis = Analysis(
analysis_type=AnalysisType.SPECTROGRAM,
sample_rate=48_000,
fft=ShortTimeFFT(hamming(1024), 1024, 48_000),
)

new_samplerate = 32_000
analysis.sample_rate = new_samplerate

assert analysis.sample_rate == new_samplerate
assert analysis.fft.fs == new_samplerate


@pytest.mark.parametrize(
("sample_rate", "fft", "expected"),
[
pytest.param(
None,
None,
nullcontext(),
id="no_sample_rate_nor_fft_shouldnt_raise",
),
pytest.param(
None,
ShortTimeFFT(hamming(1024), 1024, 48_000),
nullcontext(),
id="no_sample_rate_shouldnt_raise",
),
pytest.param(
48_000,
None,
nullcontext(),
id="no_fft_shouldnt_raise",
),
pytest.param(
48_000,
ShortTimeFFT(hamming(1024), 1024, 48_000),
nullcontext(),
id="matching_sample_rate_and_fft_shouldnt_raise",
),
pytest.param(
32_000,
ShortTimeFFT(hamming(1024), 1024, 48_000),
pytest.raises(ValueError, match="does not match"),
id="mismatching_sample_rate_and_fft_raises",
),
],
)
def test_analysis_validate_sample_rate(
sample_rate: float | None,
fft: ShortTimeFFT | None,
expected: AbstractContextManager,
) -> None:
with expected:
Analysis(AnalysisType.AUDIO)._validate_sample_rate(
sample_rate=sample_rate,
fft=fft,
)


@pytest.mark.parametrize(
("audio_files", "instrument", "analysis", "expected_data"),
[
Expand Down