From 6ac0187fe2c3bc0fb00fa2677e780bc4ae56c86f Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:50:23 -0500 Subject: [PATCH 1/8] Remove internal code: training, dataset, benchmarks, root-level scripts --- batch_librosa.py | 133 ----------------- benchmark-tf.py | 65 --------- benchmark.py | 111 -------------- datacreate.py | 184 ------------------------ requirements.txt | 7 - src/deeprhythm/dataset/clip_dataset.py | 43 ------ src/deeprhythm/dataset/intake.py | 126 ---------------- src/deeprhythm/dataset/song_dataset.py | 43 ------ src/deeprhythm/dataset/utils.py | 35 ----- src/deeprhythm/model/infer.py | 61 -------- src/deeprhythm/model/train.py | 192 ------------------------- training.ipynb | 59 -------- 12 files changed, 1059 deletions(-) delete mode 100644 batch_librosa.py delete mode 100644 benchmark-tf.py delete mode 100644 benchmark.py delete mode 100644 datacreate.py delete mode 100644 requirements.txt delete mode 100644 src/deeprhythm/dataset/clip_dataset.py delete mode 100644 src/deeprhythm/dataset/intake.py delete mode 100644 src/deeprhythm/dataset/song_dataset.py delete mode 100644 src/deeprhythm/dataset/utils.py delete mode 100644 src/deeprhythm/model/infer.py delete mode 100644 src/deeprhythm/model/train.py delete mode 100644 training.ipynb diff --git a/batch_librosa.py b/batch_librosa.py deleted file mode 100644 index a319f47..0000000 --- a/batch_librosa.py +++ /dev/null @@ -1,133 +0,0 @@ - -import sys -sys.path.append('../deeprhythm/src') -import json - -import os -import torch.multiprocessing as multiprocessing -import torch -import time -from deeprhythm.utils import load_and_split_audio -from deeprhythm.audio_proc.hcqm import make_kernels, compute_hcqm -from deeprhythm.model.infer import load_cnn_model -from deeprhythm.utils import class_to_bpm -from deeprhythm.utils import get_device -import librosa - -NUM_WORKERS = 8 -NUM_BATCH = 256 - - -def producer(task_queue, result_queue, completion_event, queue_condition, queue_threshold=NUM_BATCH*2): - """ - Producer function that waits on a shared condition if the result_queue is above a certain threshold - immediately after getting a task and before loading and processing the audio. - """ - while True: - task = task_queue.get() - if task is None: - result_queue.put(None) # Send termination signal to indicate this producer is done - completion_event.wait() # Wait for the signal to exit - break - filename = task - with queue_condition: # Use the condition to wait if the queue is too full before loading audio - while result_queue.qsize() >= queue_threshold: - queue_condition.wait() - # After ensuring the queue is not full, proceed to load and process audio - y, sr = librosa.load(filename, sr=22050) - bpm = librosa.beat.tempo(y=y, sr=sr) - if bpm: - result_queue.put((bpm, filename)) - - -def init_workers(dataset, data_path, group, n_workers=NUM_WORKERS): - """ - Initializes worker processes for multiprocessing, setting up the required queues, - an event for coordinated exit, and a condition for queue threshold management. - - Parameters: - - n_workers: Number of worker processes to start. - - dataset: The dataset items to process. - - queue_threshold: The threshold for the result queue before producers wait. - """ - manager = multiprocessing.Manager() - task_queue = multiprocessing.Queue() - result_queue = manager.Queue() # Managed Queue for sharing across processes - completion_event = manager.Event() - queue_condition = manager.Condition() - - # Create producer processes - producers = [ - multiprocessing.Process( - target=producer, - args=(task_queue, result_queue, completion_event, queue_condition) - ) for _ in range(n_workers) - ] - - # Start all producers - for p in producers: - p.start() - - for item in dataset: - task_queue.put(item) - - # Signal each producer to terminate once all tasks are processed - for _ in range(n_workers): - task_queue.put(None) - - return task_queue, result_queue, producers, completion_event, queue_condition - - - -def consume_and_process(result_queue, data_path, queue_condition, n_workers=NUM_WORKERS, max_len_batch=NUM_BATCH, device='cuda'): - - active_producers = n_workers - print(f'producers = {active_producers}') - while active_producers > 0: - result = result_queue.get() - with queue_condition: - queue_condition.notify_all() - if result is None: - active_producers -= 1 - print(f'producers = {active_producers}') - continue - bpm, filename = result - print(f'filename: {filename}, bpm: {bpm}') - -def main(dataset, n_workers=NUM_WORKERS, max_len_batch=NUM_BATCH, data_path='output.hdf5', device='cuda'): - task_queue, result_queue, producers, completion_event, queue_condition = init_workers(dataset,data_path, 'group', n_workers) - try: - consume_and_process(result_queue, data_path, queue_condition, n_workers=n_workers,max_len_batch=max_len_batch, device=device) - finally: - completion_event.set() - for p in producers: - p.join() # Ensure all producer processes have finished - - -def get_audio_files(dir_path): - """ - Collects all audio files recursively from a specified directory. - """ - audio_files = [] - for root, _, files in os.walk(dir_path): - for file in files: - if file.lower().endswith(('.wav', '.mp3', '.flac')): - audio_files.append(os.path.join(root, file)) - return audio_files - -if __name__ == '__main__': - multiprocessing.set_start_method('spawn', force=True) - torch.cuda.empty_cache() - - root_dir = sys.argv[1] - songs = get_audio_files(root_dir) - print(len(songs),'songs found') - data_path = sys.argv[2] if len(sys.argv) > 2 else 'batch_results.jsonl' - - start = time.time() - main(songs, n_workers=NUM_WORKERS, data_path=data_path) - - print(f'Total Duration: {time.time()-start:.2f}') - torch.cuda.empty_cache() - with open(data_path, 'r') as f: - print('Total Length', sum(1 for _ in f)) diff --git a/benchmark-tf.py b/benchmark-tf.py deleted file mode 100644 index d4adfa3..0000000 --- a/benchmark-tf.py +++ /dev/null @@ -1,65 +0,0 @@ -import pandas as pd -import os -from tempocnn.classifier import TempoClassifier -from tempocnn.feature import read_features -import time - -import tensorflow as tf -print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU'))) -def estimate_tempo_cnn(audio_path, model): - features = read_features(audio_path) - bpm = model.estimate_tempo(features, interpolate=False) - print(bpm) - return bpm - - -def is_within_tolerance(predicted_bpm, true_bpm, tolerance=0.02, multiples=[1]): - for multiple in multiples: - if true_bpm * multiple * (1 - tolerance) <= predicted_bpm <= true_bpm * multiple * (1 + tolerance): - return True - return False - - -def run_benchmark(test_set, estimation_methods): - results = {method: {'times': [], 'accuracy1': [], 'accuracy2':[]} for method in estimation_methods} - for method_name, method_func in estimation_methods.items(): - for _, row in test_set.iterrows(): - true_bpm = row['bpm'] - audio_path = os.path.join('/media/bleu/bulkdata2/deeprhythmdata', row['filename']) - start_time = time.time() - predicted_bpm = method_func(audio_path) - elapsed_time = time.time() - start_time - results[method_name]['times'].append(elapsed_time) - correct1 = is_within_tolerance(predicted_bpm, true_bpm) - results[method_name]['accuracy1'].append(correct1) - correct2 = is_within_tolerance(predicted_bpm, true_bpm, multiples=[0.5, 1, 2, 3]) - results[method_name]['accuracy2'].append(correct2) - - return results - - -def generate_report(results): - for method, metrics in results.items(): - accuracy1 = sum(metrics['accuracy1']) / len(metrics['accuracy1']) * 100 - accuracy2 = sum(metrics['accuracy2']) / len(metrics['accuracy2']) * 100 - - avg_time = sum(metrics['times']) / len(metrics['times']) - print('-----'*20) - print(f"{method:<18}: Acc1 = {accuracy1:.2f}%, Acc2 = {accuracy2:.2f}%, Avg Time = {avg_time:.4f}s, Total={sum(metrics['times']):.2f}s") - -if __name__ == '__main__': - test_set = pd.read_csv('/media/bleu/bulkdata2/deeprhythmdata/test.csv') - fcn_model = TempoClassifier('fcn') - cnn_model = TempoClassifier('cnn') - - # Define the estimation methods - methods = { - 'TempoCNN (cnn)': lambda audio_path: estimate_tempo_cnn(audio_path, cnn_model), - 'TempoCNN (fcn)': lambda audio_path: estimate_tempo_cnn(audio_path, fcn_model), - } - - # Run the benchmark - results = run_benchmark(test_set, methods) - - # Generate the report - generate_report(results) diff --git a/benchmark.py b/benchmark.py deleted file mode 100644 index 2d31195..0000000 --- a/benchmark.py +++ /dev/null @@ -1,111 +0,0 @@ -import pandas as pd -import os -import librosa -import essentia.standard as es -import time -import sys -sys.path.append('/home/bleu/ai/deeprhythm/src') - -from deeprhythm.model.infer import predict_global_bpm, make_kernels, load_cnn_model, predict_global_bpm_cont - - - -def estimate_tempo_essentia_multi(audio_path): - audio = es.MonoLoader(filename=audio_path)() - extractor_multi = es.RhythmExtractor2013(method="multifeature") - bpm, beats, beats_confidence, _, beats_intervals = extractor_multi(audio) - print(bpm) - return bpm - -def estimate_tempo_essentia_percival(audio_path): - audio = es.MonoLoader(filename=audio_path)() - bpm = es.PercivalBpmEstimator()(audio) - print(bpm) - return bpm - -def estimate_tempo_essentia_degara(audio_path): - audio = es.MonoLoader(filename=audio_path)() - extractor_deg = es.RhythmExtractor2013(method="degara") - bpm, beats, beats_confidence, _, beats_intervals = extractor_deg(audio) - print(bpm) - return bpm - -def estimate_tempo_librosa(audio_path): - audio, _ = librosa.load(audio_path, sr=22050) - bpm = librosa.beat.tempo(y=audio, sr=22050)[0] - print(bpm) - return bpm - -def estimate_tempo_cnn(audio_path, model, specs): - bpm= predict_global_bpm(audio_path, model=model, specs=specs)[0] - print(bpm) - return bpm - -def estimate_tempo_cnn_cont(audio_path, model, specs): - bpm= predict_global_bpm_cont(audio_path, model=model, specs=specs)[0] - print(bpm) - return bpm - -def is_within_tolerance(predicted_bpm, true_bpm, tolerance=0.02, multiples=[1]): - for multiple in multiples: - if true_bpm * multiple * (1 - tolerance) <= predicted_bpm <= true_bpm * multiple * (1 + tolerance): - return True - return False - - -def run_benchmark(test_set, estimation_methods): - results = {method: {'times': [], 'accuracy1': [], 'accuracy2':[]} for method in estimation_methods} - for method_name, method_func in estimation_methods.items(): - for _, row in test_set.iterrows(): - if row['source'] == 'fma': - continue - true_bpm = row['bpm'] - audio_path = os.path.join('/media/bleu/bulkdata2/deeprhythmdata', row['filename']) - start_time = time.time() - predicted_bpm = method_func(audio_path) - elapsed_time = time.time() - start_time - results[method_name]['times'].append(elapsed_time) - correct1 = is_within_tolerance(predicted_bpm, true_bpm) - results[method_name]['accuracy1'].append(correct1) - correct2 = is_within_tolerance(predicted_bpm, true_bpm, multiples=[0.5, 1, 2, 3]) - results[method_name]['accuracy2'].append(correct2) - - return results - - -def generate_report(results): - print('Test Songs:', len(results['DeepRhythm (cpu)']['times'])) - for method, metrics in results.items(): - accuracy1 = sum(metrics['accuracy1']) / len(metrics['accuracy1']) * 100 - accuracy2 = sum(metrics['accuracy2']) / len(metrics['accuracy2']) * 100 - - avg_time = sum(metrics['times']) / len(metrics['times']) - print('-----'*20) - print(f"{method:<18}: Acc1 = {accuracy1:.2f}%, Acc2 = {accuracy2:.2f}%, Avg Time = {avg_time:.4f}s, Total={sum(metrics['times']):.2f}s") - -if __name__ == '__main__': - test_set = pd.read_csv('/media/bleu/bulkdata2/deeprhythmdata/test.csv') - - cpu_model = load_cnn_model(device='cpu') - cpu_specs = make_kernels(device='cpu') - - cuda_model = load_cnn_model(device='cuda') - cuda_specs = make_kernels(device='cuda') - - # Define the estimation methods - methods = { - 'Essentia (multi)': lambda audio_path: estimate_tempo_essentia_multi(audio_path), - 'Essentia (percival)':estimate_tempo_essentia_percival, - 'Essentia (degara)': lambda audio_path: estimate_tempo_essentia_degara(audio_path), - 'Librosa': estimate_tempo_librosa, - 'DeepRhythm (cuda)': lambda audio_path: estimate_tempo_cnn(audio_path, cuda_model, cuda_specs), - 'DeepRhythm (cpu)': lambda audio_path: estimate_tempo_cnn(audio_path, cpu_model, cpu_specs), - - - } - - # Run the benchmark - results = run_benchmark(test_set, methods) - - # Generate the report - generate_report(results) diff --git a/datacreate.py b/datacreate.py deleted file mode 100644 index 3b33c9d..0000000 --- a/datacreate.py +++ /dev/null @@ -1,184 +0,0 @@ - -import sys -sys.path.append('/home/bleu/ai/deeprhythm/src') - -import h5py -import os -import torch.multiprocessing as multiprocessing -from deeprhythm.audio_proc.hcqm import make_kernels, compute_hcqm -import torch -import time -from deeprhythm.utils import load_and_split_audio -import csv - -NUM_WORKERS = 16 -NUM_BATCH = 1024 - - -def producer(task_queue, result_queue, completion_event, queue_condition, queue_threshold=NUM_BATCH): - """ - Producer function that waits on a shared condition if the result_queue is above a certain threshold - immediately after getting a task and before loading and processing the audio. - """ - while True: - task = task_queue.get() - if task is None: - result_queue.put(None) # Send termination signal to indicate this producer is done - completion_event.wait() # Wait for the signal to exit - break - id, filename, genre, source,num_clips, bpm = task - with queue_condition: # Use the condition to wait if the queue is too full before loading audio - while result_queue.qsize() >= queue_threshold: - queue_condition.wait() - root_dir = '/media/bleu/bulkdata2/deeprhythmdata' - full_path = os.path.join(root_dir, filename) - # After ensuring the queue is not full, proceed to load and process audio - clips = load_and_split_audio(full_path, share_mem=True) - if clips is not None: - result_queue.put((clips, filename, bpm, genre, source)) - -def init_workers(dataset, data_path, group, n_workers=NUM_WORKERS): - """ - Initializes worker processes for multiprocessing, setting up the required queues, - an event for coordinated exit, and a condition for queue threshold management. - - Parameters: - - n_workers: Number of worker processes to start. - - dataset: The dataset items to process. - - queue_threshold: The threshold for the result queue before producers wait. - """ - manager = multiprocessing.Manager() - task_queue = multiprocessing.Queue() - result_queue = manager.Queue() # Managed Queue for sharing across processes - completion_event = manager.Event() - queue_condition = manager.Condition() - - # Create producer processes - producers = [ - multiprocessing.Process( - target=producer, - args=(task_queue, result_queue, completion_event, queue_condition) - ) for _ in range(n_workers) - ] - - # Start all producers - for p in producers: - p.start() - with h5py.File(data_path, 'r') as h5f: - for item in dataset: - id, filename, genre, bpm, source, _ = item - if f'{group}/{os.path.basename(filename)}' not in h5f: - task_queue.put(item) - - # Signal each producer to terminate once all tasks are processed - for _ in range(n_workers): - task_queue.put(None) - - return task_queue, result_queue, producers, completion_event, queue_condition - -def process_and_save(batch_audio, batch_meta, specs, h5f_path, group): - """ - Processes a batch of audio clips and saves the result along with metadata to an HDF5 file. - """ - # print('batch tensor shape', batch_audio.shape) - stft, band, cqt = specs - hcqm = compute_hcqm(batch_audio, stft, band, cqt) - torch.cuda.empty_cache() - print('hcqm done', hcqm.shape) - for meta in batch_meta: - filename, bpm, genre, source, num_clips, start_idx = meta - song_clips = hcqm[start_idx:start_idx+num_clips, :, :, :] - with h5py.File(h5f_path, 'a') as h5f: - if f'{group}/{os.path.basename(filename)}' in h5f: - return - clip_group = h5f.create_group(f'{group}/{os.path.basename(filename)}') - clip_group.create_dataset('hcqm', data=song_clips.cpu().numpy()) - clip_group.attrs['bpm'] = float(bpm) - clip_group.attrs['genre'] = genre - clip_group.attrs['filepath'] = filename - clip_group.attrs['source'] = source - -def consume_and_process(result_queue, data_path, queue_condition, n_workers=NUM_WORKERS, max_len_batch=NUM_BATCH, group='data'): - batch_audio = [] - batch_meta = [] - active_producers = n_workers - sr = 22050 - len_audio = sr * 8 - specs = make_kernels(len_audio, sr) - total_clips = 0 - print(f'producers = {active_producers}') - while active_producers > 0: - result = result_queue.get() - with queue_condition: - queue_condition.notify_all() - if result is None: - active_producers -= 1 - print(f'producers = {active_producers}') - continue - clips, filename, bpm, genre, source = result - with h5py.File(data_path, 'r') as h5f: - if f'{group}/{os.path.basename(filename)}' not in h5f: - batch_audio.append(clips) - num_clips = clips.shape[0] - start_idx = total_clips - batch_meta.append((filename, bpm, genre, source, num_clips, start_idx)) - total_clips += num_clips - if total_clips >= max_len_batch: - stacked_batch_audio = torch.cat(batch_audio, dim=0).cuda() - process_and_save(stacked_batch_audio, batch_meta, specs, data_path, group) - total_clips = 0 - batch_audio = [] - batch_meta = [] - - # Make sure to process any remaining clips - if batch_audio: - stacked_batch_audio = torch.cat(batch_audio, dim=0).cuda() - process_and_save(stacked_batch_audio, batch_meta, specs, data_path, group) - pass - - -def main(dataset, n_workers=NUM_WORKERS, max_len_batch=NUM_BATCH, data_path='output.hdf5', group='data'): - task_queue, result_queue, producers, completion_event, queue_condition = init_workers(dataset,data_path, group, n_workers) - try: - consume_and_process(result_queue, data_path, queue_condition, n_workers=n_workers,max_len_batch=max_len_batch, group=group, ) - finally: - completion_event.set() - for p in producers: - p.join() # Ensure all producer processes have finished - - -def read_csv_to_tuples(csv_file_path): - data_tuples = [] - with open(csv_file_path, newline='') as csvfile: - reader = csv.reader(csvfile) - next(reader) # Skip the header row - for row in reader: - modified_row = row - data_tuples.append(tuple(modified_row)) - return data_tuples - - -if __name__ == '__main__': - multiprocessing.set_start_method('spawn', force=True) - torch.cuda.empty_cache() - train_songs = read_csv_to_tuples('/media/bleu/bulkdata2/deeprhythmdata/train.csv') - test_songs = read_csv_to_tuples('/media/bleu/bulkdata2/deeprhythmdata/test.csv') - val_songs = read_csv_to_tuples('/media/bleu/bulkdata2/deeprhythmdata/val.csv') - # idx, id, bpm, filename, genre, source - # print(test_songs[0]) - data_path = '/media/bleu/bulkdata2/deeprhythmdata/hcqm-split.hdf5' - with h5py.File(data_path, 'w') as hdf5_file: - # Create groups 'train', 'test', and 'val' within the HDF5 file - hdf5_file.create_group('train') - hdf5_file.create_group('test') - hdf5_file.create_group('val') - start = time.time() - main(train_songs, n_workers=16, data_path=data_path, group='train') - main(test_songs, n_workers=16, data_path=data_path, group='test') - main(val_songs, n_workers=16, data_path=data_path, group='val') - - print(f'Total Duration: {time.time()-start:.2f}') - torch.cuda.empty_cache() - hdf5_filename = data_path - with h5py.File(hdf5_filename, 'r') as f: - print('Total Length', sum([len(f.get(key)) for key in f.keys()])) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 48c7b28..0000000 --- a/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -librosa -torch -pandas -numpy -nnAudio -h5py -torchaudio \ No newline at end of file diff --git a/src/deeprhythm/dataset/clip_dataset.py b/src/deeprhythm/dataset/clip_dataset.py deleted file mode 100644 index 47b3f69..0000000 --- a/src/deeprhythm/dataset/clip_dataset.py +++ /dev/null @@ -1,43 +0,0 @@ -import h5py -import torch -from torch.utils.data import Dataset -from deeprhythm.utils import bpm_to_class -class ClipDataset(Dataset): - def __init__(self, hdf5_file, group, use_float=False): - """ - :param hdf5_file: Path to the HDF5 file. - :param group: Group in the HDF5 file to use ('train', 'test', 'validate'). - """ - self.use_float = use_float - self.hdf5_file = hdf5_file - self.group = group - self.index_map = [] - self.file_ref = h5py.File(self.hdf5_file, 'r') - group_data = self.file_ref[group] - for song_key in group_data.keys(): - song_data = group_data[song_key] - if song_data.attrs['source'] == 'fma': - continue - num_clips = song_data['hcqm'].shape[0] - if num_clips > 5: - clip_start = 1 - clip_range = num_clips-2 - else: - clip_start, clip_range = 0, num_clips - - for clip_index in range(clip_start, clip_range): - self.index_map.append((song_key, clip_index)) - - def __len__(self): - return len(self.index_map) - - def __getitem__(self, idx): - song_key, clip_index = self.index_map[idx] - song_data = self.file_ref[self.group][song_key] - hcqm = song_data['hcqm'][clip_index] - bpm = torch.tensor(float(song_data.attrs['bpm']), dtype=torch.float32) - hcqm_tensor = torch.tensor(hcqm, dtype=torch.float).permute(2, 0, 1) - if self.use_float: - return hcqm_tensor, bpm - label_class_index = bpm_to_class(int(bpm)) # Convert BPM to class index - return hcqm_tensor, label_class_index diff --git a/src/deeprhythm/dataset/intake.py b/src/deeprhythm/dataset/intake.py deleted file mode 100644 index fbac63f..0000000 --- a/src/deeprhythm/dataset/intake.py +++ /dev/null @@ -1,126 +0,0 @@ -import zipfile -import os -from data.fma import utils -import pandas as pd - - -def extract_fma_file(filename): - path = '/media/bleu/bulkdata/datasets/fma_full.zip' - extract_path = 'data/' - - with zipfile.ZipFile(path, 'r') as zip_ref: - zip_ref.extract(filename, extract_path) - return extract_path+filename - -def extract_ballroom_song_bpm(file_path): - """ - Extracts (song_id, bpm) tuples from the given metadata file. - - Args: - - file_path (str): The path to the metadata file. - - Returns: - - List[Tuple[str, str]]: A list of tuples containing the song ID and BPM. - """ - song_bpm_pairs = [] - with open(file_path, 'r', encoding='ISO-8859-1') as file: - lines = file.readlines() - - i = 0 - while i < len(lines): - line = lines[i].strip() - if line.startswith('http://www.ballroomdancers.com/Music/'): - # Extract song ID - parts = line.split('/') - if 'Media' in parts: - song_name = 'Media-' + parts[-1].split('.')[0] - else: - song_name = parts[-3] + '-' + parts[-2] + '-' + parts[-1].split('.')[0] - song_name = song_name.replace('.ram', '.wav') - - i += 1 - song_id = 'Media-' +lines[i].split('Song=')[1].split(' ')[0] - while i < len(lines) and ' BPM' not in lines[i]: - i += 1 - if i < len(lines): - bpm = lines[i].strip().split(' ')[0] - song_bpm_pairs.append((song_name, song_id, bpm)) - i += 1 - - return song_bpm_pairs - -def process_ballroom_folder(genre): - file_path = f'data/BallroomData/nada/{genre}.log' - dir_path = f'data/BallroomData/{genre}/' - dircount = len(os.listdir(dir_path)) - extracted_song_bpm = extract_ballroom_song_bpm(file_path) - results = [] - for (track, media_id, bpm) in extracted_song_bpm: - path = dir_path + track + '.wav' - results.append((path, bpm, genre)) - return results, dircount - -def make_ballroom_dataset(): - dirpath = 'data/BallroomData' - all_entries = os.listdir(dirpath) - genres = [entry for entry in all_entries if os.path.isdir(os.path.join(dirpath, entry)) and entry[0].isupper()] - - invalid = [] - all_tracks = [] - track_count = 0 - for genre in genres: - tracks, dircount = process_ballroom_folder(genre) - track_count += dircount - for (path, bpm, genre) in tracks: - if not os.path.isfile(path): - invalid.append((path, bpm, genre)) - else: - all_tracks.append((path, bpm, genre, 'ballroom')) - return all_tracks - -def make_giantsteps_dataset(): - audio_dir = 'data/giantsteps-tempo-dataset/audio' - tempo_dir = 'data/giantsteps-tempo-dataset/annotations_v2/tempo' - genre_dir = 'data/giantsteps-tempo-dataset/annotations/genre' - audio_files = os.listdir(audio_dir) - dataset = [] - for audio_file in audio_files: - if audio_file.endswith('.mp3'): - audio_id = audio_file[:-4] - bpm_path = os.path.join(tempo_dir, f'{audio_id}.bpm') - genre_path = os.path.join(genre_dir, f'{audio_id}.genre') - with open(bpm_path, 'r') as bpm_file: - bpm = float(bpm_file.read().strip()) - with open(genre_path, 'r') as genre_file: - genre = genre_file.read().strip() - dataset.append((os.path.join(audio_dir, audio_file), bpm, genre, 'giantsteps')) - return dataset - -def make_fma_dataset(): - tracks = utils.load('data/fma/data/fma_metadata/tracks.csv')['track'] - echonest = utils.load('data/fma/data/fma_metadata/echonest.csv') - genres = utils.load('data/fma/data/fma_metadata/genres.csv') - AUDIO_DIR = 'data/fma_full/' - - def get_genre(row): - if row['genres_all']: - genre_titles = [genres.loc[genre_id, 'title'] for genre_id in row['genres_all'] if genre_id in genres.index] - genre_full = '' - for g in genre_titles: - genre_full += g+', ' - return genre_full[:-2] - return None - - tracks['genre'] = tracks.apply(get_genre, axis=1) - - tempo_data = echonest['echonest', 'audio_features']['tempo'].astype(int) - genre_data = tracks['genre'].dropna() - - merged_data = pd.DataFrame({'tempo': tempo_data, 'genre': genre_data}).dropna() - - merged_data['filename'] = merged_data.index.map(lambda track_id: utils.get_audio_path(AUDIO_DIR, track_id)) - merged_data['source'] = 'fma' - dataset = merged_data.reset_index(drop=True)[['filename', 'tempo', 'genre', 'source']] - - dataset = list(dataset.itertuples(index=False, name=None)) - return dataset \ No newline at end of file diff --git a/src/deeprhythm/dataset/song_dataset.py b/src/deeprhythm/dataset/song_dataset.py deleted file mode 100644 index a9f37c0..0000000 --- a/src/deeprhythm/dataset/song_dataset.py +++ /dev/null @@ -1,43 +0,0 @@ -import h5py -import torch -from torch.utils.data import Dataset -from deeprhythm.utils import bpm_to_class - -def song_collate(batch): - # Each element in `batch` is a tuple (song_clips, global_bpm) - # Where song_clips is a tensor of shape [num_clips, 240, 8, 6] - inputs = [item[0] for item in batch] - labels = torch.tensor([item[1] for item in batch]) - return inputs, labels - -class SongDataset(Dataset): - def __init__(self, hdf5_path, group): - """ - Args: - hdf5_path (str): Path to the HDF5 file. - group (str): Group in HDF5 file ('train', 'test', 'validate'). - """ - super(SongDataset, self).__init__() - self.hdf5_path = hdf5_path - self.group = group - self.file = h5py.File(hdf5_path, 'r') - self.group_file = self.file[group] - self.keys = [] - for key in self.group_file.keys(): - if self.group_file[key].attrs['source'] == 'fma': - continue - else: - self.keys.append(key) - - def __len__(self): - return len(self.keys) - - def __getitem__(self, idx): - song_key = self.keys[idx] - song_data = self.group_file[song_key] - hcqm = torch.tensor(song_data['hcqm'][:]) - bpm_class = bpm_to_class(int(float(song_data.attrs['bpm']))) - return hcqm, bpm_class - - def close(self): - self.file.close() diff --git a/src/deeprhythm/dataset/utils.py b/src/deeprhythm/dataset/utils.py deleted file mode 100644 index 4de109d..0000000 --- a/src/deeprhythm/dataset/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from torch.utils.data import Subset -from torch.utils.data.dataset import random_split -import json - -def split_dataset(dataset, train_ratio, test_ratio, validate_ratio): - total_ratio = train_ratio + test_ratio + validate_ratio - assert abs(total_ratio - 1) < 1e-6, "Ratios must sum to 1" - - dataset_size = len(dataset) - train_size = int(train_ratio * dataset_size) - test_size = int(test_ratio * dataset_size) - validate_size = dataset_size - train_size - test_size - - train_dataset, test_dataset, validate_dataset = random_split(dataset, [train_size, test_size, validate_size]) - return train_dataset, test_dataset, validate_dataset - -def save_split_indices(train_dataset, test_dataset, validate_dataset, filename="dataset_splits.json"): - splits = { - 'train_indices': train_dataset.indices, - 'test_indices': test_dataset.indices, - 'validate_indices': validate_dataset.indices - } - with open(filename, 'w') as f: - json.dump(splits, f) - - -def load_split_datasets(dataset, filename="dataset_splits.json"): - with open(filename, 'r') as f: - splits = json.load(f) - - train_dataset = Subset(dataset, splits['train_indices']) - test_dataset = Subset(dataset, splits['test_indices']) - validate_dataset = Subset(dataset, splits['validate_indices']) - - return train_dataset, test_dataset, validate_dataset \ No newline at end of file diff --git a/src/deeprhythm/model/infer.py b/src/deeprhythm/model/infer.py deleted file mode 100644 index 51602f7..0000000 --- a/src/deeprhythm/model/infer.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -import time -import os -from deeprhythm.utils import get_weights -from deeprhythm.utils import load_and_split_audio, split_audio -from deeprhythm.audio_proc.hcqm import make_kernels, compute_hcqm -from deeprhythm.utils import class_to_bpm -from deeprhythm.model.frame_cnn import DeepRhythmModel - -def load_cnn_model(path='deeprhythm-0.7.pth', device=None, quiet=False): - model = DeepRhythmModel(256) - if device is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if not os.path.exists(path): - path = get_weights(quiet=quiet) - model.load_state_dict(torch.load(path, map_location=torch.device(device))) - model = model.to(device=device) - model.eval() - return model - -def predict_global_bpm(input_path, model_path='deeprhythm-0.7.pth', model=None, specs=None, device='cpu'): - if model is None: - model = load_cnn_model(model_path, device=device) - clips = load_and_split_audio(input_path, sr=22050) - model_device = next(model.parameters()).device - if specs is None: - stft, band, cqt = make_kernels(device=model_device) - else: - stft, band, cqt = specs - input_batch = compute_hcqm(clips.to(device=model_device), stft, band, cqt).permute(0,3,1,2) - model.eval() - start = time.time() - with torch.no_grad(): - input_batch = input_batch.to(device=model_device) - outputs = model(input_batch) - probabilities = torch.softmax(outputs, dim=1) - mean_probabilities = probabilities.mean(dim=0) - _, predicted_class = torch.max(mean_probabilities, 0) - predicted_global_bpm = class_to_bpm(predicted_class.item()) - return predicted_global_bpm, time.time()-start - -def predict_global_bpm_from_audio(audio, sr, model_path='deeprhythm-0.7.pth', model=None, specs=None, device='cpu'): - if model is None: - model = load_cnn_model(model_path, device=device) - clips = split_audio(audio, sr=sr) - model_device = next(model.parameters()).device - if specs is None: - stft, band, cqt = make_kernels(device=model_device) - else: - stft, band, cqt = specs - input_batch = compute_hcqm(clips.to(device=model_device), stft, band, cqt).permute(0,3,1,2) - model.eval() - start = time.time() - with torch.no_grad(): - input_batch = input_batch.to(device=model_device) - outputs = model(input_batch) - probabilities = torch.softmax(outputs, dim=1) - mean_probabilities = probabilities.mean(dim=0) - _, predicted_class = torch.max(mean_probabilities, 0) - predicted_global_bpm = class_to_bpm(predicted_class.item()) - return predicted_global_bpm, time.time()-start diff --git a/src/deeprhythm/model/train.py b/src/deeprhythm/model/train.py deleted file mode 100644 index 15db85b..0000000 --- a/src/deeprhythm/model/train.py +++ /dev/null @@ -1,192 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader -from deeprhythm.dataset.clip_dataset import ClipDataset -from deeprhythm.model.frame_cnn import DeepRhythmModel -from deeprhythm.dataset.song_dataset import SongDataset,song_collate -from torch.optim.lr_scheduler import ReduceLROnPlateau - -def train_cnn(data_path, model_name='deeprhythm', start_weights=None, batch_size=256, early_stopping_patience=5): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = DeepRhythmModel() - if start_weights is not None: - model.load_state_dict(torch.load(start_weights, map_location=torch.device(device))) - data_path = '/media/bleu/bulkdata2/deeprhythmdata/hcqm-split.hdf5' - train_dataset = ClipDataset(data_path, 'train') - test_dataset = ClipDataset(data_path, 'test') - validate_dataset = ClipDataset(data_path, 'val') - train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) - validate_loader = DataLoader(validate_dataset, batch_size=batch_size, shuffle=False) - test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) - model = model.to(device) - criterion = nn.CrossEntropyLoss() - optimizer = optim.Adam(model.parameters(), lr=0.00001, betas=(0.9, 0.999), eps=1e-8) - scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, verbose=True) - early_stopping_counter = 0 - best_validate_loss = float('inf') - # Training loop - num_epochs = 40 - for epoch in range(num_epochs): - model.train() - running_loss = 0.0 - for inputs, labels in train_loader: - inputs, labels = inputs.to(device), labels.to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - running_loss += loss.item() - - # Validation loop - model.eval() - validate_loss = 0.0 - with torch.no_grad(): - for inputs, labels in validate_loader: - inputs, labels = inputs.to(device), labels.to(device) - outputs = model(inputs) - loss = criterion(outputs, labels) - validate_loss += loss.item() - - average_train_loss = running_loss / len(train_loader) - average_validate_loss = validate_loss / len(validate_loader) - print(f"Epoch {epoch+1}, Train Loss: {average_train_loss:.4f}, Validate Loss: {average_validate_loss:.4f}") - scheduler.step(average_validate_loss) - - # Check for early stopping - if average_validate_loss < best_validate_loss: - best_validate_loss = average_validate_loss - early_stopping_counter = 0 - # save the best version of the model - model_path = f'{model_name}-best.pth' - torch.save(model.state_dict(), model_path) - else: - early_stopping_counter += 1 - if early_stopping_counter >= early_stopping_patience: - print("Early stopping triggered.") - break - test_loss = 0.0 - correct1 = 0 - correct2 = 0 - total_predictions = 0 - model.eval() - with torch.no_grad(): - for inputs, labels in test_loader: - inputs, labels = inputs.to(device), labels.to(device) - outputs = model(inputs) - loss = criterion(outputs, labels) - test_loss += loss.item() - _, predicted = torch.max(outputs, 1) - total_predictions += labels.size(0) - for prediction, label in zip(predicted, labels): - tolerance = 0.04 * label.item() - if abs(prediction.item() - label.item()) <= tolerance: - correct1 += 1 - for multiple in range(1, int(label.item() / (label.item() * tolerance)) + 1): - if abs(prediction.item() - (label.item() * multiple)) <= tolerance: - correct2 += 1 - break - - average_test_loss = test_loss / len(test_loader) - accuracy1 = correct1 / total_predictions - accuracy2 = correct1 / total_predictions - print(f"Test Loss: {average_test_loss:.4f}, Accuracy1: {accuracy1:.4f}, Accuracy2: {accuracy2:.4f}") - - - -def train_cnn_cont(data_path, model_name='deeprhythm', start_weights=None, batch_size=256, early_stopping_patience=5): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = DeepRhythmModel() - if start_weights is not None: - model.load_state_dict(torch.load(start_weights, map_location=torch.device(device))) - data_path = '/media/bleu/bulkdata2/deeprhythmdata/hcqm-split.hdf5' - train_dataset = ClipDataset(data_path, 'train', use_float=True) - test_dataset = ClipDataset(data_path, 'test', use_float=True) - validate_dataset = ClipDataset(data_path, 'val', use_float=True) - train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) - validate_loader = DataLoader(validate_dataset, batch_size=batch_size, shuffle=False) - test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) - model = model.to(device) - - criterion = nn.HuberLoss() - optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-8) - scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5, verbose=True) - early_stopping_counter = 0 - best_validate_loss = float('inf') - # Training loop - num_epochs = 40 - for epoch in range(num_epochs): - model.train() - running_loss = 0.0 - for inputs, labels in train_loader: - inputs, labels = inputs.to(device), labels.to(device) - optimizer.zero_grad() - outputs = model(inputs) - # Calculate expected BPM for each output - probabilities = torch.softmax(outputs, dim=1) - expected_bpm = torch.sum(probabilities * torch.arange(256).float().to(device), dim=1) - loss = criterion(expected_bpm, labels) - loss.backward() - optimizer.step() - - running_loss += loss.item() - - # Validation loop - model.eval() - validate_loss = 0.0 - with torch.no_grad(): - for inputs, labels in validate_loader: - inputs, labels = inputs.to(device), labels.to(device) - outputs = model(inputs) - probabilities = torch.softmax(outputs, dim=1) - expected_bpm = torch.sum(probabilities * torch.arange(256).float().to(device), dim=1) - loss = criterion(expected_bpm, labels) - validate_loss += loss.item() - - average_train_loss = running_loss / len(train_loader) - average_validate_loss = validate_loss / len(validate_loader) - print(f"Epoch {epoch+1}, Train Loss: {average_train_loss:.4f}, Validate Loss: {average_validate_loss:.4f}") - scheduler.step(average_validate_loss) - - # Check for early stopping - if average_validate_loss < best_validate_loss: - best_validate_loss = average_validate_loss - early_stopping_counter = 0 - # save the best version of the model - model_path = f'{model_name}-best.pth' - torch.save(model.state_dict(), model_path) - else: - early_stopping_counter += 1 - if early_stopping_counter >= early_stopping_patience: - print("Early stopping triggered.") - break - test_loss = 0.0 - correct1 = 0 - correct2 = 0 - total_predictions = 0 - model.eval() - with torch.no_grad(): - for inputs, labels in test_loader: - inputs, labels = inputs.to(device), labels.to(device) - outputs = model(inputs) - probabilities = torch.softmax(outputs, dim=1) - expected_bpm = torch.sum(probabilities * torch.arange(256).float().to(device), dim=1) - loss = criterion(expected_bpm, labels) - test_loss += loss.item() - _, predicted = torch.max(outputs, 1) - total_predictions += labels.size(0) - for prediction, label in zip(predicted, labels): - tolerance = 0.04 * label.item() - if abs(prediction.item() - label.item()) <= tolerance: - correct1 += 1 - for multiple in range(1, int(label.item() / (label.item() * tolerance)) + 1): - if abs(prediction.item() - (label.item() * multiple)) <= tolerance: - correct2 += 1 - break - - average_test_loss = test_loss / len(test_loader) - accuracy1 = correct1 / total_predictions - accuracy2 = correct1 / total_predictions - print(f"Test Loss: {average_test_loss:.4f}, Accuracy1: {accuracy1:.4f}, Accuracy2: {accuracy2:.4f}") diff --git a/training.ipynb b/training.ipynb deleted file mode 100644 index 6497c1e..0000000 --- a/training.ipynb +++ /dev/null @@ -1,59 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "sys.path.append('/home/bleu/ai/deeprhythm/src')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from deeprhythm.model.train import train_cnn_cont\n", - "\n", - "data_path = '/media/bleu/bulkdata2/deeprhythmdata/hcqm-split.hdf5'\n", - "train_cnn_cont(data_path, 'deeprhythm-cont', start_weights='deeprhythm-cont-best.pth', batch_size=512)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from deeprhythm.model.train import train_cnn\n", - "data_path = '/media/bleu/bulkdata2/deeprhythmdata/hcqm-split.hdf5'\n", - "\n", - "train_cnn(data_path, 'deeprhythm-2.3', start_weights='deeprhythm-2.2-best.pth', batch_size=512)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "autoawq", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 67b19adabb6f5ac1d2413809a58ed90d699f0793 Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:50:54 -0500 Subject: [PATCH 2/8] [Build] Add load_cnn_model to predictor.py, fix weights_only=True, break circular import --- src/deeprhythm/model/predictor.py | 71 +++++++++++++++++-------------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/src/deeprhythm/model/predictor.py b/src/deeprhythm/model/predictor.py index 50fd8e6..e54423d 100644 --- a/src/deeprhythm/model/predictor.py +++ b/src/deeprhythm/model/predictor.py @@ -1,32 +1,42 @@ -import torch -from deeprhythm.utils import load_and_split_audio, split_audio, AudioTooShortError, AudioLoadError -from deeprhythm.audio_proc.hcqm import make_kernels, compute_hcqm -from deeprhythm.utils import class_to_bpm -from deeprhythm.model.frame_cnn import DeepRhythmModel -from deeprhythm.utils import get_weights, get_device -from deeprhythm.batch_infer import get_audio_files, main as batch_infer_main import json -import tempfile import os +import tempfile + +import torch + +from deeprhythm.audio_proc.hcqm import compute_hcqm, make_kernels +from deeprhythm.model.frame_cnn import DeepRhythmModel +from deeprhythm.utils import class_to_bpm, get_device, get_weights, load_and_split_audio, split_audio + + +def load_cnn_model(path='deeprhythm-0.7.pth', device=None, quiet=False): + model = DeepRhythmModel() + if device is None: + device = get_device() + if not os.path.exists(path): + path = get_weights(quiet=quiet) + model.load_state_dict(torch.load(path, map_location=torch.device(device), weights_only=True)) + model = model.to(device=device) + model.eval() + return model + class DeepRhythmPredictor: """ DeepRhythm tempo prediction model. Args: - model_path (str, optional): Path to a custom model weights file (.pth). + model_path (str, optional): Path to a custom model weights file (.pth). If None, automatically downloads the default model to ~/.local/share/deeprhythm/. Defaults to None. - device (str, optional): Device to run inference on ('cpu', 'cuda', 'mps'). + device (str, optional): Device to run inference on ('cpu', 'cuda', 'mps'). If None, automatically selects best available device. quiet (bool, optional): Suppress download progress messages. Defaults to False. """ def __init__(self, model_path=None, device=None, quiet=False): - # Handle model path: use provided path or auto-download default if model_path is None: self.model_path = get_weights(quiet=quiet) else: - # User provided custom path - validate it exists if not os.path.isfile(model_path): raise FileNotFoundError( f"Model file not found at: {model_path}\n" @@ -34,7 +44,7 @@ def __init__(self, model_path=None, device=None, quiet=False): f"to auto-download the default model." ) self.model_path = model_path - + if device is None: self.device = get_device() else: @@ -44,7 +54,7 @@ def __init__(self, model_path=None, device=None, quiet=False): def load_model(self): model = DeepRhythmModel() - model.load_state_dict(torch.load(self.model_path, map_location=self.device)) + model.load_state_dict(torch.load(self.model_path, map_location=self.device, weights_only=True)) model = model.to(device=self.device) model.eval() return model @@ -57,7 +67,7 @@ def make_kernels(self, device=None): def predict(self, filename, include_confidence=False): clips = load_and_split_audio(filename, sr=22050) - input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0,3,1,2) + input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0, 3, 1, 2) self.model.eval() with torch.no_grad(): input_batch = input_batch.to(device=self.device) @@ -67,12 +77,12 @@ def predict(self, filename, include_confidence=False): confidence_score, predicted_class = torch.max(mean_probabilities, 0) predicted_global_bpm = class_to_bpm(predicted_class.item()) if include_confidence: - return predicted_global_bpm, confidence_score.item(), + return predicted_global_bpm, confidence_score.item() return predicted_global_bpm - + def predict_from_audio(self, audio, sr, include_confidence=False): clips = split_audio(audio, sr) - input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0,3,1,2) + input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0, 3, 1, 2) self.model.eval() with torch.no_grad(): input_batch = input_batch.to(device=self.device) @@ -82,26 +92,27 @@ def predict_from_audio(self, audio, sr, include_confidence=False): confidence_score, predicted_class = torch.max(mean_probabilities, 0) predicted_global_bpm = class_to_bpm(predicted_class.item()) if include_confidence: - return predicted_global_bpm, confidence_score.item(), + return predicted_global_bpm, confidence_score.item() return predicted_global_bpm def predict_batch(self, dirname, include_confidence=False, workers=8, batch=128, quiet=True): """ Predict BPM for all audio files in a directory using efficient batch processing. - + Args: dirname: Directory containing audio files include_confidence: Whether to include confidence scores in results - + Returns: dict: Mapping of filenames to their predicted BPMs (and optionally confidence scores) """ - # Create a temporary file to store batch results + from deeprhythm.batch_infer import get_audio_files + from deeprhythm.batch_infer import main as batch_infer_main + with tempfile.NamedTemporaryFile(mode='w+', suffix='.jsonl', delete=False) as tmp_file: temp_path = tmp_file.name - + try: - # Run batch inference batch_infer_main( dataset=get_audio_files(dirname), data_path=temp_path, @@ -111,8 +122,7 @@ def predict_batch(self, dirname, include_confidence=False, workers=8, batch=128, n_workers=workers, max_len_batch=batch ) - - # Read and parse results + results = {} with open(temp_path, 'r') as f: for line in f: @@ -122,17 +132,16 @@ def predict_batch(self, dirname, include_confidence=False, workers=8, batch=128, results[filename] = (result['bpm'], result['confidence']) else: results[filename] = result['bpm'] - + return results - + finally: - # Clean up temporary file if os.path.exists(temp_path): os.remove(temp_path) def predict_per_frame(self, filename, include_confidence=False): clips = load_and_split_audio(filename, sr=22050) - input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0,3,1,2) + input_batch = compute_hcqm(clips.to(device=self.device), *self.specs).permute(0, 3, 1, 2) self.model.eval() with torch.no_grad(): input_batch = input_batch.to(device=self.device) @@ -140,7 +149,7 @@ def predict_per_frame(self, filename, include_confidence=False): probabilities = torch.softmax(outputs, dim=1) confidence_scores, predicted_classes = torch.max(probabilities, dim=1) predicted_bpms = [class_to_bpm(cls.item()) for cls in predicted_classes] - + if include_confidence: return predicted_bpms, confidence_scores.tolist() return predicted_bpms From 83678bc5508d4c0ccd58123ec4f6b5330bb0a921 Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:51:35 -0500 Subject: [PATCH 3/8] [Build] Fix batch_infer.py: update import, fix indent, remove pass, guard cuda, organize imports --- src/deeprhythm/batch_infer.py | 92 +++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 41 deletions(-) diff --git a/src/deeprhythm/batch_infer.py b/src/deeprhythm/batch_infer.py index 18875d6..83e4a04 100644 --- a/src/deeprhythm/batch_infer.py +++ b/src/deeprhythm/batch_infer.py @@ -1,15 +1,15 @@ +import argparse import json import os -import torch.multiprocessing as multiprocessing -import torch -import warnings import time -import argparse -from deeprhythm.utils import load_and_split_audio, AudioTooShortError, AudioLoadError -from deeprhythm.audio_proc.hcqm import make_kernels, compute_hcqm -from deeprhythm.model.infer import load_cnn_model -from deeprhythm.utils import class_to_bpm -from deeprhythm.utils import get_device +import warnings + +import torch +import torch.multiprocessing as multiprocessing + +from deeprhythm.audio_proc.hcqm import compute_hcqm, make_kernels +from deeprhythm.model.predictor import load_cnn_model +from deeprhythm.utils import AudioLoadError, AudioTooShortError, class_to_bpm, get_device, load_and_split_audio NUM_WORKERS = 8 @@ -23,11 +23,11 @@ def producer(task_queue, result_queue, completion_event, queue_condition, queue_ while True: task = task_queue.get() if task is None: - result_queue.put(None) # Send termination signal to indicate this producer is done - completion_event.wait() # Wait for the signal to exit + result_queue.put(None) + completion_event.wait() break filename = task - with queue_condition: # Use the condition to wait if the queue is too full before loading audio + with queue_condition: while result_queue.qsize() >= queue_threshold: queue_condition.wait() try: @@ -36,19 +36,15 @@ def producer(task_queue, result_queue, completion_event, queue_condition, queue_ except (AudioTooShortError, AudioLoadError) as e: print(f"Skipping {filename}: {e}") + def init_workers(dataset, n_workers=NUM_WORKERS): """ Initializes worker processes for multiprocessing, setting up the required queues, an event for coordinated exit, and a condition for queue threshold management. - - Parameters: - - n_workers: Number of worker processes to start. - - dataset: The dataset items to process. - - queue_threshold: The threshold for the result queue before producers wait. """ manager = multiprocessing.Manager() task_queue = multiprocessing.Queue() - result_queue = manager.Queue() # Managed Queue for sharing across processes + result_queue = manager.Queue() completion_event = manager.Event() queue_condition = manager.Condition() producers = [ @@ -60,15 +56,16 @@ def init_workers(dataset, n_workers=NUM_WORKERS): for p in producers: p.start() for item in dataset: - task_queue.put(item) + task_queue.put(item) for _ in range(n_workers): task_queue.put(None) return task_queue, result_queue, producers, completion_event, queue_condition + def process_and_save(batch_audio, batch_meta, specs, model, out_path, conf=False, quiet=False): """ - Processes a batch of audio clips and saves the result along with metadata to an HDF5 file. + Processes a batch of audio clips and saves the results to a JSONL file. """ stft, band, cqt = specs hcqm = compute_hcqm(batch_audio, stft, band, cqt) @@ -76,11 +73,12 @@ def process_and_save(batch_audio, batch_meta, specs, model, out_path, conf=False if not quiet: print('hcqm done', hcqm.shape) with torch.no_grad(): - hcqm = hcqm.permute(0,3,1,2).to(device=model_device) + hcqm = hcqm.permute(0, 3, 1, 2).to(device=model_device) outputs = model(hcqm) if not quiet: print('model done', outputs.shape) - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() results = [] for meta in batch_meta: filename, num_clips, start_idx = meta @@ -100,7 +98,11 @@ def process_and_save(batch_audio, batch_meta, specs, model, out_path, conf=False for result in results: f.write(json.dumps(result) + "\n") -def consume_and_process(result_queue, data_path, queue_condition, n_workers=NUM_WORKERS, max_len_batch=NUM_BATCH, device='cuda', conf=False, quiet=False): + +def consume_and_process( + result_queue, data_path, queue_condition, n_workers=NUM_WORKERS, + max_len_batch=NUM_BATCH, device='cuda', conf=False, quiet=False +): batch_audio = [] batch_meta = [] active_producers = n_workers @@ -108,7 +110,7 @@ def consume_and_process(result_queue, data_path, queue_condition, n_workers=NUM_ len_audio = sr * 8 if not quiet: print(f'Using device: {device}') - specs = make_kernels(len_audio, sr, device=device) + specs = make_kernels(len_audio, sr, device=device) if not quiet: print('made kernels') model = load_cnn_model(device=device, quiet=quiet) @@ -135,26 +137,31 @@ def consume_and_process(result_queue, data_path, queue_condition, n_workers=NUM_ total_clips += num_clips if total_clips >= max_len_batch: stacked_batch_audio = torch.cat(batch_audio, dim=0).to(device=device) - process_and_save(stacked_batch_audio, batch_meta, specs,model, data_path, conf=conf, quiet=quiet) + process_and_save(stacked_batch_audio, batch_meta, specs, model, data_path, conf=conf, quiet=quiet) total_clips = 0 batch_audio = [] batch_meta = [] - # Make sure to process any remaining clips if batch_audio: stacked_batch_audio = torch.cat(batch_audio, dim=0).to(device=device) - process_and_save(stacked_batch_audio, batch_meta, specs,model, data_path, conf=conf, quiet=quiet) - pass + process_and_save(stacked_batch_audio, batch_meta, specs, model, data_path, conf=conf, quiet=quiet) -def main(dataset, n_workers=NUM_WORKERS, max_len_batch=NUM_BATCH, data_path='output.jsonl', device='cuda', conf=False, quiet=False): +def main( + dataset, n_workers=NUM_WORKERS, max_len_batch=NUM_BATCH, + data_path='output.jsonl', device='cuda', conf=False, quiet=False +): task_queue, result_queue, producers, completion_event, queue_condition = init_workers(dataset, n_workers) try: - consume_and_process(result_queue, data_path, queue_condition, n_workers=n_workers,max_len_batch=max_len_batch, device=device, conf=conf, quiet=quiet) + consume_and_process( + result_queue, data_path, queue_condition, + n_workers=n_workers, max_len_batch=max_len_batch, + device=device, conf=conf, quiet=quiet + ) finally: completion_event.set() for p in producers: - p.join() # Ensure all producer processes have finished + p.join() def get_audio_files(dir_path): @@ -168,32 +175,35 @@ def get_audio_files(dir_path): audio_files.append(os.path.join(root, file)) return audio_files + if __name__ == '__main__': warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) multiprocessing.set_start_method('spawn', force=True) - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() parser = argparse.ArgumentParser() parser.add_argument('input_path', type=str, help='Directory containing audio files') parser.add_argument('-o', '--output_path', type=str, default='batch_results.jsonl', help='Output path for results') - parser.add_argument('-d','--device', type=str, default=get_device(), help='Device to use for inference') - parser.add_argument('-c','--conf', action='store_true', help='Include confidence score in output') - parser.add_argument('-q','--quiet', action='store_true', help='Use minimal output format') + parser.add_argument('-d', '--device', type=str, default=get_device(), help='Device to use for inference') + parser.add_argument('-c', '--conf', action='store_true', help='Include confidence score in output') + parser.add_argument('-q', '--quiet', action='store_true', help='Use minimal output format') args = parser.parse_args() songs = get_audio_files(args.input_path) if not args.quiet: - print(len(songs),'songs found') + print(len(songs), 'songs found') start = time.time() - main(songs, - n_workers=NUM_WORKERS, - data_path=args.output_path, - device=args.device, + main(songs, + n_workers=NUM_WORKERS, + data_path=args.output_path, + device=args.device, conf=args.conf, quiet=args.quiet ) if not args.quiet: print(f'{time.time()-start:.2f}') - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() From e2033bd20591aaea5c5e73a72551b75b412d1c46 Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:52:03 -0500 Subject: [PATCH 4/8] [Build] Fix hcqm.py: change default device to None, resolve with get_device() --- src/deeprhythm/audio_proc/hcqm.py | 35 ++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/deeprhythm/audio_proc/hcqm.py b/src/deeprhythm/audio_proc/hcqm.py index 7acdccc..a2c29cb 100644 --- a/src/deeprhythm/audio_proc/hcqm.py +++ b/src/deeprhythm/audio_proc/hcqm.py @@ -1,13 +1,16 @@ import numpy as np import torch import nnAudio.features as feat -from deeprhythm.audio_proc.bandfilter import create_log_filter, apply_log_filter + +from deeprhythm.audio_proc.bandfilter import apply_log_filter, create_log_filter from deeprhythm.audio_proc.onset import onset_strength +from deeprhythm.utils import get_device N_BINS = 240 N_BANDS = 8 -def make_kernels(len_audio=22050*8, sr=22050, device='cuda'): + +def make_kernels(len_audio=22050*8, sr=22050, device=None): """ Create the kernels for the STFT and CQT based on the input parameters. @@ -28,30 +31,34 @@ def make_kernels(len_audio=22050*8, sr=22050, device='cuda'): cqt_specs : list of CQT objects A list of Constant-Q Transform (CQT) objects for different harmonics. """ + if device is None: + device = get_device() n_fft = 2048 hop = 512 n_fft_bins = int(1+n_fft/2) band_filter = create_log_filter(n_fft_bins, N_BANDS, device=device) - stft_spec = feat.stft.STFT(sr=sr, n_fft=n_fft, hop_length=hop, output_format='Magnitude', verbose=False).to(device=device) + stft_spec = feat.stft.STFT( + sr=sr, n_fft=n_fft, hop_length=hop, output_format='Magnitude', verbose=False + ).to(device=device) cqt_specs = [] for h in [1/2, 1, 2, 3, 4, 5]: - # Convert from BPM to Hz fmin = (32.7*h)/60 sr_cqt = len_audio//(hop*8) - fmax =sr_cqt/2 + fmax = sr_cqt/2 num_octaves = np.log2(fmax/fmin) bins_per_octave = N_BINS / num_octaves cqt_spec = feat.cqt.CQT(sr=sr_cqt, - hop_length=len_audio//hop, - n_bins=N_BINS, - bins_per_octave=bins_per_octave, - fmin=fmin, - output_format='Magnitude', - verbose=False, - pad_mode='constant').to(device=device) + hop_length=len_audio//hop, + n_bins=N_BINS, + bins_per_octave=bins_per_octave, + fmin=fmin, + output_format='Magnitude', + verbose=False, + pad_mode='constant').to(device=device) cqt_specs.append(cqt_spec) return stft_spec, band_filter, cqt_specs + def compute_hcqm(y, stft_spec, band_filter, cqt_specs): """ Compute the Harmonic Constant-Q Modulation (HCQM) for an input signal. @@ -63,10 +70,10 @@ def compute_hcqm(y, stft_spec, band_filter, cqt_specs): - y (Tensor): The input signal tensor of shape (batch_size, num_samples). - stft_spec (STFT object): An object to compute the Short-Time Fourier Transform (STFT). - band_filter (Tensor): A filter matrix of shape (num_bands, num_bins) to apply to the STFT. - - cqt_specs (list of CQT objects): A list of Constant-Q Transform (CQT) objects for different harmonics / bands + - cqt_specs (list of CQT objects): A list of CQT objects for different harmonics / bands Returns: - - hcqm (Tensor): The computed HCQM of shape (batch_size, N_BINS, N_BANDS, N_HARMONICS), where 6 corresponds to the number of different harmonics analyzed. + - hcqm (Tensor): The computed HCQM of shape (batch_size, N_BINS, N_BANDS, N_HARMONICS) """ stft = stft_spec(y) stft_bands = apply_log_filter(stft, band_filter) From 33228985a0540b61204cce09fdb1c82c55e0afd2 Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:52:23 -0500 Subject: [PATCH 5/8] [Build] Deduplicate utils.py: split_audio as core, load_and_split_audio as wrapper --- src/deeprhythm/utils.py | 107 +++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 63 deletions(-) diff --git a/src/deeprhythm/utils.py b/src/deeprhythm/utils.py index c1d730f..96840cf 100644 --- a/src/deeprhythm/utils.py +++ b/src/deeprhythm/utils.py @@ -1,9 +1,8 @@ -import librosa -import torch -import zipfile - import os + +import librosa import requests +import torch model_url = 'https://github.com/bleugreen/deeprhythm/raw/main/' @@ -26,20 +25,18 @@ def get_device(): else: return 'cpu' + def get_weights(filename="deeprhythm-0.7.pth", quiet=False): - # Construct the path to save the model weights home_dir = os.path.expanduser("~") model_dir = os.path.join(home_dir, ".local", "share", "deeprhythm") if not os.path.exists(model_dir): os.makedirs(model_dir, exist_ok=True) model_path = os.path.join(model_dir, filename) - # Check if the model weights already exist if not os.path.isfile(model_path): print("Downloading model weights...") - # Download the model weights try: - r = requests.get(model_url+filename, allow_redirects=True) + r = requests.get(model_url + filename, allow_redirects=True) if r.status_code == 200: with open(model_path, 'wb') as f: f.write(r.content) @@ -55,91 +52,75 @@ def get_weights(filename="deeprhythm-0.7.pth", quiet=False): return model_path -def load_and_split_audio(filename, sr=22050, clip_length=8, share_mem=False): +def split_audio(audio, sr, clip_length=8, share_mem=False): """ - Load an audio file, split it into 8-second clips, and return a single tensor of all clips. + Split audio into fixed-length clips and return a stacked tensor. Parameters: - - filename: Path to the audio file. - - sr: Sampling rate to use for loading the audio. + - audio: Audio array (e.g. from librosa.load). + - sr: Sampling rate. - clip_length: Length of each clip in seconds. + - share_mem: Whether to put the tensor in shared memory (for multiprocessing). Returns: - A tensor of shape [clips, audio] where each row is an 8-second clip. - """ + A tensor of shape [num_clips, clip_samples]. + Raises: + AudioTooShortError: If audio is too short for even one clip. + """ clips = [] clip_samples = sr * clip_length - try: - audio, _ = librosa.load(filename, sr=sr) - for i in range(0, len(audio), clip_samples): - if i + clip_samples <= len(audio): - clip_tensor = torch.tensor(audio[i:i + clip_samples], dtype=torch.float32) - clips.append(clip_tensor) - if clips: - stacked_clips = torch.stack(clips, dim=0) - else: - raise AudioTooShortError( - f"Audio file must be at least {clip_length} seconds long. " - f"File '{filename}' is too short to extract any {clip_length}-second clips." - ) - - if share_mem: - stacked_clips.share_memory_() - - return stacked_clips - except AudioTooShortError: - raise - except Exception as e: - raise AudioLoadError( - f"Failed to load audio file '{filename}': {str(e)}" - ) from e + for i in range(0, len(audio), clip_samples): + if i + clip_samples <= len(audio): + clip_tensor = torch.tensor(audio[i:i + clip_samples], dtype=torch.float32) + clips.append(clip_tensor) + if not clips: + raise AudioTooShortError( + f"Audio must be at least {clip_length} seconds long to extract clips. " + f"Provided audio has {len(audio)/sr:.2f} seconds." + ) + + stacked_clips = torch.stack(clips, dim=0) + if share_mem: + stacked_clips.share_memory_() + return stacked_clips -def split_audio(audio, sr, clip_length=8, share_mem=False): + +def load_and_split_audio(filename, sr=22050, clip_length=8, share_mem=False): """ - Load an audio file, split it into 8-second clips, and return a single tensor of all clips. + Load an audio file and split it into fixed-length clips. Parameters: - - audio: Array generated by librosa.load representing the audio. - - sr: Sampling rate to used for loading the audio. + - filename: Path to the audio file. + - sr: Sampling rate to use for loading the audio. - clip_length: Length of each clip in seconds. + - share_mem: Whether to put the tensor in shared memory (for multiprocessing). Returns: - A tensor of shape [clips, audio] where each row is an 8-second clip. - """ + A tensor of shape [num_clips, clip_samples]. - clips = [] - clip_samples = sr * clip_length + Raises: + AudioTooShortError: If audio is too short for even one clip. + AudioLoadError: If the audio file cannot be loaded. + """ try: - for i in range(0, len(audio), clip_samples): - if i + clip_samples <= len(audio): - clip_tensor = torch.tensor(audio[i:i + clip_samples], dtype=torch.float32) - clips.append(clip_tensor) - if clips: - stacked_clips = torch.stack(clips, dim=0) - else: - raise AudioTooShortError( - f"Audio must be at least {clip_length} seconds long to extract clips. " - f"Provided audio has {len(audio)/sr:.2f} seconds." - ) - - if share_mem: - stacked_clips.share_memory_() - - return stacked_clips + audio, _ = librosa.load(filename, sr=sr) + return split_audio(audio, sr, clip_length=clip_length, share_mem=share_mem) except AudioTooShortError: raise except Exception as e: raise AudioLoadError( - f"Failed to process audio array: {str(e)}" + f"Failed to load audio file '{filename}': {str(e)}" ) from e + def bpm_to_class(bpm, min_bpm=30, max_bpm=286, num_classes=256): """Map a BPM value to a class index.""" class_width = (max_bpm - min_bpm) / num_classes class_index = int((bpm - min_bpm) // class_width) return max(0, min(num_classes - 1, class_index)) + def class_to_bpm(class_index, min_bpm=30, max_bpm=286, num_classes=256): """Map a class index back to a BPM value (to the center of the class interval).""" class_width = (max_bpm - min_bpm) / num_classes From acb873cc99bdd4588597477e3cc0f3491339fe69 Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:52:35 -0500 Subject: [PATCH 6/8] [Build] Add __init__.py files, fix re-export, update pyproject.toml --- pyproject.toml | 15 ++++++++++++++- src/deeprhythm/__init__.py | 2 +- src/deeprhythm/audio_proc/__init__.py | 0 src/deeprhythm/model/__init__.py | 0 4 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 src/deeprhythm/audio_proc/__init__.py create mode 100644 src/deeprhythm/model/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 6fb9648..a2b1a81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,22 @@ dependencies = [ "nnaudio==0.3.3", "librosa", "numpy", - "h5py", + "requests", ] +[project.optional-dependencies] +dev = ["pytest", "ruff"] + [project.urls] Homepage = "https://github.com/bleugreen/deeprhythm" Issues = "https://github.com/bleugreen/deeprhythm/issues" + +[tool.ruff] +line-length = 120 +target-version = "py38" + +[tool.ruff.lint] +select = ["E", "F", "I"] + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/src/deeprhythm/__init__.py b/src/deeprhythm/__init__.py index 8966c49..d7ca098 100644 --- a/src/deeprhythm/__init__.py +++ b/src/deeprhythm/__init__.py @@ -1 +1 @@ -from deeprhythm.model.predictor import DeepRhythmPredictor \ No newline at end of file +from deeprhythm.model.predictor import DeepRhythmPredictor as DeepRhythmPredictor \ No newline at end of file diff --git a/src/deeprhythm/audio_proc/__init__.py b/src/deeprhythm/audio_proc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/deeprhythm/model/__init__.py b/src/deeprhythm/model/__init__.py new file mode 100644 index 0000000..e69de29 From 97779a1cdbf2a37c249267bca9e8e0e5fac1f9fb Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:53:02 -0500 Subject: [PATCH 7/8] Fix import ordering across all source files --- src/deeprhythm/audio_proc/bandfilter.py | 1 + src/deeprhythm/audio_proc/hcqm.py | 2 +- src/deeprhythm/audio_proc/onset.py | 3 ++- src/deeprhythm/batch_infer.py | 1 - src/deeprhythm/infer.py | 1 - src/deeprhythm/model/frame_cnn.py | 1 + 6 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/deeprhythm/audio_proc/bandfilter.py b/src/deeprhythm/audio_proc/bandfilter.py index 0b29d75..2a312cd 100644 --- a/src/deeprhythm/audio_proc/bandfilter.py +++ b/src/deeprhythm/audio_proc/bandfilter.py @@ -1,6 +1,7 @@ import numpy as np import torch + def create_log_filter(num_bins, num_bands, device='cuda'): """ Create a logarithmically spaced filter matrix for audio processing. diff --git a/src/deeprhythm/audio_proc/hcqm.py b/src/deeprhythm/audio_proc/hcqm.py index a2c29cb..6ce1a72 100644 --- a/src/deeprhythm/audio_proc/hcqm.py +++ b/src/deeprhythm/audio_proc/hcqm.py @@ -1,6 +1,6 @@ +import nnAudio.features as feat import numpy as np import torch -import nnAudio.features as feat from deeprhythm.audio_proc.bandfilter import apply_log_filter, create_log_filter from deeprhythm.audio_proc.onset import onset_strength diff --git a/src/deeprhythm/audio_proc/onset.py b/src/deeprhythm/audio_proc/onset.py index 73d0b2d..1856ae5 100644 --- a/src/deeprhythm/audio_proc/onset.py +++ b/src/deeprhythm/audio_proc/onset.py @@ -1,6 +1,7 @@ import torch -import torchaudio import torch.nn.functional as F +import torchaudio + def onset_strength( y=None, n_fft=2048, hop_length=512, lag=1, ref=None, diff --git a/src/deeprhythm/batch_infer.py b/src/deeprhythm/batch_infer.py index 83e4a04..945ab4a 100644 --- a/src/deeprhythm/batch_infer.py +++ b/src/deeprhythm/batch_infer.py @@ -11,7 +11,6 @@ from deeprhythm.model.predictor import load_cnn_model from deeprhythm.utils import AudioLoadError, AudioTooShortError, class_to_bpm, get_device, load_and_split_audio - NUM_WORKERS = 8 NUM_BATCH = 128 diff --git a/src/deeprhythm/infer.py b/src/deeprhythm/infer.py index 290f01a..891e0b5 100644 --- a/src/deeprhythm/infer.py +++ b/src/deeprhythm/infer.py @@ -4,7 +4,6 @@ from deeprhythm.model.predictor import DeepRhythmPredictor from deeprhythm.utils import get_device - if __name__ == '__main__': warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) diff --git a/src/deeprhythm/model/frame_cnn.py b/src/deeprhythm/model/frame_cnn.py index 0115bd3..025a403 100644 --- a/src/deeprhythm/model/frame_cnn.py +++ b/src/deeprhythm/model/frame_cnn.py @@ -1,6 +1,7 @@ import torch.nn as nn import torch.nn.functional as F + class DeepRhythmModel(nn.Module): def __init__(self, num_classes=256): super(DeepRhythmModel, self).__init__() From 1e98553fe06b9a0a9a0cf0dc4467e907eff58eec Mon Sep 17 00:00:00 2001 From: bleugreen Date: Sat, 21 Feb 2026 22:53:17 -0500 Subject: [PATCH 8/8] [Build] Add tests and CI workflow --- .github/workflows/ci.yml | 26 +++++++++++++ tests/__init__.py | 0 tests/test_model.py | 79 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 tests/__init__.py create mode 100644 tests/test_model.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..6c58fb4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,26 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint-and-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: pip install -e ".[dev]" soundfile + + - name: Lint + run: ruff check src/ + + - name: Test + run: pytest diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..5a5ad58 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,79 @@ +import os +import tempfile + +import numpy as np +import pytest +import soundfile as sf + +from deeprhythm.utils import AudioTooShortError, bpm_to_class, class_to_bpm, split_audio + + +def test_bpm_class_roundtrip(): + """bpm_to_class and class_to_bpm should roundtrip within one class width.""" + class_width = (286 - 30) / 256 + for bpm in [30, 60, 90, 120, 150, 200, 285]: + cls = bpm_to_class(bpm) + recovered = class_to_bpm(cls) + assert abs(recovered - bpm) <= class_width, f"Roundtrip failed for {bpm}: got {recovered}" + + +def test_bpm_to_class_clamps(): + """Values outside [30, 286] should clamp to valid class range.""" + assert bpm_to_class(0) == 0 + assert bpm_to_class(500) == 255 + + +def test_split_audio_basic(): + """split_audio should produce correct number of clips from a synthetic signal.""" + sr = 22050 + clip_length = 8 + num_clips = 3 + audio = np.random.randn(sr * clip_length * num_clips + 1000).astype(np.float32) + clips = split_audio(audio, sr, clip_length=clip_length) + assert clips.shape == (num_clips, sr * clip_length) + + +def test_split_audio_too_short(): + """split_audio should raise AudioTooShortError when audio is shorter than one clip.""" + sr = 22050 + audio = np.zeros(100, dtype=np.float32) + with pytest.raises(AudioTooShortError): + split_audio(audio, sr) + + +def test_split_audio_share_mem(): + """split_audio with share_mem=True should return a shared memory tensor.""" + sr = 22050 + audio = np.random.randn(sr * 8).astype(np.float32) + clips = split_audio(audio, sr, share_mem=True) + assert clips.is_shared() + + +def test_predictor_instantiation(): + """DeepRhythmPredictor should load model and create kernels.""" + from deeprhythm.model.predictor import DeepRhythmPredictor + predictor = DeepRhythmPredictor(device='cpu', quiet=True) + assert predictor.model is not None + assert predictor.specs is not None + + +def test_predict_sine_wave(): + """Predicting on a synthetic sine wave should return a float in valid BPM range.""" + from deeprhythm.model.predictor import DeepRhythmPredictor + + sr = 22050 + duration = 16 + t = np.linspace(0, duration, sr * duration, dtype=np.float32) + audio = np.sin(2 * np.pi * 2.0 * t).astype(np.float32) + + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: + sf.write(f.name, audio, sr) + tmp_path = f.name + + try: + predictor = DeepRhythmPredictor(device='cpu', quiet=True) + result = predictor.predict(tmp_path) + assert isinstance(result, float) + assert 30 <= result <= 286 + finally: + os.unlink(tmp_path)