Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7166c49
moved weights to their own folder
bleugreen Dec 7, 2024
154f268
reorganizing, removing unused files
bleugreen Dec 7, 2024
2ffb661
gitignore
bleugreen Feb 22, 2026
f671837
[Build] Add test fixtures and helpers in conftest.py
bleugreen Feb 22, 2026
df7add8
[Build] Add Tier 1 pure unit tests in test_utils.py
bleugreen Feb 22, 2026
3b85f91
[Build] Add Tier 2 audio processing tests in test_audio_proc.py
bleugreen Feb 22, 2026
f95839b
[Build] Refactor test_model.py: add architecture tests, mark slow tes…
bleugreen Feb 22, 2026
28dd7ef
[Build] Add Tier 3+4 integration and edge case tests in test_predicto…
bleugreen Feb 22, 2026
c780b64
[Build] Add pytest slow marker config to pyproject.toml
bleugreen Feb 22, 2026
ac03e89
[Build] Fix test_predictor.py: inline helpers instead of importing fr…
bleugreen Feb 22, 2026
6512b53
[Build] Fix detrend tolerance and parameter count range in tests
bleugreen Feb 22, 2026
f1f5094
Fix import sorting in test files
bleugreen Feb 22, 2026
bf02dd6
[Build] Fix ruff lint errors in predictor.py
bleugreen Feb 22, 2026
b17ffa3
[Build] CI: cache pip and use CPU-only torch to avoid CUDA download
bleugreen Feb 22, 2026
ab0a852
[Build] Fix circular import: lazy-import batch_infer in predict_batch
bleugreen Feb 22, 2026
2398d47
[Build] Fix circular import: inline load_cnn_model in batch_infer.py
bleugreen Feb 22, 2026
41a1e67
[Build] CI: skip slow tests that require model weights
bleugreen Feb 22, 2026
f4b0c9f
[Build] CI: download model weights to run full test suite
bleugreen Feb 22, 2026
70baaa3
[Build] Fix torch.load weights_only for PyTorch 2.6+ compatibility
bleugreen Feb 22, 2026
f09d211
[Build] CI: copy model weights from repo instead of downloading
bleugreen Feb 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: pip

- name: Install dependencies
run: pip install -e ".[dev]" soundfile
run: |
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -e ".[dev]" soundfile

- name: Lint
run: ruff check src/

- name: Install model weights
run: |
mkdir -p ~/.local/share/deeprhythm
cp weights/deeprhythm-0.7.pth ~/.local/share/deeprhythm/

- name: Test
run: pytest
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ dist/
*.csv
*.pb
.workspace
.venv
.venv
.DS_Store
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ select = ["E", "F", "I"]

[tool.pytest.ini_options]
testpaths = ["tests"]
markers = [
"slow: tests that require model weights (deselect with '-m \"not slow\"')",
]
28 changes: 9 additions & 19 deletions src/deeprhythm/audio_proc/bandfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,21 @@

def create_log_filter(num_bins, num_bands, device='cuda'):
"""
Create a logarithmically spaced filter matrix for audio processing.

This function generates a filter matrix with logarithmically spaced bands. The filters have
unity gain, meaning that the sum of the filter coefficients in each band is equal to one.
Create a logarithmically spaced filter matrix.

Parameters
----------
num_bins : int
The number of bins in the spectrogram (e.g., the number of frequency bins).
Number of frequency bins in the spectrogram
num_bands : int
The number of bands for the filter matrix. These bands are spaced logarithmically.
Number of logarithmically spaced bands
device : str, optional
The device on which the filter matrix will be created.
Target device for the filter matrix

Returns
-------
torch.Tensor
A tensor representing the filter matrix with shape (num_bands, num_bins). Each row
corresponds to a filter for a specific band.
Filter matrix of shape (num_bands, num_bins)
"""
log_bins = np.logspace(np.log10(1), np.log10(num_bins), num=num_bands+1, base=10.0) - 1
log_bins = np.unique(np.round(log_bins).astype(int))
Expand All @@ -38,25 +34,19 @@ def create_log_filter(num_bins, num_bands, device='cuda'):

def apply_log_filter(stft_output, filter_matrix):
"""
Apply the logarithmic filter matrix to the Short-Time Fourier Transform (STFT) output.

This function applies a precomputed logarithmic filter matrix to the STFT output of an audio signal
to reduce its dimensionality and to capture the energy in logarithmically spaced frequency bands.
Apply logarithmic filter matrix to STFT output.

Parameters
----------
stft_output : torch.Tensor
A tensor representing the STFT output with shape (batch_size, num_bins, num_frames), where
num_bins is the number of frequency bins and num_frames is the number of time frames.
STFT output of shape (batch_size, num_bins, num_frames)
filter_matrix : torch.Tensor
A tensor representing the logarithmic filter matrix with shape (num_bands, num_bins), where
num_bands is the number of logarithmically spaced frequency bands.
Filter matrix of shape (num_bands, num_bins)

Returns
-------
torch.Tensor
A tensor representing the filtered STFT output with shape (batch_size, num_bands, num_frames).
Each band contains the aggregated energy from the corresponding set of frequency bins.
Filtered output of shape (batch_size, num_bands, num_frames)
"""
stft_output_transposed = stft_output.transpose(1, 2)
filtered_output_transposed = torch.matmul(stft_output_transposed, filter_matrix.T)
Expand Down
6 changes: 3 additions & 3 deletions src/deeprhythm/audio_proc/onset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def onset_strength(
y=None, n_fft=2048, hop_length=512, lag=1, ref=None,
detrend=False, center=True, aggregate=None
detrend=False, center=True, aggregate=None
):
"""
Compute the onset strength of an audio signal or a spectrogram.
Expand Down Expand Up @@ -50,8 +50,8 @@ def onset_strength(

# Compute difference to reference, spaced by lag
onset_env = S[..., lag:] - ref[..., :-lag]
onset_env = torch.clamp(onset_env, min=0.0) # Discard negatives

onset_env = torch.clamp(onset_env, min=0.0)
if aggregate is None:
aggregate = torch.mean
if callable(aggregate):
Expand Down
15 changes: 12 additions & 3 deletions src/deeprhythm/batch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,15 @@
import torch.multiprocessing as multiprocessing

from deeprhythm.audio_proc.hcqm import compute_hcqm, make_kernels
from deeprhythm.model.predictor import load_cnn_model
from deeprhythm.utils import AudioLoadError, AudioTooShortError, class_to_bpm, get_device, load_and_split_audio
from deeprhythm.model.frame_cnn import DeepRhythmModel
from deeprhythm.utils import (
AudioLoadError,
AudioTooShortError,
class_to_bpm,
get_device,
get_weights,
load_and_split_audio,
)

NUM_WORKERS = 8
NUM_BATCH = 128
Expand Down Expand Up @@ -112,7 +119,9 @@ def consume_and_process(
specs = make_kernels(len_audio, sr, device=device)
if not quiet:
print('made kernels')
model = load_cnn_model(device=device, quiet=quiet)
model = DeepRhythmModel()
model.load_state_dict(torch.load(get_weights(quiet=quiet), map_location=torch.device(device), weights_only=False))
model = model.to(device=device)
model.eval()
if not quiet:
print('loaded model')
Expand Down
Empty file.
Loading