Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9bcb8e5
[Build] Add shared constants module
bleugreen Feb 22, 2026
adfd8a4
[Build] Fix predictor.py bugs: device logic, truthiness check, dead c…
bleugreen Feb 22, 2026
ea31cb5
[Build] Extract shared beat extraction pipeline into postproc/__init_…
bleugreen Feb 22, 2026
cb3dccd
[Build] Use extract_beat_times in predictor.py
bleugreen Feb 22, 2026
89c044f
[Build] Deduplicate val.py with extract_beat_times and _report_scores…
bleugreen Feb 22, 2026
652be7b
[Build] Fix train.py: import, device handling, running_loss bug, remo…
bleugreen Feb 22, 2026
5a4e0a6
[Build] Fix dataset.py: broken import, auto-detect device, add type h…
bleugreen Feb 22, 2026
aa8158a
[Build] Minor fixes: rename relu->activation, remove debug prints, de…
bleugreen Feb 22, 2026
3cf4a9e
[Build] Add type hints to hmm.py
bleugreen Feb 22, 2026
f4fbbe1
[Build] Add type hints to get_device.py
bleugreen Feb 22, 2026
6dd1d5e
[Build] Add type hints to get_weights.py
bleugreen Feb 22, 2026
56d2757
[Build] Add type hints to log_filter.py
bleugreen Feb 22, 2026
bf0b71c
[Build] Fix test imports to use phasefinder. prefix
bleugreen Feb 22, 2026
9b9a66f
[Build] Set up pytest: add tests directory with weight loading test
bleugreen Feb 22, 2026
cdeab42
[Build] Fix packaging: remove duplicate librosa, add dev deps, pytest…
bleugreen Feb 22, 2026
2a9d576
[Build] Add CI/CD workflow with lint and test jobs
bleugreen Feb 22, 2026
5044df0
[Build] Fix ruff lint: explicit re-exports, import sorting, trailing …
bleugreen Feb 22, 2026
a830482
Apply ruff auto-fixes: import sorting, trailing whitespace, unused im…
bleugreen Feb 22, 2026
01e45f9
Apply ruff formatting
bleugreen Feb 22, 2026
33a2fe5
[Build] Skip weight loading test when model files unavailable
bleugreen Feb 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
- run: pip install ruff
- run: ruff check .
- run: ruff format --check .

test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install CPU-only PyTorch
run: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu

- name: Install package with dev dependencies
run: pip install -e ".[dev]"

- name: Cache model weights
uses: actions/cache@v4
with:
path: ~/.local/share/phasefinder
key: phasefinder-weights-v1

- name: Run tests
run: pytest tests/ -v
5 changes: 3 additions & 2 deletions phasefinder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .predictor import Phasefinder
from .predictor import Phasefinder as Phasefinder


def version_info():
return {
"name": "phasefinder",
"version": "0.0.2",
"description": "A beat estimation model that predicts metric position as rotational phase.",
"author": "bleugreen",
"license": "AGPL-3.0"
"license": "AGPL-3.0",
}
26 changes: 17 additions & 9 deletions phasefinder/audio/log_filter.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
import torch
from typing import Union

import numpy as np
import torch

def create_log_filter(num_bins, num_bands, device='cuda'):
log_bins = np.logspace(np.log10(1), np.log10(num_bins), num=num_bands+1, base=10) - 1

def create_log_filter(
num_bins: int,
num_bands: int,
device: Union[str, torch.device] = "cuda",
) -> torch.Tensor:
log_bins = np.logspace(np.log10(1), np.log10(num_bins), num=num_bands + 1, base=10) - 1
log_bins = np.floor(log_bins).astype(int)
log_bins[-1] = num_bins
log_bins[-1] = num_bins
if len(np.unique(log_bins)) < len(log_bins):
for i in range(1, len(log_bins)):
if log_bins[i] <= log_bins[i-1]:
log_bins[i] = log_bins[i-1] + 1
if log_bins[i] <= log_bins[i - 1]:
log_bins[i] = log_bins[i - 1] + 1
filter_matrix = torch.zeros(num_bands, num_bins, device=device)
for i in range(num_bands):
start_bin = log_bins[i]
end_bin = log_bins[i+1] if i < num_bands - 1 else num_bins
end_bin = log_bins[i + 1] if i < num_bands - 1 else num_bins
filter_matrix[i, start_bin:end_bin] = 1 / max(1, (end_bin - start_bin))

return filter_matrix

def apply_log_filter(stft_output, filter_matrix):

def apply_log_filter(stft_output: torch.Tensor, filter_matrix: torch.Tensor) -> torch.Tensor:
stft_output_transposed = stft_output.transpose(1, 2)
filtered_output_transposed = torch.matmul(stft_output_transposed, filter_matrix.T)
filtered_output = filtered_output_transposed.transpose(1, 2)
return filtered_output
return filtered_output
6 changes: 6 additions & 0 deletions phasefinder/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SAMPLE_RATE = 22050
N_FFT = 2048
HOP = 512
FRAME_RATE = SAMPLE_RATE / HOP # ~43.066 frames/sec
BEAT_ONSET_THRESHOLD = 300
CLICK_SAMPLE_RATE = 44100 # higher rate for output audio quality
89 changes: 56 additions & 33 deletions phasefinder/dataset.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,93 @@
from torch.utils.data import Dataset
from typing import Optional

