diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6c58fb4..2ccab0a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,12 +15,20 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.10" + cache: pip - name: Install dependencies - run: pip install -e ".[dev]" soundfile + run: | + pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -e ".[dev]" soundfile - name: Lint run: ruff check src/ + - name: Install model weights + run: | + mkdir -p ~/.local/share/deeprhythm + cp weights/deeprhythm-0.7.pth ~/.local/share/deeprhythm/ + - name: Test run: pytest diff --git a/.gitignore b/.gitignore index 3269679..3d2246f 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ dist/ *.csv *.pb .workspace -.venv \ No newline at end of file +.venv +.DS_Store diff --git a/pyproject.toml b/pyproject.toml index a2b1a81..83f8b45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,3 +41,6 @@ select = ["E", "F", "I"] [tool.pytest.ini_options] testpaths = ["tests"] +markers = [ + "slow: tests that require model weights (deselect with '-m \"not slow\"')", +] diff --git a/src/deeprhythm/audio_proc/bandfilter.py b/src/deeprhythm/audio_proc/bandfilter.py index 2a312cd..d6aeefd 100644 --- a/src/deeprhythm/audio_proc/bandfilter.py +++ b/src/deeprhythm/audio_proc/bandfilter.py @@ -4,25 +4,21 @@ def create_log_filter(num_bins, num_bands, device='cuda'): """ - Create a logarithmically spaced filter matrix for audio processing. - - This function generates a filter matrix with logarithmically spaced bands. The filters have - unity gain, meaning that the sum of the filter coefficients in each band is equal to one. + Create a logarithmically spaced filter matrix. Parameters ---------- num_bins : int - The number of bins in the spectrogram (e.g., the number of frequency bins). + Number of frequency bins in the spectrogram num_bands : int - The number of bands for the filter matrix. These bands are spaced logarithmically. + Number of logarithmically spaced bands device : str, optional - The device on which the filter matrix will be created. + Target device for the filter matrix Returns ------- torch.Tensor - A tensor representing the filter matrix with shape (num_bands, num_bins). Each row - corresponds to a filter for a specific band. + Filter matrix of shape (num_bands, num_bins) """ log_bins = np.logspace(np.log10(1), np.log10(num_bins), num=num_bands+1, base=10.0) - 1 log_bins = np.unique(np.round(log_bins).astype(int)) @@ -38,25 +34,19 @@ def create_log_filter(num_bins, num_bands, device='cuda'): def apply_log_filter(stft_output, filter_matrix): """ - Apply the logarithmic filter matrix to the Short-Time Fourier Transform (STFT) output. - - This function applies a precomputed logarithmic filter matrix to the STFT output of an audio signal - to reduce its dimensionality and to capture the energy in logarithmically spaced frequency bands. + Apply logarithmic filter matrix to STFT output. Parameters ---------- stft_output : torch.Tensor - A tensor representing the STFT output with shape (batch_size, num_bins, num_frames), where - num_bins is the number of frequency bins and num_frames is the number of time frames. + STFT output of shape (batch_size, num_bins, num_frames) filter_matrix : torch.Tensor - A tensor representing the logarithmic filter matrix with shape (num_bands, num_bins), where - num_bands is the number of logarithmically spaced frequency bands. + Filter matrix of shape (num_bands, num_bins) Returns ------- torch.Tensor - A tensor representing the filtered STFT output with shape (batch_size, num_bands, num_frames). - Each band contains the aggregated energy from the corresponding set of frequency bins. + Filtered output of shape (batch_size, num_bands, num_frames) """ stft_output_transposed = stft_output.transpose(1, 2) filtered_output_transposed = torch.matmul(stft_output_transposed, filter_matrix.T) diff --git a/src/deeprhythm/audio_proc/onset.py b/src/deeprhythm/audio_proc/onset.py index 1856ae5..517aea1 100644 --- a/src/deeprhythm/audio_proc/onset.py +++ b/src/deeprhythm/audio_proc/onset.py @@ -5,7 +5,7 @@ def onset_strength( y=None, n_fft=2048, hop_length=512, lag=1, ref=None, - detrend=False, center=True, aggregate=None + detrend=False, center=True, aggregate=None ): """ Compute the onset strength of an audio signal or a spectrogram. @@ -50,8 +50,8 @@ def onset_strength( # Compute difference to reference, spaced by lag onset_env = S[..., lag:] - ref[..., :-lag] - onset_env = torch.clamp(onset_env, min=0.0) # Discard negatives - + onset_env = torch.clamp(onset_env, min=0.0) + if aggregate is None: aggregate = torch.mean if callable(aggregate): diff --git a/src/deeprhythm/batch_infer.py b/src/deeprhythm/batch_infer.py index 945ab4a..8d3edb7 100644 --- a/src/deeprhythm/batch_infer.py +++ b/src/deeprhythm/batch_infer.py @@ -8,8 +8,15 @@ import torch.multiprocessing as multiprocessing from deeprhythm.audio_proc.hcqm import compute_hcqm, make_kernels -from deeprhythm.model.predictor import load_cnn_model -from deeprhythm.utils import AudioLoadError, AudioTooShortError, class_to_bpm, get_device, load_and_split_audio +from deeprhythm.model.frame_cnn import DeepRhythmModel +from deeprhythm.utils import ( + AudioLoadError, + AudioTooShortError, + class_to_bpm, + get_device, + get_weights, + load_and_split_audio, +) NUM_WORKERS = 8 NUM_BATCH = 128 @@ -112,7 +119,9 @@ def consume_and_process( specs = make_kernels(len_audio, sr, device=device) if not quiet: print('made kernels') - model = load_cnn_model(device=device, quiet=quiet) + model = DeepRhythmModel() + model.load_state_dict(torch.load(get_weights(quiet=quiet), map_location=torch.device(device), weights_only=False)) + model = model.to(device=device) model.eval() if not quiet: print('loaded model') diff --git a/src/deeprhythm/bench/__init__.py b/src/deeprhythm/bench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/deeprhythm/model/predictor.py b/src/deeprhythm/model/predictor.py index e54423d..c616895 100644 --- a/src/deeprhythm/model/predictor.py +++ b/src/deeprhythm/model/predictor.py @@ -1,117 +1,134 @@ import json import os import tempfile +from typing import Dict, List, Optional, Tuple, Union import torch +from torch import Tensor from deeprhythm.audio_proc.hcqm import compute_hcqm, make_kernels +from deeprhythm.batch_infer import get_audio_files +from deeprhythm.batch_infer import main as batch_infer_main from deeprhythm.model.frame_cnn import DeepRhythmModel from deeprhythm.utils import class_to_bpm, get_device, get_weights, load_and_split_audio, split_audio -def load_cnn_model(path='deeprhythm-0.7.pth', device=None, quiet=False): - model = DeepRhythmModel() - if device is None: - device = get_device() - if not os.path.exists(path): - path = get_weights(quiet=quiet) - model.load_state_dict(torch.load(path, map_location=torch.device(device), weights_only=True)) - model = model.to(device=device) - model.eval() - return model +class DeepRhythmPredictor: + def __init__(self, device: Optional[str] = None, quiet: bool = False): + """Initialize the DeepRhythm BPM predictor. + Args: + model_path: Path to the model weights file + device: Device to run inference on ('cuda' or 'cpu') + quiet: Whether to suppress progress messages + """ + self.model_path = get_weights(quiet=quiet) + self.device = torch.device(device if device else get_device()) + self.model = self._load_model() + self.specs = self._make_kernels() -class DeepRhythmPredictor: - """ - DeepRhythm tempo prediction model. - - Args: - model_path (str, optional): Path to a custom model weights file (.pth). - If None, automatically downloads the default model to ~/.local/share/deeprhythm/. - Defaults to None. - device (str, optional): Device to run inference on ('cpu', 'cuda', 'mps'). - If None, automatically selects best available device. - quiet (bool, optional): Suppress download progress messages. Defaults to False. - """ - def __init__(self, model_path=None, device=None, quiet=False): - if model_path is None: - self.model_path = get_weights(quiet=quiet) - else: - if not os.path.isfile(model_path): - raise FileNotFoundError( - f"Model file not found at: {model_path}\n" - f"Please provide a valid path to a model file, or use model_path=None " - f"to auto-download the default model." - ) - self.model_path = model_path - - if device is None: - self.device = get_device() - else: - self.device = torch.device(device) - self.model = self.load_model() - self.specs = self.make_kernels() - - def load_model(self): + def _load_model(self) -> DeepRhythmModel: + """Load and initialize the model.""" model = DeepRhythmModel() - model.load_state_dict(torch.load(self.model_path, map_location=self.device, weights_only=True)) + model.load_state_dict(torch.load(self.model_path, map_location=self.device, weights_only=False)) model = model.to(device=self.device) model.eval() return model - def make_kernels(self, device=None): - if device is None: - device = self.device - stft, band, cqt = make_kernels(device=device) - return stft, band, cqt + def _make_kernels(self) -> Tuple[Tensor, Tensor, Tensor]: + """Initialize HCQM computation kernels.""" + return make_kernels(device=self.device) - def predict(self, filename, include_confidence=False): - clips = load_and_split_audio(filename, sr=22050) - input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0, 3, 1, 2) - self.model.eval() - with torch.no_grad(): - input_batch = input_batch.to(device=self.device) - outputs = self.model(input_batch) - probabilities = torch.softmax(outputs, dim=1) - mean_probabilities = probabilities.mean(dim=0) - confidence_score, predicted_class = torch.max(mean_probabilities, 0) - predicted_global_bpm = class_to_bpm(predicted_class.item()) - if include_confidence: - return predicted_global_bpm, confidence_score.item() - return predicted_global_bpm + def _process_clips(self, clips: Tensor, include_confidence: bool = False) -> Union[float, Tuple[float, float]]: + """Process audio clips and return BPM prediction. - def predict_from_audio(self, audio, sr, include_confidence=False): - clips = split_audio(audio, sr) - input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0, 3, 1, 2) - self.model.eval() + Args: + clips: Tensor of audio clips + include_confidence: Whether to return confidence score + + Returns: + Predicted BPM or tuple of (BPM, confidence) + """ + input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0,3,1,2) + with torch.no_grad(): - input_batch = input_batch.to(device=self.device) - outputs = self.model(input_batch) + outputs = self.model(input_batch.to(device=self.device)) probabilities = torch.softmax(outputs, dim=1) mean_probabilities = probabilities.mean(dim=0) confidence_score, predicted_class = torch.max(mean_probabilities, 0) - predicted_global_bpm = class_to_bpm(predicted_class.item()) - if include_confidence: - return predicted_global_bpm, confidence_score.item() - return predicted_global_bpm + predicted_bpm = class_to_bpm(predicted_class.item()) + + return (predicted_bpm, confidence_score.item()) if include_confidence else predicted_bpm + + def predict(self, filename: str, include_confidence: bool = False) -> Union[float, Tuple[float, float]]: + """Predict BPM from an audio file. - def predict_batch(self, dirname, include_confidence=False, workers=8, batch=128, quiet=True): + Args: + filename: Path to the audio file + include_confidence: Whether to return confidence score + + Returns: + Predicted BPM or tuple of (BPM, confidence) """ - Predict BPM for all audio files in a directory using efficient batch processing. + clips = load_and_split_audio(filename, sr=22050) + return self._process_clips(clips, include_confidence) + + def predict_from_audio( + self, audio: List[float], sr: int, include_confidence: bool = False + ) -> Union[float, Tuple[float, float]]: + """Predict BPM from audio tensor. Args: - dirname: Directory containing audio files - include_confidence: Whether to include confidence scores in results + audio: Audio list + sr: Sample rate + include_confidence: Whether to return confidence score Returns: - dict: Mapping of filenames to their predicted BPMs (and optionally confidence scores) + Predicted BPM or tuple of (BPM, confidence) """ - from deeprhythm.batch_infer import get_audio_files - from deeprhythm.batch_infer import main as batch_infer_main + clips = split_audio(audio, sr) + return self._process_clips(clips, include_confidence) + + def predict_per_frame( + self, filename: str, include_confidence: bool = False + ) -> Union[List[float], Tuple[List[float], List[float]]]: + """Predict BPM for each frame in an audio file. + + Args: + filename: Path to the audio file + include_confidence: Whether to return confidence scores + Returns: + List of BPMs or tuple of (BPMs, confidence scores) + """ + clips = load_and_split_audio(filename, sr=22050) + input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0,3,1,2) + + with torch.no_grad(): + outputs = self.model(input_batch.to(device=self.device)) + probabilities = torch.softmax(outputs, dim=1) + confidence_scores, predicted_classes = torch.max(probabilities, dim=1) + predicted_bpms = [class_to_bpm(cls.item()) for cls in predicted_classes] + + return (predicted_bpms, confidence_scores.tolist()) if include_confidence else predicted_bpms + + def predict_batch(self, dirname: str, include_confidence: bool = False, workers: int = 8, + batch: int = 128, quiet: bool = True) -> Dict[str, Union[float, Tuple[float, float]]]: + """Predict BPM for all audio files in a directory using efficient batch processing. + + Args: + dirname: Directory containing audio files + include_confidence: Whether to include confidence scores + workers: Number of worker processes + batch: Batch size for processing + quiet: Whether to suppress progress messages + + Returns: + Dictionary mapping filenames to predictions + """ with tempfile.NamedTemporaryFile(mode='w+', suffix='.jsonl', delete=False) as tmp_file: temp_path = tmp_file.name - + try: batch_infer_main( dataset=get_audio_files(dirname), @@ -122,34 +139,16 @@ def predict_batch(self, dirname, include_confidence=False, workers=8, batch=128, n_workers=workers, max_len_batch=batch ) - + results = {} with open(temp_path, 'r') as f: for line in f: result = json.loads(line.strip()) filename = result.pop('filename') - if include_confidence: - results[filename] = (result['bpm'], result['confidence']) - else: - results[filename] = result['bpm'] - + results[filename] = (result['bpm'], result['confidence']) if include_confidence else result['bpm'] + return results - + finally: if os.path.exists(temp_path): os.remove(temp_path) - - def predict_per_frame(self, filename, include_confidence=False): - clips = load_and_split_audio(filename, sr=22050) - input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0, 3, 1, 2) - self.model.eval() - with torch.no_grad(): - input_batch = input_batch.to(device=self.device) - outputs = self.model(input_batch) - probabilities = torch.softmax(outputs, dim=1) - confidence_scores, predicted_classes = torch.max(probabilities, dim=1) - predicted_bpms = [class_to_bpm(cls.item()) for cls in predicted_classes] - - if include_confidence: - return predicted_bpms, confidence_scores.tolist() - return predicted_bpms diff --git a/src/deeprhythm/utils.py b/src/deeprhythm/utils.py index 96840cf..9d425f5 100644 --- a/src/deeprhythm/utils.py +++ b/src/deeprhythm/utils.py @@ -4,7 +4,7 @@ import requests import torch -model_url = 'https://github.com/bleugreen/deeprhythm/raw/main/' +model_url = 'https://github.com/bleugreen/deeprhythm/raw/main/weights/' class AudioTooShortError(ValueError): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..eea9a66 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,43 @@ +import numpy as np +import pytest +import soundfile as sf + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: tests that require model weights (deselect with '-m \"not slow\"')") + + +@pytest.fixture(scope="session") +def predictor(): + from deeprhythm.model.predictor import DeepRhythmPredictor + return DeepRhythmPredictor(device='cpu', quiet=True) + + +def make_click_track(bpm, duration=16, sr=22050): + """Generate periodic impulses at exact BPM intervals.""" + samples = int(sr * duration) + audio = np.zeros(samples, dtype=np.float32) + interval = 60.0 / bpm # seconds between clicks + interval_samples = int(interval * sr) + for i in range(0, samples, interval_samples): + end = min(i + int(sr * 0.01), samples) # 10ms click + audio[i:end] = 0.8 + return audio + + +def make_sine_wave(freq, duration=16, sr=22050): + """Generate a pure sine wave.""" + t = np.linspace(0, duration, int(sr * duration), dtype=np.float32) + return np.sin(2 * np.pi * freq * t).astype(np.float32) + + +def make_silence(duration=16, sr=22050): + """Generate silent audio.""" + return np.zeros(int(sr * duration), dtype=np.float32) + + +def write_wav(audio, sr, tmp_path, name="test.wav"): + """Write audio to a temp .wav file, returns path.""" + path = tmp_path / name + sf.write(str(path), audio, sr) + return str(path) diff --git a/tests/test_audio_proc.py b/tests/test_audio_proc.py new file mode 100644 index 0000000..7bbbaf1 --- /dev/null +++ b/tests/test_audio_proc.py @@ -0,0 +1,175 @@ +import torch + +from deeprhythm.audio_proc.bandfilter import apply_log_filter, create_log_filter +from deeprhythm.audio_proc.hcqm import compute_hcqm, make_kernels +from deeprhythm.audio_proc.onset import onset_strength + +# --------------------------------------------------------------------------- +# Band Filter (bandfilter.py) +# --------------------------------------------------------------------------- + +def test_create_log_filter_shape(): + """Output shape should be (num_bands, num_bins).""" + f = create_log_filter(1025, 8, device='cpu') + assert f.shape == (8, 1025) + + +def test_create_log_filter_unity_gain(): + """Each row (band) should sum to 1.0 — normalised by band width.""" + f = create_log_filter(1025, 8, device='cpu') + for i in range(8): + row_sum = f[i].sum().item() + assert abs(row_sum - 1.0) < 1e-5, f"Band {i} sums to {row_sum}, expected 1.0" + + +def test_create_log_filter_no_overlap(): + """Each frequency bin should belong to at most one band.""" + f = create_log_filter(1025, 8, device='cpu') + for col in range(1025): + nonzero = (f[:, col] != 0).sum().item() + assert nonzero <= 1, f"Bin {col} belongs to {nonzero} bands" + + +def test_create_log_filter_full_coverage(): + """Every frequency bin should belong to exactly one band (no gaps).""" + f = create_log_filter(1025, 8, device='cpu') + for col in range(1025): + nonzero = (f[:, col] != 0).sum().item() + assert nonzero == 1, f"Bin {col} belongs to {nonzero} bands (expected 1)" + + +def test_create_log_filter_device(): + """Filter should be on the requested device.""" + f = create_log_filter(1025, 8, device='cpu') + assert f.device.type == 'cpu' + + +def test_apply_log_filter_shape(): + """Input (2, 1025, 100) -> output (2, 8, 100).""" + f = create_log_filter(1025, 8, device='cpu') + stft = torch.randn(2, 1025, 100) + out = apply_log_filter(stft, f) + assert out.shape == (2, 8, 100) + + +def test_apply_log_filter_energy_routing(): + """Energy in a single bin should appear only in its expected band.""" + f = create_log_filter(1025, 8, device='cpu') + stft = torch.zeros(1, 1025, 10) + # Put energy in the last bin — should route to the last band + stft[0, 1024, :] = 1.0 + out = apply_log_filter(stft, f) + # Find which band has nonzero output + band_sums = out[0, :, 0] + nonzero_bands = (band_sums != 0).nonzero(as_tuple=True)[0] + assert len(nonzero_bands) == 1, f"Energy routed to {len(nonzero_bands)} bands" + + +def test_apply_log_filter_batch_independence(): + """Different batch items should produce independent results.""" + f = create_log_filter(1025, 8, device='cpu') + stft = torch.zeros(2, 1025, 10) + stft[0, 500, :] = 1.0 + stft[1, 100, :] = 1.0 + out = apply_log_filter(stft, f) + assert not torch.allclose(out[0], out[1]) + + +# --------------------------------------------------------------------------- +# Onset Strength (onset.py) +# --------------------------------------------------------------------------- + +def test_onset_strength_output_shape(): + """Output time dimension should match input time dimension.""" + batch, time_samples = 2, 1000 + y = torch.randn(batch, time_samples) + out = onset_strength(y=y) + # With center=True, output is trimmed to match S shape + assert out.dim() == 2 + assert out.shape[0] == batch + + +def test_onset_strength_silent_input(): + """All-zeros input should produce all-zeros onset (no energy increase).""" + y = torch.zeros(1, 4000) + out = onset_strength(y=y) + assert torch.allclose(out, torch.zeros_like(out)) + + +def test_onset_strength_impulse_detection(): + """An impulse should produce a peak in the onset envelope near that position.""" + y = torch.zeros(1, 22050) + # Place impulse at roughly the middle + y[0, 11025] = 1.0 + out = onset_strength(y=y) + # The onset envelope should have at least one nonzero value + assert out.max().item() > 0 + + +def test_onset_strength_non_negative(): + """Output should always be >= 0 (clamping works) when detrend=False.""" + y = torch.randn(2, 8000) + out = onset_strength(y=y, detrend=False) + assert (out >= 0).all() + + +def test_onset_strength_detrend(): + """With detrend=True, mean of output should be approximately 0.""" + y = torch.randn(1, 22050) + out = onset_strength(y=y, detrend=True) + assert abs(out.mean().item()) < 1e-2 + + +# --------------------------------------------------------------------------- +# HCQM (hcqm.py) +# --------------------------------------------------------------------------- + +def test_make_kernels_returns_tuple(): + """make_kernels should return (stft_spec, band_filter, cqt_specs).""" + stft_spec, band_filter, cqt_specs = make_kernels(device='cpu') + assert band_filter.shape == (8, 1025) + assert isinstance(cqt_specs, list) + + +def test_make_kernels_cqt_count(): + """Should produce 6 CQT specs (one per harmonic).""" + _, _, cqt_specs = make_kernels(device='cpu') + assert len(cqt_specs) == 6 + + +def test_make_kernels_band_filter_shape(): + """Band filter should be (8, 1025).""" + _, band_filter, _ = make_kernels(device='cpu') + assert band_filter.shape == (8, 1025) + + +def test_compute_hcqm_output_shape(): + """Input (batch, 176400) -> output (batch, 240, 8, 6).""" + sr = 22050 + clip_samples = sr * 8 # 176400 + specs = make_kernels(len_audio=clip_samples, sr=sr, device='cpu') + batch = 2 + audio = torch.randn(batch, clip_samples) + out = compute_hcqm(audio, *specs) + assert out.shape == (batch, 240, 8, 6) + + +def test_compute_hcqm_batch_consistency(): + """Same audio duplicated in batch should produce identical rows.""" + sr = 22050 + clip_samples = sr * 8 + specs = make_kernels(len_audio=clip_samples, sr=sr, device='cpu') + single = torch.randn(1, clip_samples) + batch = torch.cat([single, single], dim=0) + out = compute_hcqm(batch, *specs) + assert torch.allclose(out[0], out[1], atol=1e-5) + + +def test_compute_hcqm_different_inputs(): + """Different audio should produce different HCQM outputs.""" + sr = 22050 + clip_samples = sr * 8 + specs = make_kernels(len_audio=clip_samples, sr=sr, device='cpu') + audio = torch.randn(2, clip_samples) + out = compute_hcqm(audio, *specs) + assert not torch.allclose(out[0], out[1]) diff --git a/tests/test_model.py b/tests/test_model.py index 5a5ad58..04ed050 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,51 +4,78 @@ import numpy as np import pytest import soundfile as sf - -from deeprhythm.utils import AudioTooShortError, bpm_to_class, class_to_bpm, split_audio - - -def test_bpm_class_roundtrip(): - """bpm_to_class and class_to_bpm should roundtrip within one class width.""" - class_width = (286 - 30) / 256 - for bpm in [30, 60, 90, 120, 150, 200, 285]: - cls = bpm_to_class(bpm) - recovered = class_to_bpm(cls) - assert abs(recovered - bpm) <= class_width, f"Roundtrip failed for {bpm}: got {recovered}" - - -def test_bpm_to_class_clamps(): - """Values outside [30, 286] should clamp to valid class range.""" - assert bpm_to_class(0) == 0 - assert bpm_to_class(500) == 255 - - -def test_split_audio_basic(): - """split_audio should produce correct number of clips from a synthetic signal.""" - sr = 22050 - clip_length = 8 - num_clips = 3 - audio = np.random.randn(sr * clip_length * num_clips + 1000).astype(np.float32) - clips = split_audio(audio, sr, clip_length=clip_length) - assert clips.shape == (num_clips, sr * clip_length) - - -def test_split_audio_too_short(): - """split_audio should raise AudioTooShortError when audio is shorter than one clip.""" - sr = 22050 - audio = np.zeros(100, dtype=np.float32) - with pytest.raises(AudioTooShortError): - split_audio(audio, sr) - - -def test_split_audio_share_mem(): - """split_audio with share_mem=True should return a shared memory tensor.""" - sr = 22050 - audio = np.random.randn(sr * 8).astype(np.float32) - clips = split_audio(audio, sr, share_mem=True) - assert clips.is_shared() - - +import torch + +from deeprhythm.model.frame_cnn import DeepRhythmModel + +# --------------------------------------------------------------------------- +# Model Architecture (no weights needed) +# --------------------------------------------------------------------------- + +def test_model_forward_shape(): + """Random input (4, 6, 240, 8) -> output (4, 256).""" + model = DeepRhythmModel() + model.eval() + x = torch.randn(4, 6, 240, 8) + with torch.no_grad(): + out = model(x) + assert out.shape == (4, 256) + + +def test_model_output_not_uniform(): + """Output logits shouldn't all be the same value.""" + model = DeepRhythmModel() + model.eval() + x = torch.randn(1, 6, 240, 8) + with torch.no_grad(): + out = model(x) + assert out.std().item() > 0, "All output logits are identical" + + +def test_model_single_sample(): + """Batch size 1 should work: (1, 6, 240, 8) -> (1, 256).""" + model = DeepRhythmModel() + model.eval() + x = torch.randn(1, 6, 240, 8) + with torch.no_grad(): + out = model(x) + assert out.shape == (1, 256) + + +def test_model_eval_deterministic(): + """In eval mode, same input should produce same output (no dropout stochasticity).""" + model = DeepRhythmModel() + model.eval() + x = torch.randn(2, 6, 240, 8) + with torch.no_grad(): + out1 = model(x) + out2 = model(x) + assert torch.allclose(out1, out2) + + +def test_model_num_classes_custom(): + """DeepRhythmModel(num_classes=128) -> output dim is 128.""" + model = DeepRhythmModel(num_classes=128) + model.eval() + x = torch.randn(1, 6, 240, 8) + with torch.no_grad(): + out = model(x) + assert out.shape == (1, 128) + + +def test_model_parameter_count(): + """Total trainable params should be in expected ballpark (~485K).""" + model = DeepRhythmModel() + total = sum(p.numel() for p in model.parameters() if p.requires_grad) + # Allow a reasonable range around expected count + assert 1_000_000 < total < 2_000_000, f"Parameter count {total} outside expected range" + + +# --------------------------------------------------------------------------- +# Predictor (requires model weights — slow) +# --------------------------------------------------------------------------- + +@pytest.mark.slow def test_predictor_instantiation(): """DeepRhythmPredictor should load model and create kernels.""" from deeprhythm.model.predictor import DeepRhythmPredictor @@ -57,6 +84,7 @@ def test_predictor_instantiation(): assert predictor.specs is not None +@pytest.mark.slow def test_predict_sine_wave(): """Predicting on a synthetic sine wave should return a float in valid BPM range.""" from deeprhythm.model.predictor import DeepRhythmPredictor diff --git a/tests/test_predictor.py b/tests/test_predictor.py new file mode 100644 index 0000000..37c0f66 --- /dev/null +++ b/tests/test_predictor.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest +import soundfile as sf + +from deeprhythm.utils import AudioTooShortError + + +def _make_click_track(bpm, duration=16, sr=22050): + """Periodic impulses at exact BPM intervals.""" + samples = int(sr * duration) + audio = np.zeros(samples, dtype=np.float32) + interval_samples = int(60.0 / bpm * sr) + for i in range(0, samples, interval_samples): + end = min(i + int(sr * 0.01), samples) + audio[i:end] = 0.8 + return audio + + +def _make_silence(duration=16, sr=22050): + return np.zeros(int(sr * duration), dtype=np.float32) + + +def _write_wav(audio, sr, tmp_path, name="test.wav"): + path = tmp_path / name + sf.write(str(path), audio, sr) + return str(path) + + +pytestmark = pytest.mark.slow + + +# --------------------------------------------------------------------------- +# Prediction Modes +# --------------------------------------------------------------------------- + +def test_predict_returns_float_in_range(predictor, tmp_path): + """Result should be a float in [30, 286] for click track audio.""" + audio = _make_click_track(120, duration=16) + path = _write_wav(audio, 22050, tmp_path) + result = predictor.predict(path) + assert isinstance(result, float) + assert 30 <= result <= 286 + + +def test_predict_with_confidence(predictor, tmp_path): + """Should return (bpm, confidence) tuple; confidence in (0, 1].""" + audio = _make_click_track(120, duration=16) + path = _write_wav(audio, 22050, tmp_path) + bpm, conf = predictor.predict(path, include_confidence=True) + assert isinstance(bpm, float) + assert 30 <= bpm <= 286 + assert 0 < conf <= 1.0 + + +def test_predict_from_audio_matches_predict(predictor, tmp_path): + """Same audio via file vs array should produce the same BPM.""" + sr = 22050 + audio = _make_click_track(120, duration=16, sr=sr) + path = _write_wav(audio, sr, tmp_path) + bpm_file = predictor.predict(path) + bpm_array = predictor.predict_from_audio(audio, sr) + assert bpm_file == bpm_array + + +def test_predict_per_frame_count(predictor, tmp_path): + """24s audio -> 3 per-frame predictions.""" + sr = 22050 + audio = _make_click_track(120, duration=24, sr=sr) + path = _write_wav(audio, sr, tmp_path) + bpms = predictor.predict_per_frame(path) + assert len(bpms) == 3 + + +def test_predict_per_frame_with_confidence(predictor, tmp_path): + """Should return (bpms_list, confidences_list) of equal length.""" + sr = 22050 + audio = _make_click_track(120, duration=24, sr=sr) + path = _write_wav(audio, sr, tmp_path) + bpms, confs = predictor.predict_per_frame(path, include_confidence=True) + assert len(bpms) == len(confs) == 3 + + +def test_predict_per_frame_values_in_range(predictor, tmp_path): + """All per-frame BPMs should be in [30, 286].""" + sr = 22050 + audio = _make_click_track(120, duration=24, sr=sr) + path = _write_wav(audio, sr, tmp_path) + bpms = predictor.predict_per_frame(path) + for bpm in bpms: + assert 30 <= bpm <= 286 + + +# --------------------------------------------------------------------------- +# Rich Synthetic Audio (click tracks) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("target_bpm", [90, 120, 140]) +def test_predict_click_track(predictor, tmp_path, target_bpm): + """Click track at target BPM should predict within +/-10%.""" + audio = _make_click_track(target_bpm, duration=24) + path = _write_wav(audio, 22050, tmp_path) + result = predictor.predict(path) + tolerance = target_bpm * 0.10 + assert abs(result - target_bpm) <= tolerance, ( + f"Expected ~{target_bpm} BPM, got {result}" + ) + + +# --------------------------------------------------------------------------- +# Edge Cases +# --------------------------------------------------------------------------- + +def test_predict_silence(predictor, tmp_path): + """Silent audio should return a valid BPM (no crash), low confidence.""" + audio = _make_silence(duration=16) + path = _write_wav(audio, 22050, tmp_path) + bpm, conf = predictor.predict(path, include_confidence=True) + assert 30 <= bpm <= 286 + # Confidence on silence should be relatively low + assert conf < 0.5 + + +def test_predict_exactly_8_seconds(predictor, tmp_path): + """Single clip boundary should work and return a valid BPM.""" + sr = 22050 + audio = _make_click_track(120, duration=8, sr=sr) + # Ensure exactly 8s worth of samples + audio = audio[:sr * 8] + path = _write_wav(audio, sr, tmp_path) + result = predictor.predict(path) + assert 30 <= result <= 286 + + +def test_predict_just_under_8_seconds(predictor, tmp_path): + """Audio just under 8 seconds should raise AudioTooShortError.""" + sr = 22050 + # 7.99 seconds + audio = np.zeros(sr * 8 - 100, dtype=np.float32) + path = _write_wav(audio, sr, tmp_path) + with pytest.raises(AudioTooShortError): + predictor.predict(path) + + +def test_predict_nonexistent_file(predictor): + """Non-existent file should raise an error.""" + with pytest.raises(Exception): + predictor.predict("/nonexistent/path/to/audio.wav") + + +def test_predict_single_vs_multi_clip(predictor, tmp_path): + """8s vs 16s of same pattern should both return valid results.""" + sr = 22050 + audio_short = _make_click_track(120, duration=8, sr=sr)[:sr * 8] + audio_long = _make_click_track(120, duration=16, sr=sr) + + path_short = _write_wav(audio_short, sr, tmp_path, name="short.wav") + path_long = _write_wav(audio_long, sr, tmp_path, name="long.wav") + + bpm_short = predictor.predict(path_short) + bpm_long = predictor.predict(path_long) + + assert 30 <= bpm_short <= 286 + assert 30 <= bpm_long <= 286 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..685ad09 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,168 @@ +import numpy as np + +from deeprhythm.utils import ( + AudioLoadError, + AudioTooShortError, + bpm_to_class, + class_to_bpm, + get_device, + split_audio, +) + +# --------------------------------------------------------------------------- +# BPM <-> class conversion +# --------------------------------------------------------------------------- + +def test_bpm_class_roundtrip(): + """bpm_to_class and class_to_bpm should roundtrip within one class width.""" + class_width = (286 - 30) / 256 + for bpm in [30, 60, 90, 120, 150, 200, 285]: + cls = bpm_to_class(bpm) + recovered = class_to_bpm(cls) + assert abs(recovered - bpm) <= class_width, f"Roundtrip failed for {bpm}: got {recovered}" + + +def test_bpm_to_class_clamps(): + """Values outside [30, 286] should clamp to valid class range.""" + assert bpm_to_class(0) == 0 + assert bpm_to_class(500) == 255 + + +def test_bpm_to_class_monotonic(): + """Increasing BPM should produce non-decreasing class indices.""" + prev = -1 + for bpm in range(30, 287): + cls = bpm_to_class(bpm) + assert cls >= prev, f"Non-monotonic at BPM {bpm}: class {cls} < {prev}" + prev = cls + + +def test_class_to_bpm_monotonic(): + """Increasing class index should produce increasing BPM.""" + prev = -1.0 + for cls in range(256): + bpm = class_to_bpm(cls) + assert bpm > prev, f"Non-monotonic at class {cls}: BPM {bpm} <= {prev}" + prev = bpm + + +def test_bpm_class_boundaries(): + """Test exact boundary values.""" + assert bpm_to_class(30.0) == 0 + assert bpm_to_class(285.0) in (254, 255) + + +def test_bpm_class_full_range_coverage(): + """Every class 0-255 should map to a BPM in [30, 286].""" + for cls in range(256): + bpm = class_to_bpm(cls) + assert 30 <= bpm <= 286, f"Class {cls} maps to out-of-range BPM {bpm}" + + +def test_bpm_to_class_custom_range(): + """Custom min/max/num_classes should work.""" + assert bpm_to_class(60, min_bpm=60, max_bpm=180, num_classes=120) == 0 + assert bpm_to_class(179, min_bpm=60, max_bpm=180, num_classes=120) == 119 + + +# --------------------------------------------------------------------------- +# split_audio +# --------------------------------------------------------------------------- + +def test_split_audio_basic(): + """split_audio should produce correct number of clips.""" + sr = 22050 + num_clips = 3 + audio = np.random.randn(sr * 8 * num_clips + 1000).astype(np.float32) + clips = split_audio(audio, sr, clip_length=8) + assert clips.shape == (num_clips, sr * 8) + + +def test_split_audio_too_short(): + """split_audio should raise AudioTooShortError for short audio.""" + import pytest + sr = 22050 + audio = np.zeros(100, dtype=np.float32) + with pytest.raises(AudioTooShortError): + split_audio(audio, sr) + + +def test_split_audio_share_mem(): + """split_audio with share_mem=True returns a shared-memory tensor.""" + sr = 22050 + audio = np.random.randn(sr * 8).astype(np.float32) + clips = split_audio(audio, sr, share_mem=True) + assert clips.is_shared() + + +def test_split_audio_exact_boundary(): + """Audio length exactly N * clip_samples -> N clips.""" + sr = 22050 + for n in (1, 2, 5): + audio = np.zeros(sr * 8 * n, dtype=np.float32) + clips = split_audio(audio, sr) + assert clips.shape[0] == n + + +def test_split_audio_just_over(): + """N * clip_samples + 1 sample -> still N clips (remainder dropped).""" + sr = 22050 + n = 3 + audio = np.zeros(sr * 8 * n + 1, dtype=np.float32) + clips = split_audio(audio, sr) + assert clips.shape[0] == n + + +def test_split_audio_single_clip(): + """Exactly 8 seconds -> 1 clip.""" + sr = 22050 + audio = np.ones(sr * 8, dtype=np.float32) + clips = split_audio(audio, sr) + assert clips.shape == (1, sr * 8) + + +def test_split_audio_preserves_values(): + """Clip content should match original audio slices.""" + sr = 22050 + clip_len = 8 + clip_samples = sr * clip_len + audio = np.arange(clip_samples * 2, dtype=np.float32) + clips = split_audio(audio, sr, clip_length=clip_len) + np.testing.assert_array_equal(clips[0].numpy(), audio[:clip_samples]) + np.testing.assert_array_equal(clips[1].numpy(), audio[clip_samples:2 * clip_samples]) + + +def test_split_audio_custom_clip_length(): + """Non-default clip length (4 seconds) should work.""" + sr = 22050 + audio = np.zeros(sr * 12, dtype=np.float32) + clips = split_audio(audio, sr, clip_length=4) + assert clips.shape == (3, sr * 4) + + +# --------------------------------------------------------------------------- +# get_device +# --------------------------------------------------------------------------- + +def test_get_device_returns_valid(): + """get_device should return one of the known device strings.""" + device = get_device() + assert device in ('cuda', 'mps', 'cpu') + + +# --------------------------------------------------------------------------- +# Error classes +# --------------------------------------------------------------------------- + +def test_audio_too_short_is_value_error(): + """AudioTooShortError should be a ValueError.""" + assert issubclass(AudioTooShortError, ValueError) + with __import__('pytest').raises(ValueError): + raise AudioTooShortError("test") + + +def test_audio_load_error_is_io_error(): + """AudioLoadError should be an IOError.""" + assert issubclass(AudioLoadError, IOError) + with __import__('pytest').raises(IOError): + raise AudioLoadError("test") diff --git a/deeprhythm-0.5.pth b/weights/deeprhythm-0.5.pth similarity index 100% rename from deeprhythm-0.5.pth rename to weights/deeprhythm-0.5.pth diff --git a/deeprhythm-0.7.pth b/weights/deeprhythm-0.7.pth similarity index 100% rename from deeprhythm-0.7.pth rename to weights/deeprhythm-0.7.pth