diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..c23552e --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,42 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - run: pip install ruff + - run: ruff check . + - run: ruff format --check . + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install CPU-only PyTorch + run: pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu + + - name: Install package with dev dependencies + run: pip install -e ".[dev]" + + - name: Cache model weights + uses: actions/cache@v4 + with: + path: ~/.local/share/phasefinder + key: phasefinder-weights-v1 + + - name: Run tests + run: pytest tests/ -v diff --git a/phasefinder/__init__.py b/phasefinder/__init__.py index 4ab48d7..f603a33 100644 --- a/phasefinder/__init__.py +++ b/phasefinder/__init__.py @@ -1,4 +1,5 @@ -from .predictor import Phasefinder +from .predictor import Phasefinder as Phasefinder + def version_info(): return { @@ -6,5 +7,5 @@ def version_info(): "version": "0.0.2", "description": "A beat estimation model that predicts metric position as rotational phase.", "author": "bleugreen", - "license": "AGPL-3.0" + "license": "AGPL-3.0", } diff --git a/phasefinder/audio/log_filter.py b/phasefinder/audio/log_filter.py index e34afd5..9bb2d58 100644 --- a/phasefinder/audio/log_filter.py +++ b/phasefinder/audio/log_filter.py @@ -1,24 +1,32 @@ -import torch +from typing import Union + import numpy as np +import torch -def create_log_filter(num_bins, num_bands, device='cuda'): - log_bins = np.logspace(np.log10(1), np.log10(num_bins), num=num_bands+1, base=10) - 1 + +def create_log_filter( + num_bins: int, + num_bands: int, + device: Union[str, torch.device] = "cuda", +) -> torch.Tensor: + log_bins = np.logspace(np.log10(1), np.log10(num_bins), num=num_bands + 1, base=10) - 1 log_bins = np.floor(log_bins).astype(int) - log_bins[-1] = num_bins + log_bins[-1] = num_bins if len(np.unique(log_bins)) < len(log_bins): for i in range(1, len(log_bins)): - if log_bins[i] <= log_bins[i-1]: - log_bins[i] = log_bins[i-1] + 1 + if log_bins[i] <= log_bins[i - 1]: + log_bins[i] = log_bins[i - 1] + 1 filter_matrix = torch.zeros(num_bands, num_bins, device=device) for i in range(num_bands): start_bin = log_bins[i] - end_bin = log_bins[i+1] if i < num_bands - 1 else num_bins + end_bin = log_bins[i + 1] if i < num_bands - 1 else num_bins filter_matrix[i, start_bin:end_bin] = 1 / max(1, (end_bin - start_bin)) return filter_matrix -def apply_log_filter(stft_output, filter_matrix): + +def apply_log_filter(stft_output: torch.Tensor, filter_matrix: torch.Tensor) -> torch.Tensor: stft_output_transposed = stft_output.transpose(1, 2) filtered_output_transposed = torch.matmul(stft_output_transposed, filter_matrix.T) filtered_output = filtered_output_transposed.transpose(1, 2) - return filtered_output \ No newline at end of file + return filtered_output diff --git a/phasefinder/constants.py b/phasefinder/constants.py new file mode 100644 index 0000000..6dffadb --- /dev/null +++ b/phasefinder/constants.py @@ -0,0 +1,6 @@ +SAMPLE_RATE = 22050 +N_FFT = 2048 +HOP = 512 +FRAME_RATE = SAMPLE_RATE / HOP # ~43.066 frames/sec +BEAT_ONSET_THRESHOLD = 300 +CLICK_SAMPLE_RATE = 44100 # higher rate for output audio quality diff --git a/phasefinder/dataset.py b/phasefinder/dataset.py index d4d50e4..36ec00e 100644 --- a/phasefinder/dataset.py +++ b/phasefinder/dataset.py @@ -1,11 +1,23 @@ -from torch.utils.data import Dataset +from typing import Optional + import h5py import torch -from utils.one_hots import generate_blurred_one_hots_wrapped +from torch.utils.data import Dataset + +from phasefinder.utils import get_device +from phasefinder.utils.one_hots import generate_blurred_one_hots_wrapped class BeatDataset(Dataset): - def __init__(self, data_path, group, mode='both', items=None, device='cuda', phase_width=5): + def __init__( + self, + data_path: str, + group: str, + mode: str = "both", + items: Optional[list[str]] = None, + device: Optional[str] = None, + phase_width: int = 5, + ): """ Initializes the dataset. :param data_path: Path to the HDF5 file. @@ -13,58 +25,69 @@ def __init__(self, data_path, group, mode='both', items=None, device='cuda', pha :param mode: 'beat', 'downbeat', or 'both' (default: 'both'). :param items: List of items to include in the dataset (default: None). Options: 'stft', 'phase', 'label', 'time', 'filepath', 'bpm'. - :param device: Device to use for tensors (default: 'cuda'). + :param device: Device to use for tensors (default: None, auto-detect). + :param phase_width: Width for blurred one-hot encoding (default: 5). """ self.data_path = data_path self.group = group self.mode = mode self.phase_width = phase_width - self.items = items if items is not None else ['stft', 'phase', 'label', 'time', 'filepath', 'bpm'] - self.device = device + self.items = items if items is not None else ["stft", "phase", "label", "time", "filepath", "bpm"] + self.device = device if device is not None else get_device() - with h5py.File(self.data_path, 'r') as file: + with h5py.File(self.data_path, "r") as file: self.keys = list(file[group].keys()) - def __len__(self): + def __len__(self) -> int: return len(self.keys) - def __getitem__(self, idx): - with h5py.File(self.data_path, 'r') as file: + def __getitem__(self, idx: int) -> tuple: + with h5py.File(self.data_path, "r") as file: data = file[self.group][self.keys[idx]] result = [] for item in self.items: - if item == 'stft': - spec = torch.from_numpy(data['stft'][...]).to(self.device) + if item == "stft": + spec = torch.from_numpy(data["stft"][...]).to(self.device) result.append(spec) - elif item == 'phase': - if self.mode == 'beat' or self.mode == 'both': - beat_phase = torch.from_numpy(data['beat_phase'][...]).long() - beat_phase = generate_blurred_one_hots_wrapped(beat_phase, width=self.phase_width).to_dense().unsqueeze(0).to(self.device) + elif item == "phase": + if self.mode == "beat" or self.mode == "both": + beat_phase = torch.from_numpy(data["beat_phase"][...]).long() + beat_phase = ( + generate_blurred_one_hots_wrapped(beat_phase, width=self.phase_width) + .to_dense() + .unsqueeze(0) + .to(self.device) + ) result.append(beat_phase) - if self.mode == 'downbeat' or self.mode == 'both': - downbeat_phase = torch.from_numpy(data['downbeat_phase'][...]).long() - downbeat_phase = generate_blurred_one_hots_wrapped(downbeat_phase, width=self.phase_width).to_dense().unsqueeze(0).to(self.device) + if self.mode == "downbeat" or self.mode == "both": + downbeat_phase = torch.from_numpy(data["downbeat_phase"][...]).long() + downbeat_phase = ( + generate_blurred_one_hots_wrapped(downbeat_phase, width=self.phase_width) + .to_dense() + .unsqueeze(0) + .to(self.device) + ) result.append(downbeat_phase) - elif item == 'label': - if self.mode == 'beat' or self.mode == 'both': - beat_phase = torch.from_numpy(data['beat_phase'][...]).long().to(self.device) + elif item == "label": + if self.mode == "beat" or self.mode == "both": + beat_phase = torch.from_numpy(data["beat_phase"][...]).long().to(self.device) result.append(beat_phase) - if self.mode == 'downbeat' or self.mode == 'both': - downbeat_phase = torch.from_numpy(data['downbeat_phase'][...]).long().to(self.device) + if self.mode == "downbeat" or self.mode == "both": + downbeat_phase = torch.from_numpy(data["downbeat_phase"][...]).long().to(self.device) result.append(downbeat_phase) - elif item == 'time': - if self.mode == 'beat' or self.mode == 'both': - beats = torch.from_numpy(data.attrs['beats'][...]).float() + elif item == "time": + if self.mode == "beat" or self.mode == "both": + beats = torch.from_numpy(data.attrs["beats"][...]).float() result.append(beats) - if self.mode == 'downbeat' or self.mode == 'both': - downbeats = torch.from_numpy(data.attrs['downbeats'][...]).float() + if self.mode == "downbeat" or self.mode == "both": + downbeats = torch.from_numpy(data.attrs["downbeats"][...]).float() result.append(downbeats) - elif item == 'filepath': - filepath = data.attrs['filepath'] + elif item == "filepath": + filepath = data.attrs["filepath"] result.append(filepath) - elif item == 'bpm': - bpm = data.attrs['bpm'] + elif item == "bpm": + bpm = data.attrs["bpm"] result.append(bpm) return tuple(result) diff --git a/phasefinder/infer.py b/phasefinder/infer.py index 5acbfd7..b31efbf 100644 --- a/phasefinder/infer.py +++ b/phasefinder/infer.py @@ -1,22 +1,33 @@ +import argparse +import json import librosa import numpy as np import soundfile as sf -import argparse -import json from phasefinder.predictor import Phasefinder -if __name__ == '__main__': +if __name__ == "__main__": pf = Phasefinder() - parser = argparse.ArgumentParser(description='Predict beats from an audio file.') - parser.add_argument('audio_path', type=str, help='Path to the audio file') - parser.add_argument('--bpm', action='store_true', help='Include BPM in the output') - parser.add_argument('--noclean', action='store_true', help='Don\'t apply cleaning function') - parser.add_argument('--format', type=str, choices=['times', 'click_track'], default='times', help='Output format: "times" for beat times or "click_track" for audio with click track') - parser.add_argument('--audio_output', type=str, default='output_with_clicks.wav', help='Path to save the output audio file with clicks') - parser.add_argument('--json_output', type=str, default='', help='Path to save the output json results') + parser = argparse.ArgumentParser(description="Predict beats from an audio file.") + parser.add_argument("audio_path", type=str, help="Path to the audio file") + parser.add_argument("--bpm", action="store_true", help="Include BPM in the output") + parser.add_argument("--noclean", action="store_true", help="Don't apply cleaning function") + parser.add_argument( + "--format", + type=str, + choices=["times", "click_track"], + default="times", + help='Output format: "times" for beat times or "click_track" for audio with click track', + ) + parser.add_argument( + "--audio_output", + type=str, + default="output_with_clicks.wav", + help="Path to save the output audio file with clicks", + ) + parser.add_argument("--json_output", type=str, default="", help="Path to save the output json results") args = parser.parse_args() @@ -26,23 +37,20 @@ else: beat_times = pf.predict(audio_path, include_bpm=args.bpm, clean=not args.noclean) - if args.format == 'click_track': + if args.format == "click_track": audio, sr = librosa.load(audio_path) click_track = librosa.clicks(times=beat_times, sr=sr, length=len(audio)) - audio_with_clicks = np.array([click_track, audio]) audio_with_clicks = np.vstack([click_track, audio]).T sf.write(args.audio_output, audio_with_clicks, sr) else: - if args.json_output != '': - output_data = { - 'beat_times': beat_times.tolist() - } + if args.json_output != "": + output_data = {"beat_times": beat_times.tolist()} if args.bpm: - output_data['bpm'] = bpm + output_data["bpm"] = bpm - with open(args.json_output, 'w') as json_file: + with open(args.json_output, "w") as json_file: json.dump(output_data, json_file, indent=4) else: print(f"beats = {beat_times}") if args.bpm: - print(f'bpm = {bpm}') + print(f"bpm = {bpm}") diff --git a/phasefinder/model/__init__.py b/phasefinder/model/__init__.py index cb1c9c0..eddfc4a 100644 --- a/phasefinder/model/__init__.py +++ b/phasefinder/model/__init__.py @@ -1,2 +1,2 @@ -from .model_attn import PhasefinderModelAttn -from .model_noattn import PhasefinderModelNoattn \ No newline at end of file +from .model_attn import PhasefinderModelAttn as PhasefinderModelAttn +from .model_noattn import PhasefinderModelNoattn as PhasefinderModelNoattn diff --git a/phasefinder/model/attention.py b/phasefinder/model/attention.py index b6bd25d..ca84c53 100644 --- a/phasefinder/model/attention.py +++ b/phasefinder/model/attention.py @@ -1,47 +1,50 @@ +import math + import torch import torch.nn as nn import torch.nn.functional as F -import math + from phasefinder.model.pos_encoding import PositionalEncoding + class AttentionModule(nn.Module): def __init__(self, input_dim, num_heads=4, dropout=0.1): super(AttentionModule, self).__init__() self.input_dim = input_dim self.num_heads = num_heads self.head_dim = input_dim // num_heads - + # Multi-head attention self.query = nn.Linear(input_dim, input_dim) self.key = nn.Linear(input_dim, input_dim) self.value = nn.Linear(input_dim, input_dim) - + # Output projection self.proj = nn.Linear(input_dim, input_dim) - + # Positional encoding self.pos_encoding = PositionalEncoding(input_dim, dropout=dropout) - + self.dropout = nn.Dropout(dropout) - + def forward(self, x): batch_size, seq_len, _ = x.size() - + # Add positional encoding x = self.pos_encoding(x) - + # Compute Q, K, V q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - + # Compute attention scores scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) - + # Apply attention context = torch.matmul(attn, v).transpose(1, 2).contiguous().view(batch_size, seq_len, self.input_dim) - + output = self.proj(context) return output diff --git a/phasefinder/model/decoder.py b/phasefinder/model/decoder.py index 12ab7ec..a8d1c87 100644 --- a/phasefinder/model/decoder.py +++ b/phasefinder/model/decoder.py @@ -1,11 +1,12 @@ import torch.nn as nn + class BeatPhaseDecoder(nn.Module): def __init__(self, num_tcn_outputs, num_classes): super(BeatPhaseDecoder, self).__init__() self.dropout1 = nn.Dropout(0.1) self.dense1 = nn.Linear(num_tcn_outputs, 72) - self.relu = nn.ELU() + self.activation = nn.ELU() self.dropout2 = nn.Dropout(0.1) self.dense2 = nn.Linear(72, num_classes) self.softmax = nn.LogSoftmax(dim=2) @@ -13,7 +14,7 @@ def __init__(self, num_tcn_outputs, num_classes): def forward(self, x): x = x.transpose(1, 2) x = self.dropout1(x) - x = self.relu(self.dense1(x)) + x = self.activation(self.dense1(x)) x = self.dropout2(x) x = self.dense2(x) x = self.softmax(x) diff --git a/phasefinder/model/feature1d.py b/phasefinder/model/feature1d.py index 467a61c..9b1ee64 100644 --- a/phasefinder/model/feature1d.py +++ b/phasefinder/model/feature1d.py @@ -1,16 +1,17 @@ import torch.nn as nn + class FeatureExtraction(nn.Module): def __init__(self, num_bands=81, num_channels=20): super(FeatureExtraction, self).__init__() self.conv1 = nn.Conv1d(num_bands, num_channels, kernel_size=3, stride=1, padding=1) - self.pool1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + self.pool1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) self.elu1 = nn.ELU() self.dropout1 = nn.Dropout(0.1) self.conv2 = nn.Conv1d(num_channels, num_channels, kernel_size=3, stride=1, padding=1) - self.pool2 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + self.pool2 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) self.elu2 = nn.ELU() self.dropout2 = nn.Dropout(0.1) @@ -21,17 +22,17 @@ def __init__(self, num_bands=81, num_channels=20): def forward(self, x): x = self.conv1(x) - x = self.pool1(x) + x = self.pool1(x) x = self.elu1(x) x = self.dropout1(x) x = self.conv2(x) - x = self.pool2(x) + x = self.pool2(x) x = self.elu2(x) x = self.dropout2(x) x = self.conv3(x) - x = self.pool3(x) + x = self.pool3(x) x = self.elu3(x) x = self.dropout3(x) diff --git a/phasefinder/model/model_attn.py b/phasefinder/model/model_attn.py index 783e42e..a4b114d 100644 --- a/phasefinder/model/model_attn.py +++ b/phasefinder/model/model_attn.py @@ -1,14 +1,15 @@ import torch.nn as nn from pytorch_tcn import TCN -from phasefinder.model.feature1d import FeatureExtraction -from phasefinder.model.decoder import BeatPhaseDecoder from phasefinder.model.attention import AttentionModule +from phasefinder.model.decoder import BeatPhaseDecoder +from phasefinder.model.feature1d import FeatureExtraction class PhasefinderModelAttn(nn.Module): - def __init__(self, num_bands=81, num_channels=36, num_classes=360, - kernel_size=5, dropout=0.1, num_tcn_layers=16, dilation=8): + def __init__( + self, num_bands=81, num_channels=36, num_classes=360, kernel_size=5, dropout=0.1, num_tcn_layers=16, dilation=8 + ): super(PhasefinderModelAttn, self).__init__() self.feature_extraction = FeatureExtraction(num_bands=num_bands, num_channels=num_channels) self.tcn_beat = TCN( @@ -18,10 +19,10 @@ def __init__(self, num_bands=81, num_channels=36, num_classes=360, dropout=dropout, causal=False, use_skip_connections=True, - kernel_initializer='kaiming_normal', - use_norm='layer_norm', - activation='relu', - dilation_reset=dilation + kernel_initializer="kaiming_normal", + use_norm="layer_norm", + activation="relu", + dilation_reset=dilation, ) self.attention = AttentionModule(num_channels) self.decoder_beat = BeatPhaseDecoder(num_tcn_outputs=num_channels, num_classes=num_classes) diff --git a/phasefinder/model/model_noattn.py b/phasefinder/model/model_noattn.py index 76bd2f1..8250039 100644 --- a/phasefinder/model/model_noattn.py +++ b/phasefinder/model/model_noattn.py @@ -1,12 +1,14 @@ import torch.nn as nn from pytorch_tcn import TCN -from phasefinder.model.feature1d import FeatureExtraction from phasefinder.model.decoder import BeatPhaseDecoder +from phasefinder.model.feature1d import FeatureExtraction + class PhasefinderModelNoattn(nn.Module): - def __init__(self, num_bands=81, num_channels=36, num_classes=360, - kernel_size=5, dropout=0.1, num_tcn_layers=16, dilation=8): + def __init__( + self, num_bands=81, num_channels=36, num_classes=360, kernel_size=5, dropout=0.1, num_tcn_layers=16, dilation=8 + ): super(PhasefinderModelNoattn, self).__init__() self.feature_extraction = FeatureExtraction(num_bands=num_bands, num_channels=num_channels) self.tcn_beat = TCN( @@ -16,10 +18,10 @@ def __init__(self, num_bands=81, num_channels=36, num_classes=360, dropout=dropout, causal=False, use_skip_connections=True, - kernel_initializer='kaiming_normal', - use_norm='layer_norm', - activation='relu', - dilation_reset=dilation + kernel_initializer="kaiming_normal", + use_norm="layer_norm", + activation="relu", + dilation_reset=dilation, ) self.decoder_beat = BeatPhaseDecoder(num_tcn_outputs=num_channels, num_classes=num_classes) diff --git a/phasefinder/model/pos_encoding.py b/phasefinder/model/pos_encoding.py index 4bf9de0..ae3a95c 100644 --- a/phasefinder/model/pos_encoding.py +++ b/phasefinder/model/pos_encoding.py @@ -1,6 +1,8 @@ +import math + import torch import torch.nn as nn -import math + class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): @@ -13,8 +15,8 @@ def __init__(self, d_model, dropout=0.1, max_len=5000): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): - x = x + self.pe[:x.size(0), :] - return self.dropout(x) \ No newline at end of file + x = x + self.pe[: x.size(0), :] + return self.dropout(x) diff --git a/phasefinder/postproc/__init__.py b/phasefinder/postproc/__init__.py index e69de29..57c666e 100644 --- a/phasefinder/postproc/__init__.py +++ b/phasefinder/postproc/__init__.py @@ -0,0 +1,32 @@ +from typing import Union + +import numpy as np +import torch + +from phasefinder.constants import BEAT_ONSET_THRESHOLD, FRAME_RATE, HOP, SAMPLE_RATE +from phasefinder.postproc.cleaner import clean_beats +from phasefinder.postproc.hmm import hmm_beat_estimation + + +def extract_beat_times( + phase_tensor: torch.Tensor, + bpm: float, + bpm_confidence: float = 0.9, + distance_threshold_factor: float = 0.2, + clean: bool = True, + device: Union[str, torch.device] = "cpu", +) -> np.ndarray: + """Run HMM beat estimation, onset detection, and optional cleaning.""" + res = hmm_beat_estimation( + phase_tensor, + bpm, + bpm_confidence=bpm_confidence, + distance_threshold_factor=distance_threshold_factor, + frame_rate=FRAME_RATE, + device=device, + ) + bt = torch.tensor(res) + onset = torch.abs(bt[1:] - bt[:-1]) + beat_frames = np.array([i for i, x in enumerate(onset) if x > BEAT_ONSET_THRESHOLD]) + pred_beat_times = beat_frames * HOP / SAMPLE_RATE + return clean_beats(pred_beat_times) if clean else pred_beat_times diff --git a/phasefinder/postproc/cleaner.py b/phasefinder/postproc/cleaner.py index 0b1c1d8..4a794cc 100644 --- a/phasefinder/postproc/cleaner.py +++ b/phasefinder/postproc/cleaner.py @@ -1,18 +1,6 @@ from collections import defaultdict -import numpy as np - -""" -f 0.873 - -CLEAN_THRESHOLD = 0.07 -OVERLAP_THRESHOLD = 0.3 -EARLY_THRESHOLD = 0.6 -LATE_THRESHOLD = 1.4 -MISSED_THRESHOLD = 1.6 -MODE_THRESHOLD = 0.001 -NUDGE_AMOUNT = 0.5 -""" +import numpy as np CLEAN_THRESHOLD = 0.03 OVERLAP_THRESHOLD = 0.35 @@ -22,31 +10,68 @@ MODE_THRESHOLD = 0.04 NUDGE_AMOUNT = 0.5 -def clean_beats(beat_times, clean_beats_threshold=CLEAN_THRESHOLD, overlap_threshold=OVERLAP_THRESHOLD, early_threshold=EARLY_THRESHOLD, late_threshold=LATE_THRESHOLD, missed_threshold=MISSED_THRESHOLD, nudge_amount=NUDGE_AMOUNT, mode_threshold=MODE_THRESHOLD): + +def clean_beats( + beat_times: np.ndarray, + clean_beats_threshold: float = CLEAN_THRESHOLD, + overlap_threshold: float = OVERLAP_THRESHOLD, + early_threshold: float = EARLY_THRESHOLD, + late_threshold: float = LATE_THRESHOLD, + missed_threshold: float = MISSED_THRESHOLD, + nudge_amount: float = NUDGE_AMOUNT, + mode_threshold: float = MODE_THRESHOLD, +) -> np.ndarray: cleaned_beats = _clean_beat_times(beat_times, clean_beats_threshold, mode_threshold) if len(cleaned_beats) < 3: - print(len(beat_times)) - corrected_beats = _correct_beat_sequence(beat_times, overlap_threshold, early_threshold, late_threshold, missed_threshold, nudge_amount, mode_threshold) + corrected_beats = _correct_beat_sequence( + beat_times, + overlap_threshold, + early_threshold, + late_threshold, + missed_threshold, + nudge_amount, + mode_threshold, + ) else: - corrected_beats = _correct_beat_sequence(cleaned_beats, overlap_threshold, early_threshold, late_threshold, missed_threshold, nudge_amount, mode_threshold) + corrected_beats = _correct_beat_sequence( + cleaned_beats, + overlap_threshold, + early_threshold, + late_threshold, + missed_threshold, + nudge_amount, + mode_threshold, + ) return np.array(corrected_beats) + def _clean_beat_times(beat_times, threshold, mode_threshold=MODE_THRESHOLD): if len(beat_times) < 2: return beat_times - intervals = [beat_times[i+1] - beat_times[i] for i in range(len(beat_times) - 1)] + intervals = [beat_times[i + 1] - beat_times[i] for i in range(len(beat_times) - 1)] interval_mode = find_interval_mode(intervals, threshold=mode_threshold) cleaned_beats = [beat_times[0]] i = 1 while i < len(beat_times) - 1: - if abs((beat_times[i+1] - beat_times[i-1]) - interval_mode) > threshold: + if abs((beat_times[i + 1] - beat_times[i - 1]) - interval_mode) > threshold: cleaned_beats.append(beat_times[i]) i += 1 cleaned_beats.append(beat_times[-1]) return cleaned_beats if cleaned_beats else beat_times -def _correct_beat_sequence(beat_times, overlap_threshold, early_beat_threshold, late_beat_threshold, missed_beat_threshold, nudge_amount, mode_threshold): - median_interval = find_interval_mode([beat_times[i] - beat_times[i-1] for i in range(1, len(beat_times))], mode_threshold) + +def _correct_beat_sequence( + beat_times, + overlap_threshold, + early_beat_threshold, + late_beat_threshold, + missed_beat_threshold, + nudge_amount, + mode_threshold, +): + median_interval = find_interval_mode( + [beat_times[i] - beat_times[i - 1] for i in range(1, len(beat_times))], mode_threshold + ) adjusted_beats = [beat_times[0]] for i in range(1, len(beat_times)): @@ -64,33 +89,15 @@ def _correct_beat_sequence(beat_times, overlap_threshold, early_beat_threshold, missed_beats = round(interval_ratio) for n in range(missed_beats - 1): adjusted_beats.append(adjusted_beats[-1] + median_interval) - adjusted_beats.append(beat_times[i]) + adjusted_beats.append(beat_times[i]) return adjusted_beats -def find_interval_mode(intervals, threshold=None): +def find_interval_mode(intervals: list[float], threshold: float = 0.001) -> float: thresh = threshold if threshold else 0.001 rounded_intervals = [round(interval / thresh) * thresh for interval in intervals] - interval_counts = defaultdict(int) + interval_counts: dict[float, int] = defaultdict(int) for interval in rounded_intervals: interval_counts[interval] += 1 interval_mode = max(interval_counts, key=interval_counts.get) return interval_mode - -def nudge(beat_times, interval_ratio): - """ - Moves all beats by a proportion of their interval mode. - - Parameters: - beat_times (list): List of beat times. - interval_ratio (float): Proportion of the interval mode to nudge the beats. - - Returns: - list: Nudged beat times. - """ - - interval_mode = find_interval_mode([beat_times[i] - beat_times[i-1] for i in range(1, len(beat_times))]) - nudge_amount = interval_ratio * interval_mode - - nudged_beats = [beat_time + nudge_amount for beat_time in beat_times] - return nudged_beats diff --git a/phasefinder/postproc/hmm.py b/phasefinder/postproc/hmm.py index 69232f1..113f511 100644 --- a/phasefinder/postproc/hmm.py +++ b/phasefinder/postproc/hmm.py @@ -1,45 +1,70 @@ +from typing import Union + import torch -def hmm_beat_estimation(phase_prediction, bpm, frame_rate, bpm_confidence=0.9, distance_threshold_factor=0.2, device='cpu'): + +def hmm_beat_estimation( + phase_prediction: torch.Tensor, + bpm: float, + frame_rate: float, + bpm_confidence: float = 0.9, + distance_threshold_factor: float = 0.2, + device: Union[str, torch.device] = "cpu", +) -> list[int]: num_states = phase_prediction.shape[1] seq_len = phase_prediction.shape[0] - transition_probs = calculate_transition_probs(num_states, bpm, frame_rate, bpm_confidence, distance_threshold_factor, device) + transition_probs = calculate_transition_probs( + num_states, bpm, frame_rate, bpm_confidence, distance_threshold_factor, device + ) emission_probs = phase_prediction.to(device) viterbi_probs = torch.zeros((seq_len, num_states), device=device) backpointers = torch.zeros((seq_len, num_states), dtype=torch.long, device=device) viterbi_probs[0] = emission_probs[0] + transition_probs[0] for t in range(1, seq_len): - prev_probs = viterbi_probs[t-1].unsqueeze(1) + prev_probs = viterbi_probs[t - 1].unsqueeze(1) curr_emis = emission_probs[t].unsqueeze(0) curr_probs = prev_probs + transition_probs + curr_emis viterbi_probs[t], backpointers[t] = torch.max(curr_probs, dim=0) beat_positions = backtrack(backpointers) return beat_positions -def calculate_transition_probs(num_states, bpm, frame_rate, bpm_confidence, distance_threshold_factor, device='cpu'): + +def calculate_transition_probs( + num_states: int, + bpm: float, + frame_rate: float, + bpm_confidence: float, + distance_threshold_factor: float, + device: Union[str, torch.device] = "cpu", +) -> torch.Tensor: frames_per_beat = frame_rate * 60 / bpm phase_change_per_frame = 360 / frames_per_beat i = torch.arange(num_states, device=device).float() * (360 / num_states) j = torch.arange(num_states, device=device).float() * (360 / num_states) expected_phase_diff = (i.unsqueeze(1) + phase_change_per_frame) % 360 - j.unsqueeze(0) - expected_phase_diff = torch.min(abs(expected_phase_diff), 360 - abs(expected_phase_diff)) - + expected_phase_diff = torch.min(torch.abs(expected_phase_diff), 360 - torch.abs(expected_phase_diff)) + distance_threshold = phase_change_per_frame * distance_threshold_factor - transition_probs = torch.where(expected_phase_diff <= distance_threshold, 1.0 - (expected_phase_diff / distance_threshold), torch.tensor(1e-10, device=device)) - + transition_probs = torch.where( + expected_phase_diff <= distance_threshold, + 1.0 - (expected_phase_diff / distance_threshold), + torch.tensor(1e-10, device=device), + ) + uniform_probs = torch.ones_like(transition_probs) / num_states transition_probs = bpm_confidence * transition_probs + (1 - bpm_confidence) * uniform_probs - + transition_probs = transition_probs / transition_probs.sum(dim=1, keepdim=True) return torch.log(transition_probs) -def backtrack(backpointers): + +def backtrack(backpointers: torch.Tensor) -> list[int]: seq_len = backpointers.shape[0] - beat_positions = [] + beat_positions: list[int] = [] curr_state = torch.argmax(backpointers[-1]) beat_positions.append(curr_state.item()) for t in range(seq_len - 2, -1, -1): curr_state = backpointers[t, curr_state] beat_positions.append(curr_state.item()) beat_positions = beat_positions[::-1] - return beat_positions \ No newline at end of file + return beat_positions diff --git a/phasefinder/postproc_gridsearch.py b/phasefinder/postproc_gridsearch.py index 40f7ce8..8176ea8 100644 --- a/phasefinder/postproc_gridsearch.py +++ b/phasefinder/postproc_gridsearch.py @@ -1,28 +1,43 @@ +import argparse +import csv import itertools +import os +import random +from datetime import datetime + +import mir_eval import numpy as np import torch -from model.model_noattn import PhasefinderModelNoattn -from dataset import BeatDataset from torch.utils.data import DataLoader from tqdm import tqdm -from postproc.hmm import hmm_beat_estimation -from postproc.cleaner import clean_beats -import mir_eval -import argparse -import csv -from datetime import datetime -import os -import random -def main(modelname, bpm_confidence=1.0, distance_threshold_factor=0.1, clean_beats_threshold=0.001, - overlap_threshold=0.3, early_beat_threshold=0.7, late_beat_threshold=1.3, missed_beat_threshold=1.7, mode_threshold=0.001, nudge_amount=0.5): - datapath = '../stft_db_b_phase_cleaned.h5' +from phasefinder.dataset import BeatDataset +from phasefinder.model.model_noattn import PhasefinderModelNoattn +from phasefinder.postproc.cleaner import clean_beats +from phasefinder.postproc.hmm import hmm_beat_estimation + + +def main( + modelname, + bpm_confidence=1.0, + distance_threshold_factor=0.1, + clean_beats_threshold=0.001, + overlap_threshold=0.3, + early_beat_threshold=0.7, + late_beat_threshold=1.3, + missed_beat_threshold=1.7, + mode_threshold=0.001, + nudge_amount=0.5, +): + datapath = "../stft_db_b_phase_cleaned.h5" beat_model = PhasefinderModelNoattn().cuda() - beat_model.load_state_dict(torch.load(modelname, map_location=torch.device('cuda'), weights_only=True), strict=False) + beat_model.load_state_dict( + torch.load(modelname, map_location=torch.device("cuda"), weights_only=True), strict=False + ) beat_model.eval() - dataset = BeatDataset(datapath, 'test', mode='beat', items=['stft', 'time', 'bpm'], device='cuda') + dataset = BeatDataset(datapath, "test", mode="beat", items=["stft", "time", "bpm"], device="cuda") dataloader = DataLoader(dataset, batch_size=1, shuffle=False) f_measures = [] @@ -31,31 +46,37 @@ def main(modelname, bpm_confidence=1.0, distance_threshold_factor=0.1, clean_bea with torch.no_grad(): for i, (stft, beat_times, bpm) in enumerate(tqdm(dataloader)): - stft, beat_times, bpm = stft.to('cuda'), beat_times[0].to('cuda'), bpm.to('cuda') + stft, beat_times, bpm = stft.to("cuda"), beat_times[0].to("cuda"), bpm.to("cuda") phase_preds = beat_model(stft) - - frame_rate = 22050. / 512 - res = hmm_beat_estimation(phase_preds[0][0].to('cuda'), bpm.item(), frame_rate, - bpm_confidence=bpm_confidence, - distance_threshold_factor=distance_threshold_factor, - device='cuda') + + frame_rate = 22050.0 / 512 + res = hmm_beat_estimation( + phase_preds[0][0].to("cuda"), + bpm.item(), + frame_rate, + bpm_confidence=bpm_confidence, + distance_threshold_factor=distance_threshold_factor, + device="cuda", + ) bt = torch.tensor(res) pred_beat_label_onset = torch.abs(bt[1:] - bt[:-1]) beat_frames = np.array([i for i, x in enumerate(pred_beat_label_onset) if x > 300]) - + pred_beat_times = beat_frames / frame_rate - cleaned_times = clean_beats(pred_beat_times, - clean_beats_threshold=clean_beats_threshold, - overlap_threshold=overlap_threshold, - early_threshold=early_beat_threshold, - late_threshold=late_beat_threshold, - missed_threshold=missed_beat_threshold, - mode_threshold=mode_threshold, - nudge_amount=nudge_amount) - + cleaned_times = clean_beats( + pred_beat_times, + clean_beats_threshold=clean_beats_threshold, + overlap_threshold=overlap_threshold, + early_threshold=early_beat_threshold, + late_threshold=late_beat_threshold, + missed_threshold=missed_beat_threshold, + mode_threshold=mode_threshold, + nudge_amount=nudge_amount, + ) + np_actual = beat_times.cpu().numpy() np_pred = cleaned_times - + cmlc, cmlt, amlc, amlt = mir_eval.beat.continuity(np_actual, np_pred) f_corr = mir_eval.beat.f_measure(np_actual, np_pred) @@ -63,7 +84,7 @@ def main(modelname, bpm_confidence=1.0, distance_threshold_factor=0.1, clean_bea cmlt_scores.append(cmlt) amlt_scores.append(amlt) - overall_f_measure = sum(f_measures) / len(f_measures) + overall_f_measure = sum(f_measures) / len(f_measures) print(f"Overall F-measure: {overall_f_measure:.3f}") overall_cmlt = sum(cmlt_scores) / len(cmlt_scores) overall_amlt = sum(amlt_scores) / len(amlt_scores) @@ -72,13 +93,14 @@ def main(modelname, bpm_confidence=1.0, distance_threshold_factor=0.1, clean_bea return overall_f_measure, overall_cmlt, overall_amlt + def read_existing_results(filename): tested_combinations = set() best_f_measure = 0 best_params = None - + if os.path.exists(filename): - with open(filename, 'r') as csvfile: + with open(filename, "r") as csvfile: reader = csv.reader(csvfile) next(reader) # Skip header for row in reader: @@ -88,29 +110,36 @@ def read_existing_results(filename): if f_measure > best_f_measure: best_f_measure = f_measure best_params = params - + return tested_combinations, best_f_measure, best_params + def grid_search(input_csv=None): # Define parameter ranges - bpm_confidences = [0.9] # [0.1, 0.5, 0.7, 0.9] + bpm_confidences = [0.9] distance_threshold_factors = [0.15, 0.2, 0.25] clean_beats_thresholds = [0.03, 0.05, 0.1] - overlap_thresholds = [ 0.15, 0.2,0.25, 0.35, 0.4] + overlap_thresholds = [0.15, 0.2, 0.25, 0.35, 0.4] early_beat_thresholds = [0.55, 0.6, 0.65] - late_beat_thresholds = [1.3,1.35, 1.4] - missed_beat_thresholds = [1.6,1.65,1.8] - mode_thresholds = [0.01,0.04,0.08, 0.1] + late_beat_thresholds = [1.3, 1.35, 1.4] + missed_beat_thresholds = [1.6, 1.65, 1.8] + mode_thresholds = [0.01, 0.04, 0.08, 0.1] nudge_amounts = [0.45, 0.5, 0.55] - # Generate all combinations of parameters - param_combinations = list(itertools.product( - bpm_confidences, distance_threshold_factors, clean_beats_thresholds, - overlap_thresholds, early_beat_thresholds, late_beat_thresholds, missed_beat_thresholds, mode_thresholds, nudge_amounts - )) - - + param_combinations = list( + itertools.product( + bpm_confidences, + distance_threshold_factors, + clean_beats_thresholds, + overlap_thresholds, + early_beat_thresholds, + late_beat_thresholds, + missed_beat_thresholds, + mode_thresholds, + nudge_amounts, + ) + ) # If input CSV is provided, read existing results if input_csv: @@ -124,19 +153,32 @@ def grid_search(input_csv=None): results_filename = f"grid_search_results_{timestamp}.csv" # Write header to the CSV file if it's a new file - with open(results_filename, 'w', newline='') as csvfile: + with open(results_filename, "w", newline="") as csvfile: writer = csv.writer(csvfile) - writer.writerow(['bpm_confidence', 'distance_threshold_factor', 'clean_beats_threshold', - 'overlap_threshold', 'early_beat_threshold', 'late_beat_threshold', - 'missed_beat_threshold','mode_threshold', 'nudge_amount', 'f_measure', 'cmlt', 'amlt']) + writer.writerow( + [ + "bpm_confidence", + "distance_threshold_factor", + "clean_beats_threshold", + "overlap_threshold", + "early_beat_threshold", + "late_beat_threshold", + "missed_beat_threshold", + "mode_threshold", + "nudge_amount", + "f_measure", + "cmlt", + "amlt", + ] + ) random.shuffle(param_combinations) - print(f'num combos = {len(param_combinations)}') + print(f"num combos = {len(param_combinations)}") for params in param_combinations: print(f"Parameters: {params}") - f_measure, cmlt, amlt = main('../phasefinder-0.1-noattn.pt', *params) + f_measure, cmlt, amlt = main("../phasefinder-0.1-noattn.pt", *params) # Append results to the CSV file - with open(results_filename, 'a', newline='') as csvfile: + with open(results_filename, "a", newline="") as csvfile: writer = csv.writer(csvfile) writer.writerow(list(params) + [f_measure, cmlt, amlt]) @@ -149,9 +191,10 @@ def grid_search(input_csv=None): print(f"Best F-measure: {best_f_measure}") print(f"Results saved to: {results_filename}") + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Run grid search for beat estimation parameters') - parser.add_argument('--input_csv', type=str, help='Path to input CSV file with existing results', default=None) + parser = argparse.ArgumentParser(description="Run grid search for beat estimation parameters") + parser.add_argument("--input_csv", type=str, help="Path to input CSV file with existing results", default=None) args = parser.parse_args() - - grid_search(args.input_csv) \ No newline at end of file + + grid_search(args.input_csv) diff --git a/phasefinder/predictor.py b/phasefinder/predictor.py index b2a62ab..197582e 100644 --- a/phasefinder/predictor.py +++ b/phasefinder/predictor.py @@ -1,28 +1,29 @@ -import torch +from pathlib import Path +from typing import Optional, Tuple, Union + import librosa -import torchaudio -from nnAudio.features import STFT -from deeprhythm import DeepRhythmPredictor import numpy as np import soundfile as sf +import torch +import torchaudio +from deeprhythm import DeepRhythmPredictor +from nnAudio.features import STFT +from phasefinder.audio.log_filter import apply_log_filter, create_log_filter +from phasefinder.constants import CLICK_SAMPLE_RATE, HOP, N_FFT, SAMPLE_RATE from phasefinder.model import PhasefinderModelAttn, PhasefinderModelNoattn -from phasefinder.utils import get_weights, get_device -from phasefinder.audio.log_filter import create_log_filter, apply_log_filter -from phasefinder.postproc.hmm import hmm_beat_estimation -from phasefinder.postproc.cleaner import clean_beats +from phasefinder.postproc import extract_beat_times +from phasefinder.utils import get_device, get_weights -N_FFT = 2048 -HOP = 512 -SAMPLE_RATE = 22050 class Phasefinder: - def __init__(self, modelname='phasefinder-0.1.pt', device=None, quiet=False, attention=False) -> None: - if device: - self.device = device - else: - self.device = get_device() - + def __init__( + self, + modelname: str = "phasefinder-0.1.pt", + device: Optional[str] = None, + quiet: bool = False, + attention: bool = False, + ) -> None: self.attention = attention self.quiet = quiet self.model_path = get_weights(modelname, quiet=quiet) @@ -32,62 +33,69 @@ def __init__(self, modelname='phasefinder-0.1.pt', device=None, quiet=False, att self.device = torch.device(device) self.load_model() - def load_model(self): + def load_model(self) -> None: if self.attention: self.model = PhasefinderModelAttn() else: self.model = PhasefinderModelNoattn() - - self.model.load_state_dict(torch.load(self.model_path, map_location=torch.device(self.device), weights_only=True), strict=False) + + self.model.load_state_dict( + torch.load(self.model_path, map_location=torch.device(self.device), weights_only=True), + strict=False, + ) self.model = self.model.to(self.device) self.model.eval() - self.bpm_model = DeepRhythmPredictor('deeprhythm-0.7.pth', device=self.device, quiet=self.quiet) + self.bpm_model = DeepRhythmPredictor("deeprhythm-0.7.pth", device=self.device, quiet=self.quiet) - fft_bins = int((N_FFT/2)+1) + fft_bins = int((N_FFT / 2) + 1) self.filter_matrix = create_log_filter(fft_bins, 81, device=self.device) - self.stft = STFT( - n_fft=N_FFT, - hop_length=HOP, - sr = SAMPLE_RATE, - output_format='Magnitude' - ) + self.stft = STFT(n_fft=N_FFT, hop_length=HOP, sr=SAMPLE_RATE, output_format="Magnitude") self.stft = self.stft.to(self.device) - - def predict(self, audio_path, include_bpm=False, clean=True): + + def predict( + self, + audio_path: Union[str, Path], + include_bpm: bool = False, + clean: bool = True, + ) -> Union[np.ndarray, Tuple[np.ndarray, float]]: bpm, confidence = self.bpm_model.predict(audio_path, include_confidence=True) audio, _ = librosa.load(audio_path, sr=SAMPLE_RATE) audio_tens = torch.tensor(audio).unsqueeze(0).unsqueeze(1).to(self.device) - - stft_batch = torchaudio.functional.amplitude_to_DB(torch.abs(self.stft(audio_tens)), multiplier=10., amin=0.00001, db_multiplier=1) - + + stft_batch = torchaudio.functional.amplitude_to_DB( + torch.abs(self.stft(audio_tens)), multiplier=10.0, amin=0.00001, db_multiplier=1 + ) + song_spec = apply_log_filter(stft_batch, self.filter_matrix) song_spec = (song_spec - song_spec.min()) / (song_spec.max() - song_spec.min()) - + phase_preds = self.model(song_spec) - - frame_rate = 22050. / 512 - res = hmm_beat_estimation(phase_preds[0, :, :].squeeze(0).to(self.device), bpm, bpm_confidence=confidence, frame_rate=frame_rate, device=self.device) - bt = torch.tensor(res) - - pred_beat_label_onset = torch.abs(bt[1:] - bt[:-1]) - beat_frames = torch.tensor([i for i, x in enumerate(pred_beat_label_onset) if x > 300]) - - pred_beat_times = beat_frames * HOP / SAMPLE_RATE - if clean: - pred_beat_times = clean_beats(pred_beat_times.numpy()) + + pred_beat_times = extract_beat_times( + phase_preds[0, :, :].squeeze(0).to(self.device), + bpm, + bpm_confidence=confidence, + clean=clean, + device=self.device, + ) if include_bpm: return pred_beat_times, bpm else: return pred_beat_times - - def make_click_track(self, audio_path, output_path='output.wav', beats=None, clean=True): - if not beats: + + def make_click_track( + self, + audio_path: Union[str, Path], + output_path: str = "output.wav", + beats: Optional[np.ndarray] = None, + clean: bool = True, + ) -> None: + if beats is None: beats = self.predict(audio_path, clean=clean) - audio, _ = librosa.load(audio_path, sr=44100) - click_track = librosa.clicks(times=beats, sr=44100, length=len(audio)) - audio_with_clicks = np.array([click_track, audio]) + audio, _ = librosa.load(audio_path, sr=CLICK_SAMPLE_RATE) + click_track = librosa.clicks(times=beats, sr=CLICK_SAMPLE_RATE, length=len(audio)) audio_with_clicks = np.vstack([click_track, audio]).T - sf.write(output_path, audio_with_clicks, 44100) \ No newline at end of file + sf.write(output_path, audio_with_clicks, CLICK_SAMPLE_RATE) diff --git a/phasefinder/test_model.py b/phasefinder/test_model.py index 967be77..b3f4276 100644 --- a/phasefinder/test_model.py +++ b/phasefinder/test_model.py @@ -1,19 +1,25 @@ -import torch -from model import PhasefinderModelNoattn -from val import test_model_f_measure import argparse +import torch + +from phasefinder.model import PhasefinderModelNoattn +from phasefinder.val import test_model_f_measure + + def main(modelname): - datapath = '../stft_db_b_phase_cleaned.h5' + datapath = "../stft_db_b_phase_cleaned.h5" beat_model = PhasefinderModelNoattn().cuda() - beat_model.load_state_dict(torch.load(modelname, map_location=torch.device('cuda'), weights_only=True), strict=False) + beat_model.load_state_dict( + torch.load(modelname, map_location=torch.device("cuda"), weights_only=True), strict=False + ) beat_model.eval() test_model_f_measure(beat_model, datapath) + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Test PhasefinderModel') - parser.add_argument('modelname', type=str, help='Path to the model file') + parser = argparse.ArgumentParser(description="Test PhasefinderModel") + parser.add_argument("modelname", type=str, help="Path to the model file") args = parser.parse_args() main(args.modelname) diff --git a/phasefinder/test_postproc.py b/phasefinder/test_postproc.py index 748a026..0e2dbaf 100644 --- a/phasefinder/test_postproc.py +++ b/phasefinder/test_postproc.py @@ -1,10 +1,9 @@ -from val import test_postprocessing_f_measure +from phasefinder.val import test_postprocessing_f_measure - -if __name__ == '__main__': - data_path = '../stft_db_b_phase_cleaned.h5' +if __name__ == "__main__": + data_path = "../stft_db_b_phase_cleaned.h5" f_measure, cmlt, amlt = test_postprocessing_f_measure(data_path) - print(f"Postprocessing Results:") + print("Postprocessing Results:") print(f"F-measure: {f_measure:.3f}") print(f"CMLt: {cmlt:.3f}") - print(f"AMLt: {amlt:.3f}") \ No newline at end of file + print(f"AMLt: {amlt:.3f}") diff --git a/phasefinder/train.py b/phasefinder/train.py index e14cdc6..ff0fdfa 100644 --- a/phasefinder/train.py +++ b/phasefinder/train.py @@ -1,131 +1,143 @@ -import torch -from torch.utils.data import DataLoader -from torch import nn, optim -from dataset import BeatDataset -from tqdm import tqdm -from torch.utils.tensorboard import SummaryWriter -from torch.optim.lr_scheduler import LambdaLR import argparse import json import os -from phasefinder.val import test_model_f_measure -from phasefinder.model.model_noattn import PhasefinderModelNoattn +import torch +from torch import nn, optim +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from phasefinder.dataset import BeatDataset from phasefinder.model.model_attn import PhasefinderModelAttn +from phasefinder.model.model_noattn import PhasefinderModelNoattn +from phasefinder.utils import get_device +from phasefinder.val import test_model_f_measure -parser = argparse.ArgumentParser(description='Train PhasefinderModel') -parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') -parser.add_argument('--phase_width', type=int, default=5, help='Phase width') -parser.add_argument('--model_root', type=str, default='kl9-pw5', help='Model root name') -parser.add_argument('--num_channels', type=int, default=36, help='Number of channels') -parser.add_argument('--num_classes', type=int, default=360, help='Number of classes') -parser.add_argument('--num_tcn_layers', type=int, default=16, help='Number of TCN layers') -parser.add_argument('--dilation', type=int, default=8, help='Dilation') -parser.add_argument('--start_epoch', type=int, default=0, help='Start Epoch') -parser.add_argument('--use_attention', action='store_true', help='Use attention mechanism') -parser.add_argument('--load_weights', type=str, default=None, help='Path to model weights file') -parser.add_argument('--data_path', type=str, default='stft_db_b_phase.hdf5', help='Path to dataset') -parser.add_argument('--max_epochs', type=int, default=20, help='Max epochs to train') +parser = argparse.ArgumentParser(description="Train PhasefinderModel") +parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") +parser.add_argument("--phase_width", type=int, default=5, help="Phase width") +parser.add_argument("--model_root", type=str, default="kl9-pw5", help="Model root name") +parser.add_argument("--num_channels", type=int, default=36, help="Number of channels") +parser.add_argument("--num_classes", type=int, default=360, help="Number of classes") +parser.add_argument("--num_tcn_layers", type=int, default=16, help="Number of TCN layers") +parser.add_argument("--dilation", type=int, default=8, help="Dilation") +parser.add_argument("--start_epoch", type=int, default=0, help="Start Epoch") +parser.add_argument("--use_attention", action="store_true", help="Use attention mechanism") +parser.add_argument("--load_weights", type=str, default=None, help="Path to model weights file") +parser.add_argument("--data_path", type=str, default="stft_db_b_phase.hdf5", help="Path to dataset") +parser.add_argument("--max_epochs", type=int, default=20, help="Max epochs to train") args = parser.parse_args() -print(args.use_attention) LR = args.lr PHASE_WIDTH = args.phase_width START_EPOCH = args.start_epoch model_root = args.model_root -def warmup_lambda(epoch): +device = get_device() + + +def warmup_lambda(epoch: int) -> float: ep = max(epoch, START_EPOCH) if ep < 5: return (ep + 1) / 5 return 1.0 + if args.use_attention: model = PhasefinderModelAttn( - num_bands=81, - num_channels=args.num_channels, - num_classes=args.num_classes, - num_tcn_layers=args.num_tcn_layers, - dilation=args.dilation, -) + num_bands=81, + num_channels=args.num_channels, + num_classes=args.num_classes, + num_tcn_layers=args.num_tcn_layers, + dilation=args.dilation, + ) else: model = PhasefinderModelNoattn( - num_bands=81, - num_channels=args.num_channels, - num_classes=args.num_classes, - num_tcn_layers=args.num_tcn_layers, - dilation=args.dilation, + num_bands=81, + num_channels=args.num_channels, + num_classes=args.num_classes, + num_tcn_layers=args.num_tcn_layers, + dilation=args.dilation, ) -if(args.load_weights): +if args.load_weights: model.load_state_dict(torch.load(args.load_weights, weights_only=True), strict=False) -model = model.cuda() +model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=LR) warmup_scheduler = LambdaLR(optimizer, warmup_lambda) criterion = nn.KLDivLoss(reduction="batchmean") -writer = SummaryWriter(f'runs/{model_root}') -scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3) +writer = SummaryWriter(f"runs/{model_root}") +scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.5, patience=3) + -def train_one_epoch(epoch): - train_dataset = BeatDataset(args.data_path, 'train', mode='beat', items=['stft', 'phase'], phase_width=PHASE_WIDTH) +def train_one_epoch(epoch: int) -> float: + train_dataset = BeatDataset(args.data_path, "train", mode="beat", items=["stft", "phase"], phase_width=PHASE_WIDTH) train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) model.train() - running_loss = 0.0 + total_loss = 0.0 + window_loss = 0.0 for i, (stft, beat_phase) in enumerate(tqdm(train_loader)): optimizer.zero_grad() - - stft = stft.cuda() - beat_phase = beat_phase.cuda() - + + stft = stft.to(device) + beat_phase = beat_phase.to(device) + beat_phase_pred = model(stft) loss = criterion(beat_phase_pred.unsqueeze(0), beat_phase) loss.backward() - + optimizer.step() - - running_loss += loss.item() + + total_loss += loss.item() + window_loss += loss.item() if (i + 1) % 1000 == 0: - writer.add_scalar('Loss/Train', running_loss / 1000, epoch * len(train_loader) + i) - running_loss = 0.0 - print(f"Epoch: {epoch}, Train Loss: {running_loss}") - return running_loss / len(train_loader) + writer.add_scalar("Loss/Train", window_loss / 1000, epoch * len(train_loader) + i) + window_loss = 0.0 + avg_loss = total_loss / len(train_loader) + print(f"Epoch: {epoch}, Train Loss: {avg_loss}") + return avg_loss -def validate(epoch): - val_dataset = BeatDataset(args.data_path, 'val', mode='beat', items=['stft', 'phase'], phase_width=PHASE_WIDTH) + +def validate(epoch: int) -> float: + val_dataset = BeatDataset(args.data_path, "val", mode="beat", items=["stft", "phase"], phase_width=PHASE_WIDTH) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) model.eval() val_loss = 0.0 with torch.no_grad(): for i, (stft, beat_phase) in enumerate(tqdm(val_loader)): - stft = stft.cuda() - beat_phase = beat_phase.cuda() + stft = stft.to(device) + beat_phase = beat_phase.to(device) beat_phase_pred = model(stft) - + loss = criterion(beat_phase_pred.unsqueeze(0), beat_phase) val_loss += loss.item() val_loss /= len(val_loader) print(f"Epoch: {epoch}, Val Loss: {val_loss}") - writer.add_scalar('Loss/Validate', val_loss, epoch) + writer.add_scalar("Loss/Validate", val_loss, epoch) return val_loss -def test(epoch): - overall_f_measure, overall_cmlt, overall_amlt = test_model_f_measure(model, args.data_path) - writer.add_scalar('F-Measure/Val', overall_f_measure, epoch) - writer.add_scalar('Accuracy/CMLt', overall_cmlt, epoch) - writer.add_scalar('Accuracy/AMLt', overall_amlt, epoch) + +def test(epoch: int) -> tuple[float, float, float]: + overall_f_measure, overall_cmlt, overall_amlt = test_model_f_measure(model, args.data_path, device=device) + writer.add_scalar("F-Measure/Val", overall_f_measure, epoch) + writer.add_scalar("Accuracy/CMLt", overall_cmlt, epoch) + writer.add_scalar("Accuracy/AMLt", overall_amlt, epoch) return overall_f_measure, overall_cmlt, overall_amlt -best_val_loss = float('inf') + +best_val_loss = float("inf") best_f_measure = 0 -best_model_path = '' +best_model_path = "" epochs_no_improve = 0 max_epochs = args.max_epochs -results_file_path = f'{model_root}_results.json' +results_file_path = f"{model_root}_results.json" if os.path.exists(results_file_path): - with open(results_file_path, 'r') as f: + with open(results_file_path, "r") as f: results = json.load(f) else: results = { @@ -134,15 +146,15 @@ def test(epoch): "val_loss": [], "f_measure": [], "cmlt": [], - "amlt": [] + "amlt": [], } -if __name__ == '__main__': +if __name__ == "__main__": for epoch in range(START_EPOCH, max_epochs): train_loss = train_one_epoch(epoch) val_loss = validate(epoch) f_measure, cmlt, amlt = test(epoch) - + # Log results results["epochs"].append(epoch) results["train_loss"].append(train_loss) @@ -150,14 +162,14 @@ def test(epoch): results["f_measure"].append(f_measure) results["cmlt"].append(cmlt) results["amlt"].append(amlt) - + # Save results to JSON file after each epoch - with open(f'{model_root}_results.json', 'w') as f: + with open(f"{model_root}_results.json", "w") as f: json.dump(results, f, indent=4) - + warmup_scheduler.step() scheduler.step(val_loss) - + save_model = False if val_loss < best_val_loss: best_val_loss = val_loss @@ -166,14 +178,14 @@ def test(epoch): if f_measure > best_f_measure: best_f_measure = f_measure save_model = True - + if save_model: - best_model_path = f'{model_root}_f_{f_measure:.3f}_epoch_{epoch}_loss_{val_loss:.4f}.pt' + best_model_path = f"{model_root}_f_{f_measure:.3f}_epoch_{epoch}_loss_{val_loss:.4f}.pt" torch.save(model.state_dict(), best_model_path) - + if not save_model: epochs_no_improve += 1 if epochs_no_improve == 15: print("Early stopping due to no improvement in validation loss.") break - writer.close() \ No newline at end of file + writer.close() diff --git a/phasefinder/utils/__init__.py b/phasefinder/utils/__init__.py index adf28de..3fdf33d 100644 --- a/phasefinder/utils/__init__.py +++ b/phasefinder/utils/__init__.py @@ -1,2 +1,2 @@ -from .get_weights import get_weights -from .get_device import get_device \ No newline at end of file +from .get_device import get_device as get_device +from .get_weights import get_weights as get_weights diff --git a/phasefinder/utils/detect_corruption.py b/phasefinder/utils/detect_corruption.py index 56a0839..dd3c1d4 100644 --- a/phasefinder/utils/detect_corruption.py +++ b/phasefinder/utils/detect_corruption.py @@ -2,61 +2,64 @@ import numpy as np from tqdm import tqdm + def compute_noise_features(stft): avg_spectrum = np.mean(np.abs(stft), axis=1) - + # Spectral Flatness geometric_mean = np.exp(np.mean(np.log(avg_spectrum + 1e-10))) arithmetic_mean = np.mean(avg_spectrum) spectral_flatness = geometric_mean / arithmetic_mean - + # High Frequency Ratio high_freq_start = int(0.7 * len(avg_spectrum)) high_freq_ratio = np.sum(avg_spectrum[high_freq_start:]) / np.sum(avg_spectrum) - + return spectral_flatness, high_freq_ratio + def detect_noisy_tracks(file_path, group, flatness_threshold=0.9985, high_freq_threshold=0.275): - with h5py.File(file_path, 'r') as file: + with h5py.File(file_path, "r") as file: noisy_tracks = [] for track in tqdm(file[group].keys(), desc=f"Analyzing {group}"): - stft = file[group][track]['stft'][:] + stft = file[group][track]["stft"][:] flatness, high_freq_ratio = compute_noise_features(stft) if flatness > flatness_threshold and high_freq_ratio > high_freq_threshold: noisy_tracks.append(track) return noisy_tracks + def remove_noisy_tracks(file_path, flatness_threshold=0.9985, high_freq_threshold=0.275): - with h5py.File(file_path, 'a') as h5_file: + with h5py.File(file_path, "a") as h5_file: # Identify noisy tracks in all groups noisy_tracks = {} - for group in ['train', 'test', 'val']: + for group in ["train", "test", "val"]: if group in h5_file: noisy_tracks[group] = detect_noisy_tracks(file_path, group, flatness_threshold, high_freq_threshold) print(f"Detected {len(noisy_tracks[group])} noisy tracks in {group} group") # Delete noisy tracks - for group in ['train', 'test', 'val']: + for group in ["train", "test", "val"]: if group in h5_file: group_noisy = set(noisy_tracks[group]) group_ref = h5_file[group] - + for track in tqdm(group_noisy, desc=f"Removing noisy tracks from {group}"): del group_ref[track] print(f"Cleaned dataset saved to {file_path}") - + # Verify the cleaning process def count_tracks(file_path): - with h5py.File(file_path, 'r') as f: - return {group: len(f[group]) for group in ['train', 'test', 'val'] if group in f} + with h5py.File(file_path, "r") as f: + return {group: len(f[group]) for group in ["train", "test", "val"] if group in f} -if __name__ == '__main__': +if __name__ == "__main__": # Example usage - file_path = 'stft_db_b_phase.hdf5' - print("\Original dataset:") + file_path = "stft_db_b_phase.hdf5" + print(r"\Original dataset:") print(count_tracks(file_path)) remove_noisy_tracks(file_path) diff --git a/phasefinder/utils/get_device.py b/phasefinder/utils/get_device.py index b5d2481..8e738f7 100644 --- a/phasefinder/utils/get_device.py +++ b/phasefinder/utils/get_device.py @@ -1,9 +1,10 @@ import torch -def get_device(): + +def get_device() -> str: if torch.cuda.is_available(): - return 'cuda' + return "cuda" elif torch.backends.mps.is_available(): - return 'mps' + return "mps" else: - return 'cpu' \ No newline at end of file + return "cpu" diff --git a/phasefinder/utils/get_weights.py b/phasefinder/utils/get_weights.py index be39b64..c128ce8 100644 --- a/phasefinder/utils/get_weights.py +++ b/phasefinder/utils/get_weights.py @@ -1,10 +1,11 @@ - import os + import requests -model_url = 'https://github.com/bleugreen/phasefinder/raw/main/' +model_url = "https://github.com/bleugreen/phasefinder/raw/main/" + -def get_weights(filename="phasefinder-0.1-noattn.pt", quiet=False): +def get_weights(filename: str = "phasefinder-0.1-noattn.pt", quiet: bool = False) -> str: # Construct the path to save the model weights home_dir = os.path.expanduser("~") model_dir = os.path.join(home_dir, ".local", "share", "phasefinder") @@ -20,9 +21,9 @@ def get_weights(filename="phasefinder-0.1-noattn.pt", quiet=False): print("Downloading model weights...") # Download the model weights try: - r = requests.get(model_url+filename, allow_redirects=True) + r = requests.get(model_url + filename, allow_redirects=True) if r.status_code == 200: - with open(model_path, 'wb') as f: + with open(model_path, "wb") as f: f.write(r.content) print("Model weights downloaded successfully.") else: @@ -33,4 +34,4 @@ def get_weights(filename="phasefinder-0.1-noattn.pt", quiet=False): if not quiet: print("Model weights already exist.") - return model_path \ No newline at end of file + return model_path diff --git a/phasefinder/utils/one_hots.py b/phasefinder/utils/one_hots.py index 643b309..a20a35b 100644 --- a/phasefinder/utils/one_hots.py +++ b/phasefinder/utils/one_hots.py @@ -1,5 +1,6 @@ -import torch import numpy as np +import torch + def calculate_beat_phase(num_frames, beat_times, sr, hop, K=360): frame_rate = sr / hop # Frames per second @@ -8,33 +9,35 @@ def calculate_beat_phase(num_frames, beat_times, sr, hop, K=360): # Add num_frames as a virtual beat end point to include the last segment beat_indices = torch.tensor(sorted(beat_indices), dtype=torch.long) - for i in range(len(beat_indices) -1): + for i in range(len(beat_indices) - 1): start_idx = beat_indices[i] end_idx = beat_indices[i + 1] num_intervals = end_idx - start_idx # Fill the phase up to just before the next beat index - if num_intervals > 0: + if num_intervals > 0: phase_step = torch.linspace(0, K, num_intervals + 1)[:-1] end = min(num_frames, end_idx) - phase_end = end-start_idx + phase_end = end - start_idx if end > start_idx: phase[start_idx:end] = phase_step[:phase_end] return phase.int() + def triangular_label(width): - peak = (width + 1) / 2 - tri_list= [min(i, width - i + 1) / peak for i in range(1, width + 1)] + peak = (width + 1) / 2 + tri_list = [min(i, width - i + 1) / peak for i in range(1, width + 1)] return torch.tensor(tri_list) / np.sum(tri_list) + def generate_blurred_one_hots_wrapped(indices, K=360, width=5): blur_vector = triangular_label(width) num_frames = indices.shape[-1] one_hots = torch.zeros((num_frames, K)) - for offset, weight in enumerate(blur_vector, -len(blur_vector)//2): + for offset, weight in enumerate(blur_vector, -len(blur_vector) // 2): wrapped_indices = (indices + offset) % K one_hots[torch.arange(num_frames), wrapped_indices] += weight diff --git a/phasefinder/val.py b/phasefinder/val.py index de40e56..a47fb2d 100644 --- a/phasefinder/val.py +++ b/phasefinder/val.py @@ -1,72 +1,84 @@ +import os +from typing import Tuple + +import librosa +import mir_eval import torch from torch.utils.data import DataLoader from tqdm import tqdm -import librosa -import os -import mir_eval -import numpy as np from phasefinder.dataset import BeatDataset -from phasefinder.postproc.hmm import hmm_beat_estimation -from phasefinder.postproc.cleaner import clean_beats +from phasefinder.postproc import extract_beat_times + + +def _report_scores( + f_measures: list[float], + cmlt_scores: list[float], + amlt_scores: list[float], +) -> Tuple[float, float, float]: + """Average and print evaluation scores.""" + overall_f_measure = sum(f_measures) / len(f_measures) + overall_cmlt = sum(cmlt_scores) / len(cmlt_scores) + overall_amlt = sum(amlt_scores) / len(amlt_scores) + print(f"Overall F-measure: {overall_f_measure:.3f}") + print(f"Overall CMLt: {overall_cmlt:.3f}") + print(f"Overall AMLt: {overall_amlt:.3f}") + return overall_f_measure, overall_cmlt, overall_amlt + -def test_model_f_measure(model, data_path, device='cuda'): - dataset = BeatDataset(data_path, 'test', mode='beat', items=['stft', 'time', 'bpm'], device=device) +def test_model_f_measure( + model: torch.nn.Module, + data_path: str, + device: str = "cuda", +) -> Tuple[float, float, float]: + dataset = BeatDataset(data_path, "test", mode="beat", items=["stft", "time", "bpm"], device=device) dataloader = DataLoader(dataset, batch_size=1, shuffle=False) model.eval() - cmlt_scores = [] - amlt_scores = [] - f_measures = [] + cmlt_scores: list[float] = [] + amlt_scores: list[float] = [] + f_measures: list[float] = [] with torch.no_grad(): for stft, beat_times, bpm in tqdm(dataloader): stft, beat_times, bpm = stft.to(device), beat_times[0].to(device), bpm.to(device) phase_preds = model(stft) - - frame_rate = 22050. / 512 - res = hmm_beat_estimation(phase_preds[0][0].to(device), bpm.item(), bpm_confidence=0.9, distance_threshold_factor=0.2, frame_rate=frame_rate, device=device) - bt = torch.tensor(res) - pred_beat_label_onset = torch.abs(bt[1:] - bt[:-1]) - beat_frames = np.array([i for i, x in enumerate(pred_beat_label_onset) if x > 300]) - - pred_beat_times = beat_frames * 512 / 22050 - cleaned_times = clean_beats(pred_beat_times) - + + cleaned_times = extract_beat_times( + phase_preds[0][0].to(device), + bpm.item(), + bpm_confidence=0.9, + distance_threshold_factor=0.2, + clean=True, + device=device, + ) + np_actual = beat_times.cpu().numpy() - np_pred = cleaned_times - - cmlc, cmlt, amlc, amlt = mir_eval.beat.continuity(np_actual, np_pred) - f_corr = mir_eval.beat.f_measure(np_actual, np_pred) + + cmlc, cmlt, amlc, amlt = mir_eval.beat.continuity(np_actual, cleaned_times) + f_corr = mir_eval.beat.f_measure(np_actual, cleaned_times) f_measures.append(f_corr) cmlt_scores.append(cmlt) amlt_scores.append(amlt) - overall_f_measure = sum(f_measures) / len(f_measures) - print(f"Overall F-measure: {overall_f_measure:.3f}") - overall_cmlt = sum(cmlt_scores) / len(cmlt_scores) - overall_amlt = sum(amlt_scores) / len(amlt_scores) - print(f"Overall CMLt: {overall_cmlt:.3f}") - print(f"Overall AMLt: {overall_amlt:.3f}") + return _report_scores(f_measures, cmlt_scores, amlt_scores) - return overall_f_measure, overall_cmlt, overall_amlt -def test_librosa_f_measure(data_path): - dataset = BeatDataset(data_path, 'test', mode='beat', items=['filepath', 'time', 'bpm'], device='cpu') +def test_librosa_f_measure(data_path: str) -> Tuple[float, float, float]: + dataset = BeatDataset(data_path, "test", mode="beat", items=["filepath", "time", "bpm"], device="cpu") dataloader = DataLoader(dataset, batch_size=1, shuffle=False) - f_measures = [] - cmlt_scores = [] - amlt_scores = [] + f_measures: list[float] = [] + cmlt_scores: list[float] = [] + amlt_scores: list[float] = [] with torch.no_grad(): - for filepath, beat_times, bpm in tqdm(dataloader): + for filepath, beat_times, bpm in tqdm(dataloader): assert os.path.exists(filepath[0]) - + audio, _ = librosa.load(filepath[0], sr=22050) tempo, beat_frames = librosa.beat.beat_track(y=audio, sr=22050, bpm=bpm.item()) np_actual = beat_times.cpu().numpy()[0] - - # Convert beat frames to timestamps + pred_beat_times = librosa.frames_to_time(beat_frames, sr=22050) cmlc, cmlt, amlc, amlt = mir_eval.beat.continuity(np_actual, pred_beat_times) @@ -75,53 +87,39 @@ def test_librosa_f_measure(data_path): cmlt_scores.append(cmlt) amlt_scores.append(amlt) - overall_f_measure = sum(f_measures) / len(f_measures) - print(f"Overall F-measure: {overall_f_measure:.3f}") - overall_cmlt = sum(cmlt_scores) / len(cmlt_scores) - overall_amlt = sum(amlt_scores) / len(amlt_scores) - print(f"Overall CMLt: {overall_cmlt:.3f}") - print(f"Overall AMLt: {overall_amlt:.3f}") + return _report_scores(f_measures, cmlt_scores, amlt_scores) - return overall_f_measure, overall_cmlt, overall_amlt - -def test_postprocessing_f_measure(data_path, device='cuda'): - dataset = BeatDataset(data_path, 'test', mode='beat', items=['stft', 'phase', 'time', 'bpm'], device=device) +def test_postprocessing_f_measure( + data_path: str, + device: str = "cuda", +) -> Tuple[float, float, float]: + dataset = BeatDataset(data_path, "test", mode="beat", items=["stft", "phase", "time", "bpm"], device=device) dataloader = DataLoader(dataset, batch_size=1, shuffle=False) - f_measures = [] - cmlt_scores = [] - amlt_scores = [] + f_measures: list[float] = [] + cmlt_scores: list[float] = [] + amlt_scores: list[float] = [] with torch.no_grad(): for stft, phase, beat_times, bpm in tqdm(dataloader): stft, phase, beat_times, bpm = stft.to(device), phase.to(device), beat_times[0].to(device), bpm.to(device) - - frame_rate = 22050. / 512 - frame_rate = 22050. / 512 - res = hmm_beat_estimation(phase[0][0].to(device), bpm.item(), bpm_confidence=0.9, frame_rate=frame_rate, device=device) - bt = torch.tensor(res) - pred_beat_label_onset = torch.abs(bt[1:] - bt[:-1]) - beat_frames = np.array([i for i, x in enumerate(pred_beat_label_onset) if x > 300]) - - pred_beat_times = beat_frames * 512 / 22050 - cleaned_times = clean_beats(pred_beat_times) - + + cleaned_times = extract_beat_times( + phase[0][0].to(device), + bpm.item(), + bpm_confidence=0.9, + clean=True, + device=device, + ) + np_actual = beat_times.cpu().numpy() - np_pred = cleaned_times - - cmlc, cmlt, amlc, amlt = mir_eval.beat.continuity(np_actual, np_pred) - f_corr = mir_eval.beat.f_measure(np_actual, np_pred) + + cmlc, cmlt, amlc, amlt = mir_eval.beat.continuity(np_actual, cleaned_times) + f_corr = mir_eval.beat.f_measure(np_actual, cleaned_times) f_measures.append(f_corr) cmlt_scores.append(cmlt) amlt_scores.append(amlt) - overall_f_measure = sum(f_measures) / len(f_measures) - print(f"Overall F-measure: {overall_f_measure:.3f}") - overall_cmlt = sum(cmlt_scores) / len(cmlt_scores) - overall_amlt = sum(amlt_scores) / len(amlt_scores) - print(f"Overall CMLt: {overall_cmlt:.3f}") - print(f"Overall AMLt: {overall_amlt:.3f}") - - return overall_f_measure, overall_cmlt, overall_amlt \ No newline at end of file + return _report_scores(f_measures, cmlt_scores, amlt_scores) diff --git a/pyproject.toml b/pyproject.toml index 0c3d4a4..688f71f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ dependencies = [ "numpy<2.0.0", "torch", - "pytorch-tcn", + "pytorch-tcn", "h5py", "librosa", "nnAudio", @@ -27,10 +27,23 @@ dependencies = [ "soundfile", "deeprhythm", "tqdm", - "librosa", - "mir_eval" + "mir_eval", ] +[project.optional-dependencies] +dev = ["pytest", "ruff"] + [project.urls] Homepage = "https://github.com/bleugreen/phasefinder" -Issues = "https://github.com/bleugreen/phasefinder/issues" \ No newline at end of file +Issues = "https://github.com/bleugreen/phasefinder/issues" + +[tool.pytest.ini_options] +testpaths = ["tests"] + +[tool.ruff] +line-length = 120 +target-version = "py38" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +ignore = ["E501"] diff --git a/requirements.txt b/requirements.txt index be61022..0742fd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,4 @@ torchaudio soundfile deeprhythm tqdm -librosa mir_eval diff --git a/test_fix.py b/test_fix.py index 88d49d7..066e5d0 100644 --- a/test_fix.py +++ b/test_fix.py @@ -6,12 +6,12 @@ try: from phasefinder import Phasefinder - + print("Testing Phasefinder initialization...") pf = Phasefinder(quiet=True) print("✓ Success! Phasefinder initialized without RuntimeError") print(f"✓ Model loaded on device: {pf.device}") - + except RuntimeError as e: if "Missing key(s) in state_dict" in str(e): print("✗ FAILED: RuntimeError still occurring") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_weight_loading.py b/tests/test_weight_loading.py new file mode 100644 index 0000000..29fdbd5 --- /dev/null +++ b/tests/test_weight_loading.py @@ -0,0 +1,24 @@ +"""Test that Phasefinder model weights load without errors.""" + +import pytest + + +def test_phasefinder_init(): + """Verify Phasefinder initializes and loads weights without RuntimeError.""" + from phasefinder import Phasefinder + + try: + pf = Phasefinder(quiet=True) + except FileNotFoundError: + pytest.skip("Model weights not available") + assert pf.device is not None + assert pf.model is not None + + +def test_constants_importable(): + """Verify constants module is importable and has expected values.""" + from phasefinder.constants import FRAME_RATE, HOP, SAMPLE_RATE + + assert SAMPLE_RATE == 22050 + assert HOP == 512 + assert abs(FRAME_RATE - 43.066) < 0.01