import h5py
import torch
from utils.one_hots import generate_blurred_one_hots_wrapped
from torch.utils.data import Dataset

from phasefinder.utils import get_device
from phasefinder.utils.one_hots import generate_blurred_one_hots_wrapped


class BeatDataset(Dataset):
def __init__(self, data_path, group, mode='both', items=None, device='cuda', phase_width=5):
def __init__(
self,
data_path: str,
group: str,
mode: str = "both",
items: Optional[list[str]] = None,
device: Optional[str] = None,
phase_width: int = 5,
):
"""
Initializes the dataset.
:param data_path: Path to the HDF5 file.
:param group: The group within the HDF5 file ('train', 'val', 'test').
:param mode: 'beat', 'downbeat', or 'both' (default: 'both').
:param items: List of items to include in the dataset (default: None).
Options: 'stft', 'phase', 'label', 'time', 'filepath', 'bpm'.
:param device: Device to use for tensors (default: 'cuda').
:param device: Device to use for tensors (default: None, auto-detect).
:param phase_width: Width for blurred one-hot encoding (default: 5).
"""
self.data_path = data_path
self.group = group
self.mode = mode
self.phase_width = phase_width
self.items = items if items is not None else ['stft', 'phase', 'label', 'time', 'filepath', 'bpm']
self.device = device
self.items = items if items is not None else ["stft", "phase", "label", "time", "filepath", "bpm"]
self.device = device if device is not None else get_device()

with h5py.File(self.data_path, 'r') as file:
with h5py.File(self.data_path, "r") as file:
self.keys = list(file[group].keys())

def __len__(self):
def __len__(self) -> int:
return len(self.keys)

def __getitem__(self, idx):
with h5py.File(self.data_path, 'r') as file:
def __getitem__(self, idx: int) -> tuple:
with h5py.File(self.data_path, "r") as file:
data = file[self.group][self.keys[idx]]
result = []

for item in self.items:
if item == 'stft':
spec = torch.from_numpy(data['stft'][...]).to(self.device)
if item == "stft":
spec = torch.from_numpy(data["stft"][...]).to(self.device)
result.append(spec)
elif item == 'phase':
if self.mode == 'beat' or self.mode == 'both':
beat_phase = torch.from_numpy(data['beat_phase'][...]).long()
beat_phase = generate_blurred_one_hots_wrapped(beat_phase, width=self.phase_width).to_dense().unsqueeze(0).to(self.device)
elif item == "phase":
if self.mode == "beat" or self.mode == "both":
beat_phase = torch.from_numpy(data["beat_phase"][...]).long()
beat_phase = (
generate_blurred_one_hots_wrapped(beat_phase, width=self.phase_width)
.to_dense()
.unsqueeze(0)
.to(self.device)
)
result.append(beat_phase)
if self.mode == 'downbeat' or self.mode == 'both':
downbeat_phase = torch.from_numpy(data['downbeat_phase'][...]).long()
downbeat_phase = generate_blurred_one_hots_wrapped(downbeat_phase, width=self.phase_width).to_dense().unsqueeze(0).to(self.device)
if self.mode == "downbeat" or self.mode == "both":
downbeat_phase = torch.from_numpy(data["downbeat_phase"][...]).long()
downbeat_phase = (
generate_blurred_one_hots_wrapped(downbeat_phase, width=self.phase_width)
.to_dense()
.unsqueeze(0)
.to(self.device)
)
result.append(downbeat_phase)
elif item == 'label':
if self.mode == 'beat' or self.mode == 'both':
beat_phase = torch.from_numpy(data['beat_phase'][...]).long().to(self.device)
elif item == "label":
if self.mode == "beat" or self.mode == "both":
beat_phase = torch.from_numpy(data["beat_phase"][...]).long().to(self.device)
result.append(beat_phase)
if self.mode == 'downbeat' or self.mode == 'both':
downbeat_phase = torch.from_numpy(data['downbeat_phase'][...]).long().to(self.device)
if self.mode == "downbeat" or self.mode == "both":
downbeat_phase = torch.from_numpy(data["downbeat_phase"][...]).long().to(self.device)
result.append(downbeat_phase)
elif item == 'time':
if self.mode == 'beat' or self.mode == 'both':
beats = torch.from_numpy(data.attrs['beats'][...]).float()
elif item == "time":
if self.mode == "beat" or self.mode == "both":
beats = torch.from_numpy(data.attrs["beats"][...]).float()
result.append(beats)
if self.mode == 'downbeat' or self.mode == 'both':
downbeats = torch.from_numpy(data.attrs['downbeats'][...]).float()
if self.mode == "downbeat" or self.mode == "both":
downbeats = torch.from_numpy(data.attrs["downbeats"][...]).float()
result.append(downbeats)
elif item == 'filepath':
filepath = data.attrs['filepath']
elif item == "filepath":
filepath = data.attrs["filepath"]
result.append(filepath)
elif item == 'bpm':
bpm = data.attrs['bpm']
elif item == "bpm":
bpm = data.attrs["bpm"]
result.append(bpm)

return tuple(result)
46 changes: 27 additions & 19 deletions phasefinder/infer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import argparse
import json

import librosa
import numpy as np
import soundfile as sf
import argparse
import json

from phasefinder.predictor import Phasefinder

if __name__ == '__main__':
if __name__ == "__main__":
pf = Phasefinder()

parser = argparse.ArgumentParser(description='Predict beats from an audio file.')
parser.add_argument('audio_path', type=str, help='Path to the audio file')
parser.add_argument('--bpm', action='store_true', help='Include BPM in the output')
parser.add_argument('--noclean', action='store_true', help='Don\'t apply cleaning function')
parser.add_argument('--format', type=str, choices=['times', 'click_track'], default='times', help='Output format: "times" for beat times or "click_track" for audio with click track')
parser.add_argument('--audio_output', type=str, default='output_with_clicks.wav', help='Path to save the output audio file with clicks')
parser.add_argument('--json_output', type=str, default='', help='Path to save the output json results')
parser = argparse.ArgumentParser(description="Predict beats from an audio file.")
parser.add_argument("audio_path", type=str, help="Path to the audio file")
parser.add_argument("--bpm", action="store_true", help="Include BPM in the output")
parser.add_argument("--noclean", action="store_true", help="Don't apply cleaning function")
parser.add_argument(
"--format",
type=str,
choices=["times", "click_track"],
default="times",
help='Output format: "times" for beat times or "click_track" for audio with click track',
)
parser.add_argument(
"--audio_output",
type=str,
default="output_with_clicks.wav",
help="Path to save the output audio file with clicks",
)
parser.add_argument("--json_output", type=str, default="", help="Path to save the output json results")

args = parser.parse_args()

Expand All @@ -26,23 +37,20 @@
else:
beat_times = pf.predict(audio_path, include_bpm=args.bpm, clean=not args.noclean)

if args.format == 'click_track':
if args.format == "click_track":
audio, sr = librosa.load(audio_path)
click_track = librosa.clicks(times=beat_times, sr=sr, length=len(audio))
audio_with_clicks = np.array([click_track, audio])
audio_with_clicks = np.vstack([click_track, audio]).T
sf.write(args.audio_output, audio_with_clicks, sr)
else:
if args.json_output != '':
output_data = {
'beat_times': beat_times.tolist()
}
if args.json_output != "":
output_data = {"beat_times": beat_times.tolist()}
if args.bpm:
output_data['bpm'] = bpm
output_data["bpm"] = bpm

with open(args.json_output, 'w') as json_file:
with open(args.json_output, "w") as json_file:
json.dump(output_data, json_file, indent=4)
else:
print(f"beats = {beat_times}")
if args.bpm:
print(f'bpm = {bpm}')
print(f"bpm = {bpm}")
4 changes: 2 additions & 2 deletions phasefinder/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .model_attn import PhasefinderModelAttn
from .model_noattn import PhasefinderModelNoattn
from .model_attn import PhasefinderModelAttn as PhasefinderModelAttn
from .model_noattn import PhasefinderModelNoattn as PhasefinderModelNoattn
25 changes: 14 additions & 11 deletions phasefinder/model/attention.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,50 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from phasefinder.model.pos_encoding import PositionalEncoding


class AttentionModule(nn.Module):
def __init__(self, input_dim, num_heads=4, dropout=0.1):
super(AttentionModule, self).__init__()
self.input_dim = input_dim
self.num_heads = num_heads
self.head_dim = input_dim // num_heads

# Multi-head attention
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)

# Output projection
self.proj = nn.Linear(input_dim, input_dim)

# Positional encoding
self.pos_encoding = PositionalEncoding(input_dim, dropout=dropout)

self.dropout = nn.Dropout(dropout)

def forward(self, x):
batch_size, seq_len, _ = x.size()

# Add positional encoding
x = self.pos_encoding(x)

# Compute Q, K, V
q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)

# Apply attention
context = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.input_dim)

output = self.proj(context)
return output
5 changes: 3 additions & 2 deletions phasefinder/model/decoder.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import torch.nn as nn


class BeatPhaseDecoder(nn.Module):
def __init__(self, num_tcn_outputs, num_classes):
super(BeatPhaseDecoder, self).__init__()
self.dropout1 = nn.Dropout(0.1)
self.dense1 = nn.Linear(num_tcn_outputs, 72)
self.relu = nn.ELU()
self.activation = nn.ELU()
self.dropout2 = nn.Dropout(0.1)
self.dense2 = nn.Linear(72, num_classes)
self.softmax = nn.LogSoftmax(dim=2)

def forward(self, x):
x = x.transpose(1, 2)
x = self.dropout1(x)
x = self.relu(self.dense1(x))
x = self.activation(self.dense1(x))
x = self.dropout2(x)
x = self.dense2(x)
x = self.softmax(x)
Expand Down
Loading