From ad30ba98a69dd6ca74d63d9379e387239ad210fc Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 4 Feb 2026 04:10:24 +0000 Subject: [PATCH 01/33] Add multi-provider support and local file media provider Major features: - Add local file media provider that scans directories for audio files using MP3 tags for metadata (ID3, Vorbis comments, etc.) - Introduce multi-provider architecture with database schema for tracking providers, tracks, and cross-provider linking - Add GUI setup wizard for configuring providers and settings - Simplify Docker deployment to 2 unified files (CPU/NVIDIA) Database changes: - New tables: provider, track, provider_track, app_settings - Add file_path and track_id columns to score table for linking - All changes are migration-safe with ON CONFLICT handling Key implementation details: - File path hash (SHA-256) used as stable identifier for local files - Providers can be tested before saving configuration - Existing installations auto-migrate with default provider from env - Analysis pipeline updated to store file_path for track linking Files added: - tasks/mediaserver_localfiles.py - Local file provider implementation - app_setup.py - Setup wizard API endpoints - templates/setup.html - Setup wizard UI - deployment/docker-compose-unified.yaml - Simplified CPU deployment - deployment/docker-compose-unified-nvidia.yaml - Simplified GPU deployment - docs/MULTI_PROVIDER_ARCHITECTURE.md - Architecture documentation https://claude.ai/code/session_011AebTWAucDafK4m6uoSSNg --- app.py | 2 + app_helper.py | 144 ++- app_setup.py | 682 +++++++++++ config.py | 10 +- deployment/.env.example | 102 +- deployment/docker-compose-unified-nvidia.yaml | 210 ++++ deployment/docker-compose-unified.yaml | 188 +++ docs/MULTI_PROVIDER_ARCHITECTURE.md | 323 ++++++ tasks/analysis.py | 10 +- tasks/mediaserver.py | 257 ++++- tasks/mediaserver_localfiles.py | 604 ++++++++++ templates/setup.html | 1028 +++++++++++++++++ templates/sidebar_navi.html | 1 + 13 files changed, 3511 insertions(+), 50 deletions(-) create mode 100644 app_setup.py create mode 100644 deployment/docker-compose-unified-nvidia.yaml create mode 100644 deployment/docker-compose-unified.yaml create mode 100644 docs/MULTI_PROVIDER_ARCHITECTURE.md create mode 100644 tasks/mediaserver_localfiles.py create mode 100644 templates/setup.html diff --git a/app.py b/app.py index 4d4a5136..0b630480 100644 --- a/app.py +++ b/app.py @@ -598,6 +598,7 @@ def listen_for_index_reloads(): from app_artist_similarity import artist_similarity_bp from app_clap_search import clap_search_bp from app_mulan_search import mulan_search_bp +from app_setup import setup_bp # Setup wizard and provider configuration app.register_blueprint(chat_bp, url_prefix='/chat') app.register_blueprint(clustering_bp) @@ -614,6 +615,7 @@ def listen_for_index_reloads(): app.register_blueprint(artist_similarity_bp) app.register_blueprint(clap_search_bp) app.register_blueprint(mulan_search_bp) +app.register_blueprint(setup_bp) # Setup wizard if __name__ == '__main__': os.makedirs(TEMP_DIR, exist_ok=True) diff --git a/app_helper.py b/app_helper.py index 7afbd822..25a47f90 100644 --- a/app_helper.py +++ b/app_helper.py @@ -212,7 +212,100 @@ def init_db(): """, (query, 1.0, rank)) logger.info(f"Inserted {len(default_queries)} default CLAP search queries") - + + # ================================================================= + # MULTI-PROVIDER SUPPORT TABLES + # ================================================================= + + # Create 'provider' table - Registry of configured media providers + cur.execute(""" + CREATE TABLE IF NOT EXISTS provider ( + id SERIAL PRIMARY KEY, + provider_type VARCHAR(50) NOT NULL, + name VARCHAR(255) NOT NULL, + config JSONB NOT NULL DEFAULT '{}', + enabled BOOLEAN DEFAULT TRUE, + priority INTEGER DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(provider_type, name) + ) + """) + + # Create 'track' table - Stable track identity based on file path + cur.execute(""" + CREATE TABLE IF NOT EXISTS track ( + id SERIAL PRIMARY KEY, + file_path_hash VARCHAR(64) NOT NULL UNIQUE, + file_path TEXT NOT NULL, + file_size BIGINT, + file_modified TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cur.execute("CREATE INDEX IF NOT EXISTS idx_track_file_path_hash ON track(file_path_hash)") + + # Create 'provider_track' table - Links provider item_ids to tracks + cur.execute(""" + CREATE TABLE IF NOT EXISTS provider_track ( + id SERIAL PRIMARY KEY, + provider_id INTEGER NOT NULL REFERENCES provider(id) ON DELETE CASCADE, + track_id INTEGER NOT NULL REFERENCES track(id) ON DELETE CASCADE, + item_id TEXT NOT NULL, + title TEXT, + artist TEXT, + album TEXT, + last_synced TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(provider_id, item_id), + UNIQUE(provider_id, track_id) + ) + """) + cur.execute("CREATE INDEX IF NOT EXISTS idx_provider_track_item_id ON provider_track(item_id)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_provider_track_track_id ON provider_track(track_id)") + + # Create 'app_settings' table - Application configuration storage + cur.execute(""" + CREATE TABLE IF NOT EXISTS app_settings ( + key VARCHAR(255) PRIMARY KEY, + value JSONB NOT NULL, + category VARCHAR(100), + description TEXT, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + + # Add 'track_id' column to 'score' table if not exists (for multi-provider linking) + cur.execute("SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'track_id')") + if not cur.fetchone()[0]: + logger.info("Adding 'track_id' column to 'score' table for multi-provider support.") + cur.execute("ALTER TABLE score ADD COLUMN track_id INTEGER REFERENCES track(id)") + cur.execute("CREATE INDEX IF NOT EXISTS idx_score_track_id ON score(track_id)") + + # Add 'file_path' column to 'score' table if not exists (for file identification) + cur.execute("SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'score' AND column_name = 'file_path')") + if not cur.fetchone()[0]: + logger.info("Adding 'file_path' column to 'score' table.") + cur.execute("ALTER TABLE score ADD COLUMN file_path TEXT") + cur.execute("CREATE INDEX IF NOT EXISTS idx_score_file_path ON score(file_path)") + + # Insert default settings if app_settings is empty + cur.execute("SELECT COUNT(*) FROM app_settings") + if cur.fetchone()[0] == 0: + default_settings = [ + ('setup_completed', 'false', 'system', 'Whether the setup wizard has been completed'), + ('setup_version', '"1.0"', 'system', 'Version of the setup wizard last completed'), + ('multi_provider_enabled', 'false', 'providers', 'Whether multi-provider mode is enabled'), + ('primary_provider_id', 'null', 'providers', 'ID of the primary provider for playlist creation'), + ] + for key, value, category, description in default_settings: + cur.execute(""" + INSERT INTO app_settings (key, value, category, description) + VALUES (%s, %s::jsonb, %s, %s) + ON CONFLICT (key) DO NOTHING + """, (key, value, category, description)) + logger.info("Inserted default app settings") + db.commit() # --- Status Constants --- @@ -427,14 +520,29 @@ def track_exists(item_id): cur.close() return row is not None -def save_track_analysis_and_embedding(item_id, title, author, tempo, key, scale, moods, embedding_vector, energy=None, other_features=None, album=None): - """Saves track analysis and embedding in a single transaction.""" - +def save_track_analysis_and_embedding(item_id, title, author, tempo, key, scale, moods, embedding_vector, energy=None, other_features=None, album=None, file_path=None): + """Saves track analysis and embedding in a single transaction. + + Args: + item_id: Provider-specific track identifier + title: Track title + author: Artist name + tempo: BPM + key: Musical key + scale: Major/Minor scale + moods: Dict of mood labels and scores + embedding_vector: numpy array of embeddings + energy: Energy level (0.01-0.15) + other_features: JSON string of additional features + album: Album name + file_path: Full path to the audio file (for multi-provider track linking) + """ + def _sanitize_string(s, max_length=1000, field_name="field"): """Sanitize string for PostgreSQL insertion.""" if s is None: return None - + # Ensure it's a string if not isinstance(s, str): try: @@ -442,25 +550,25 @@ def _sanitize_string(s, max_length=1000, field_name="field"): except Exception: logger.warning(f"Could not convert {field_name} to string, using empty string") return "" - + # Remove problematic characters # NUL byte (0x00) - PostgreSQL cannot store s = s.replace('\x00', '') - + # Remove other control characters that could cause issues # Keep only printable ASCII, space, tab, newline, and common Unicode s = ''.join(char for char in s if char.isprintable() or char in '\n\t ') - + # Truncate to max length to prevent overly long strings if len(s) > max_length: logger.warning(f"{field_name} truncated from {len(s)} to {max_length} characters") s = s[:max_length] - + # Strip leading/trailing whitespace s = s.strip() - + return s - + # Sanitize all string inputs with field-specific limits title = _sanitize_string(title, max_length=500, field_name="title") author = _sanitize_string(author, max_length=200, field_name="author") @@ -468,16 +576,17 @@ def _sanitize_string(s, max_length=1000, field_name="field"): key = _sanitize_string(key, max_length=10, field_name="key") scale = _sanitize_string(scale, max_length=10, field_name="scale") other_features = _sanitize_string(other_features, max_length=2000, field_name="other_features") + file_path = _sanitize_string(file_path, max_length=2000, field_name="file_path") mood_str = ','.join(f"{k}:{v:.3f}" for k, v in moods.items()) - + conn = get_db() # This now calls the function within this file cur = conn.cursor() try: - # Save analysis to score table + # Save analysis to score table (includes file_path for multi-provider linking) cur.execute(""" - INSERT INTO score (item_id, title, author, tempo, key, scale, mood_vector, energy, other_features, album) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + INSERT INTO score (item_id, title, author, tempo, key, scale, mood_vector, energy, other_features, album, file_path) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (item_id) DO UPDATE SET title = EXCLUDED.title, author = EXCLUDED.author, @@ -487,8 +596,9 @@ def _sanitize_string(s, max_length=1000, field_name="field"): mood_vector = EXCLUDED.mood_vector, energy = EXCLUDED.energy, other_features = EXCLUDED.other_features, - album = EXCLUDED.album - """, (item_id, title, author, tempo, key, scale, mood_str, energy, other_features, album)) + album = EXCLUDED.album, + file_path = EXCLUDED.file_path + """, (item_id, title, author, tempo, key, scale, mood_str, energy, other_features, album, file_path)) # Save embedding if isinstance(embedding_vector, np.ndarray) and embedding_vector.size > 0: diff --git a/app_setup.py b/app_setup.py new file mode 100644 index 00000000..db895861 --- /dev/null +++ b/app_setup.py @@ -0,0 +1,682 @@ +# app_setup.py +""" +Setup Wizard API for AudioMuse-AI + +This module provides the backend API for the setup wizard and provider configuration. +It handles: +- Initial setup detection +- Provider configuration (add, update, delete, test) +- Application settings management +- Multi-provider mode enablement +""" + +import logging +import json +from datetime import datetime +from flask import Blueprint, jsonify, request, render_template, redirect, url_for, g +from functools import wraps + +from app_helper import get_db +from tasks.mediaserver import ( + get_available_provider_types, + get_provider_info, + test_provider_connection, + PROVIDER_TYPES +) +import config + +logger = logging.getLogger(__name__) + +setup_bp = Blueprint('setup', __name__) + + +# ############################################################################## +# HELPER FUNCTIONS +# ############################################################################## + +def get_setting(key, default=None): + """Get a setting value from the database.""" + db = get_db() + with db.cursor() as cur: + cur.execute("SELECT value FROM app_settings WHERE key = %s", (key,)) + row = cur.fetchone() + if row: + return row[0] + return default + + +def set_setting(key, value, category=None, description=None): + """Set a setting value in the database.""" + db = get_db() + with db.cursor() as cur: + cur.execute(""" + INSERT INTO app_settings (key, value, category, description, updated_at) + VALUES (%s, %s, %s, %s, NOW()) + ON CONFLICT (key) DO UPDATE SET + value = EXCLUDED.value, + category = COALESCE(EXCLUDED.category, app_settings.category), + description = COALESCE(EXCLUDED.description, app_settings.description), + updated_at = NOW() + """, (key, json.dumps(value), category, description)) + db.commit() + + +def get_all_settings(): + """Get all settings grouped by category.""" + db = get_db() + with db.cursor() as cur: + cur.execute("SELECT key, value, category, description FROM app_settings ORDER BY category, key") + rows = cur.fetchall() + settings = {} + for row in rows: + key, value, category, description = row + if category not in settings: + settings[category] = {} + settings[category][key] = { + 'value': value, + 'description': description + } + return settings + + +def is_setup_completed(): + """Check if initial setup has been completed.""" + result = get_setting('setup_completed') + return result is True or result == 'true' or result == True + + +def is_multi_provider_enabled(): + """Check if multi-provider mode is enabled.""" + result = get_setting('multi_provider_enabled') + return result is True or result == 'true' or result == True + + +# ############################################################################## +# PROVIDER MANAGEMENT +# ############################################################################## + +def get_providers(): + """Get all configured providers.""" + db = get_db() + with db.cursor() as cur: + cur.execute(""" + SELECT id, provider_type, name, config, enabled, priority, created_at, updated_at + FROM provider + ORDER BY priority DESC, created_at ASC + """) + rows = cur.fetchall() + providers = [] + for row in rows: + provider = { + 'id': row[0], + 'provider_type': row[1], + 'name': row[2], + 'config': row[3], # JSONB is automatically parsed + 'enabled': row[4], + 'priority': row[5], + 'created_at': row[6].isoformat() if row[6] else None, + 'updated_at': row[7].isoformat() if row[7] else None, + } + # Don't expose sensitive config values + if provider['config']: + safe_config = {} + for k, v in provider['config'].items(): + if k in ('password', 'token', 'api_key'): + safe_config[k] = '********' if v else None + else: + safe_config[k] = v + provider['config_display'] = safe_config + providers.append(provider) + return providers + + +def get_provider_by_id(provider_id): + """Get a provider by ID.""" + db = get_db() + with db.cursor() as cur: + cur.execute(""" + SELECT id, provider_type, name, config, enabled, priority + FROM provider WHERE id = %s + """, (provider_id,)) + row = cur.fetchone() + if row: + return { + 'id': row[0], + 'provider_type': row[1], + 'name': row[2], + 'config': row[3], + 'enabled': row[4], + 'priority': row[5], + } + return None + + +def add_provider(provider_type, name, config_data, enabled=True, priority=0): + """Add a new provider configuration.""" + db = get_db() + with db.cursor() as cur: + cur.execute(""" + INSERT INTO provider (provider_type, name, config, enabled, priority) + VALUES (%s, %s, %s, %s, %s) + RETURNING id + """, (provider_type, name, json.dumps(config_data), enabled, priority)) + provider_id = cur.fetchone()[0] + db.commit() + return provider_id + + +def update_provider(provider_id, name=None, config_data=None, enabled=None, priority=None): + """Update an existing provider configuration.""" + db = get_db() + updates = [] + values = [] + + if name is not None: + updates.append("name = %s") + values.append(name) + if config_data is not None: + updates.append("config = %s") + values.append(json.dumps(config_data)) + if enabled is not None: + updates.append("enabled = %s") + values.append(enabled) + if priority is not None: + updates.append("priority = %s") + values.append(priority) + + if not updates: + return False + + updates.append("updated_at = NOW()") + values.append(provider_id) + + with db.cursor() as cur: + cur.execute(f""" + UPDATE provider SET {', '.join(updates)} + WHERE id = %s + """, values) + db.commit() + return cur.rowcount > 0 + + +def delete_provider(provider_id): + """Delete a provider configuration.""" + db = get_db() + with db.cursor() as cur: + cur.execute("DELETE FROM provider WHERE id = %s", (provider_id,)) + db.commit() + return cur.rowcount > 0 + + +def create_default_provider_from_env(): + """ + Create a default provider from environment variables if no providers exist. + This enables backward compatibility with existing installations. + """ + existing = get_providers() + if existing: + return None # Providers already exist + + provider_type = config.MEDIASERVER_TYPE + if provider_type not in PROVIDER_TYPES: + logger.warning(f"Unknown provider type from env: {provider_type}") + return None + + # Build config from environment variables + config_data = {} + + if provider_type == 'jellyfin': + config_data = { + 'url': config.JELLYFIN_URL, + 'user_id': config.JELLYFIN_USER_ID, + 'token': config.JELLYFIN_TOKEN, + } + elif provider_type == 'navidrome': + config_data = { + 'url': config.NAVIDROME_URL, + 'user': config.NAVIDROME_USER, + 'password': config.NAVIDROME_PASSWORD, + } + elif provider_type == 'lyrion': + config_data = { + 'url': config.LYRION_URL, + } + elif provider_type == 'mpd': + config_data = { + 'host': config.MPD_HOST, + 'port': config.MPD_PORT, + 'password': config.MPD_PASSWORD, + 'music_directory': config.MPD_MUSIC_DIRECTORY, + } + elif provider_type == 'emby': + config_data = { + 'url': config.EMBY_URL, + 'user_id': config.EMBY_USER_ID, + 'token': config.EMBY_TOKEN, + } + elif provider_type == 'localfiles': + config_data = { + 'music_directory': config.LOCALFILES_MUSIC_DIRECTORY, + 'supported_formats': config.LOCALFILES_FORMATS, + 'scan_subdirectories': config.LOCALFILES_SCAN_SUBDIRS, + 'playlist_directory': config.LOCALFILES_PLAYLIST_DIR, + } + + name = f"{PROVIDER_TYPES[provider_type]['name']} (Default)" + provider_id = add_provider(provider_type, name, config_data, enabled=True, priority=100) + logger.info(f"Created default provider from environment: {provider_type} (id={provider_id})") + return provider_id + + +# ############################################################################## +# API ENDPOINTS +# ############################################################################## + +@setup_bp.route('/setup') +def setup_page(): + """Render the setup wizard page.""" + return render_template('setup.html', title='AudioMuse-AI - Setup', active='setup') + + +@setup_bp.route('/api/setup/status', methods=['GET']) +def get_setup_status(): + """ + Get the current setup status. + --- + tags: + - Setup + responses: + 200: + description: Setup status information + """ + completed = is_setup_completed() + multi_provider = is_multi_provider_enabled() + providers = get_providers() + + # Check if we need to create default provider from env + if not providers: + create_default_provider_from_env() + providers = get_providers() + + return jsonify({ + 'setup_completed': completed, + 'multi_provider_enabled': multi_provider, + 'provider_count': len(providers), + 'providers': providers, + 'current_mediaserver_type': config.MEDIASERVER_TYPE, + 'app_version': config.APP_VERSION, + }) + + +@setup_bp.route('/api/setup/providers/types', methods=['GET']) +def get_provider_types(): + """ + Get available provider types with their configuration fields. + --- + tags: + - Setup + responses: + 200: + description: List of provider types + """ + types = get_available_provider_types() + result = [] + for ptype, info in types.items(): + provider_info = get_provider_info(ptype) + result.append({ + 'type': ptype, + 'name': info['name'], + 'description': info['description'], + 'supports_user_auth': info['supports_user_auth'], + 'supports_play_history': info['supports_play_history'], + 'config_fields': provider_info.get('config_fields', []) if provider_info else [], + }) + return jsonify(result) + + +@setup_bp.route('/api/setup/providers', methods=['GET']) +def list_providers(): + """ + List all configured providers. + --- + tags: + - Setup + responses: + 200: + description: List of providers + """ + providers = get_providers() + return jsonify(providers) + + +@setup_bp.route('/api/setup/providers', methods=['POST']) +def create_provider(): + """ + Add a new provider configuration. + --- + tags: + - Setup + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + provider_type: + type: string + name: + type: string + config: + type: object + enabled: + type: boolean + priority: + type: integer + responses: + 201: + description: Provider created + 400: + description: Invalid request + """ + data = request.get_json() + if not data: + return jsonify({'error': 'No data provided'}), 400 + + provider_type = data.get('provider_type') + name = data.get('name') + config_data = data.get('config', {}) + enabled = data.get('enabled', True) + priority = data.get('priority', 0) + + if not provider_type: + return jsonify({'error': 'provider_type is required'}), 400 + if not name: + return jsonify({'error': 'name is required'}), 400 + if provider_type not in PROVIDER_TYPES: + return jsonify({'error': f'Unknown provider type: {provider_type}'}), 400 + + try: + provider_id = add_provider(provider_type, name, config_data, enabled, priority) + return jsonify({'id': provider_id, 'message': 'Provider created'}), 201 + except Exception as e: + logger.error(f"Error creating provider: {e}") + return jsonify({'error': str(e)}), 500 + + +@setup_bp.route('/api/setup/providers/', methods=['PUT']) +def update_provider_endpoint(provider_id): + """ + Update an existing provider configuration. + --- + tags: + - Setup + parameters: + - name: provider_id + in: path + required: true + schema: + type: integer + responses: + 200: + description: Provider updated + 404: + description: Provider not found + """ + provider = get_provider_by_id(provider_id) + if not provider: + return jsonify({'error': 'Provider not found'}), 404 + + data = request.get_json() + if not data: + return jsonify({'error': 'No data provided'}), 400 + + # Merge config if partial update + config_data = data.get('config') + if config_data and isinstance(config_data, dict): + # Don't allow updating password fields with '********' + for key in list(config_data.keys()): + if config_data[key] == '********': + config_data[key] = provider['config'].get(key) + + success = update_provider( + provider_id, + name=data.get('name'), + config_data=config_data, + enabled=data.get('enabled'), + priority=data.get('priority') + ) + + if success: + return jsonify({'message': 'Provider updated'}) + return jsonify({'error': 'Update failed'}), 500 + + +@setup_bp.route('/api/setup/providers/', methods=['DELETE']) +def delete_provider_endpoint(provider_id): + """ + Delete a provider configuration. + --- + tags: + - Setup + parameters: + - name: provider_id + in: path + required: true + schema: + type: integer + responses: + 200: + description: Provider deleted + 404: + description: Provider not found + """ + success = delete_provider(provider_id) + if success: + return jsonify({'message': 'Provider deleted'}) + return jsonify({'error': 'Provider not found'}), 404 + + +@setup_bp.route('/api/setup/providers//test', methods=['POST']) +def test_provider_endpoint(provider_id): + """ + Test connection to a provider. + --- + tags: + - Setup + parameters: + - name: provider_id + in: path + required: true + schema: + type: integer + responses: + 200: + description: Connection test result + """ + provider = get_provider_by_id(provider_id) + if not provider: + return jsonify({'error': 'Provider not found'}), 404 + + success, message = test_provider_connection( + provider['provider_type'], + provider['config'] + ) + + return jsonify({ + 'success': success, + 'message': message, + 'provider_id': provider_id, + 'provider_type': provider['provider_type'], + }) + + +@setup_bp.route('/api/setup/providers/test', methods=['POST']) +def test_provider_config(): + """ + Test connection with provided configuration (without saving). + --- + tags: + - Setup + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + provider_type: + type: string + config: + type: object + responses: + 200: + description: Connection test result + """ + data = request.get_json() + if not data: + return jsonify({'error': 'No data provided'}), 400 + + provider_type = data.get('provider_type') + config_data = data.get('config', {}) + + if not provider_type: + return jsonify({'error': 'provider_type is required'}), 400 + + success, message = test_provider_connection(provider_type, config_data) + + return jsonify({ + 'success': success, + 'message': message, + 'provider_type': provider_type, + }) + + +@setup_bp.route('/api/setup/settings', methods=['GET']) +def get_settings(): + """ + Get all application settings. + --- + tags: + - Setup + responses: + 200: + description: All settings grouped by category + """ + settings = get_all_settings() + return jsonify(settings) + + +@setup_bp.route('/api/setup/settings', methods=['PUT']) +def update_settings(): + """ + Update application settings. + --- + tags: + - Setup + requestBody: + required: true + content: + application/json: + schema: + type: object + additionalProperties: true + responses: + 200: + description: Settings updated + """ + data = request.get_json() + if not data: + return jsonify({'error': 'No data provided'}), 400 + + for key, value in data.items(): + set_setting(key, value) + + return jsonify({'message': 'Settings updated'}) + + +@setup_bp.route('/api/setup/complete', methods=['POST']) +def complete_setup(): + """ + Mark the setup as complete. + --- + tags: + - Setup + responses: + 200: + description: Setup marked as complete + """ + set_setting('setup_completed', True, 'system', 'Whether the setup wizard has been completed') + set_setting('setup_version', config.APP_VERSION, 'system', 'Version of the setup wizard last completed') + return jsonify({'message': 'Setup completed', 'setup_completed': True}) + + +@setup_bp.route('/api/setup/multi-provider', methods=['POST']) +def enable_multi_provider(): + """ + Enable or disable multi-provider mode. + --- + tags: + - Setup + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + enabled: + type: boolean + responses: + 200: + description: Multi-provider mode updated + """ + data = request.get_json() + if not data: + return jsonify({'error': 'No data provided'}), 400 + + enabled = data.get('enabled', False) + set_setting('multi_provider_enabled', enabled, 'providers', 'Whether multi-provider mode is enabled') + + return jsonify({ + 'message': f"Multi-provider mode {'enabled' if enabled else 'disabled'}", + 'multi_provider_enabled': enabled + }) + + +@setup_bp.route('/api/setup/primary-provider', methods=['PUT']) +def set_primary_provider(): + """ + Set the primary provider for playlist creation. + --- + tags: + - Setup + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + provider_id: + type: integer + responses: + 200: + description: Primary provider set + """ + data = request.get_json() + if not data: + return jsonify({'error': 'No data provided'}), 400 + + provider_id = data.get('provider_id') + if provider_id is not None: + provider = get_provider_by_id(provider_id) + if not provider: + return jsonify({'error': 'Provider not found'}), 404 + + set_setting('primary_provider_id', provider_id, 'providers', 'ID of the primary provider for playlist creation') + + return jsonify({ + 'message': 'Primary provider set', + 'primary_provider_id': provider_id + }) diff --git a/config.py b/config.py index 9e9dec46..334bd411 100644 --- a/config.py +++ b/config.py @@ -2,7 +2,7 @@ import os # --- Media Server Type --- -MEDIASERVER_TYPE = os.environ.get("MEDIASERVER_TYPE", "jellyfin").lower() # Possible values: jellyfin, navidrome, lyrion, mpd, emby +MEDIASERVER_TYPE = os.environ.get("MEDIASERVER_TYPE", "jellyfin").lower() # Possible values: jellyfin, navidrome, lyrion, mpd, emby, localfiles # --- Jellyfin and DB Constants (Read from Environment Variables first) --- @@ -49,6 +49,14 @@ MPD_PASSWORD = os.environ.get("MPD_PASSWORD", "") # Optional password, leave empty if none MPD_MUSIC_DIRECTORY = os.environ.get("MPD_MUSIC_DIRECTORY", "/var/lib/mpd/music") # Path to MPD's music directory for file access +# --- Local Files Provider Constants --- +# These are used only if MEDIASERVER_TYPE is "localfiles". +LOCALFILES_MUSIC_DIRECTORY = os.environ.get("LOCALFILES_MUSIC_DIRECTORY", "/music") # Path to local music directory +LOCALFILES_FORMATS = os.environ.get("LOCALFILES_FORMATS", ".mp3,.flac,.ogg,.m4a,.mp4,.wav,.wma,.aac,.opus") # Supported audio formats +LOCALFILES_SCAN_SUBDIRS = os.environ.get("LOCALFILES_SCAN_SUBDIRS", "true").lower() == "true" # Scan subdirectories +LOCALFILES_USE_METADATA = os.environ.get("LOCALFILES_USE_METADATA", "true").lower() == "true" # Use embedded metadata +LOCALFILES_PLAYLIST_DIR = os.environ.get("LOCALFILES_PLAYLIST_DIR", "/music/playlists") # Where to save M3U playlists + # --- General Constants (Read from Environment Variables where applicable) --- APP_VERSION = "v0.8.8" diff --git a/deployment/.env.example b/deployment/.env.example index 0ec06cb6..d7e7db62 100644 --- a/deployment/.env.example +++ b/deployment/.env.example @@ -1,7 +1,16 @@ +# ============================================================================= +# AudioMuse-AI Configuration +# ============================================================================= # Copy this file to `.env` and fill in the values that match your setup. # Docker Compose files under deployment/ read these variables to keep settings in one place. # -# IMPORTANT: +# QUICK START: +# 1. Copy this file: cp .env.example .env +# 2. Set your media provider settings below (Jellyfin, Navidrome, etc.) +# 3. Start the containers: docker-compose -f docker-compose-unified.yaml up -d +# 4. Open http://localhost:8000 and complete the setup wizard +# +# IMPORTANT: # 1. This file must be named exactly ".env" (not .env.txt or .env.example) # 2. It must be in the SAME directory as your docker-compose-*.yaml file # 3. Do NOT use spaces around the = sign @@ -24,41 +33,90 @@ # - Restart containers after changing this file # If all else fails, try hardcoding the value directly in docker-compose-*.yaml to isolate the issue +# ============================================================================= +# MEDIA SERVER PROVIDER +# ============================================================================= +# Choose your primary media provider. Additional providers can be configured +# via the web-based setup wizard at http://localhost:8000/setup +# +# Options: jellyfin, navidrome, lyrion, mpd, emby, localfiles +# Default: localfiles (scans local music directory) +MEDIASERVER_TYPE=localfiles + +# --- Local Files Provider --- +# Path to your music library (mounted into the container as /music) +MUSIC_PATH=/path/to/your/music +LOCALFILES_MUSIC_DIRECTORY=/music +LOCALFILES_PLAYLIST_DIR=/music/playlists +# Supported formats (comma-separated, including the dot) +LOCALFILES_FORMATS=.mp3,.flac,.ogg,.m4a,.mp4,.wav,.wma,.aac,.opus +# Scan subdirectories for music files +LOCALFILES_SCAN_SUBDIRS=true + # --- Jellyfin --- +JELLYFIN_URL=http://jellyfin.example.com:8096 JELLYFIN_USER_ID=YOUR_JELLYFIN_USER_ID JELLYFIN_TOKEN=YOUR_JELLYFIN_API_TOKEN -JELLYFIN_URL=http://jellyfin.example.com:8096 # --- Emby --- +EMBY_URL= EMBY_USER_ID= EMBY_TOKEN= -EMBY_URL= # --- Navidrome --- NAVIDROME_URL= NAVIDROME_USER= NAVIDROME_PASSWORD= -# --- Lyrion --- -LYRION_URL=http://lyrion.example.com +# --- Lyrion (formerly LMS) --- +LYRION_URL=http://lyrion.example.com:9000 + +# --- MPD (Music Player Daemon) --- +MPD_HOST=localhost +MPD_PORT=6600 +MPD_PASSWORD= +MPD_MUSIC_DIRECTORY=/var/lib/mpd/music -# --- Shared backend configuration --- +# ============================================================================= +# DATABASE & INFRASTRUCTURE +# ============================================================================= POSTGRES_USER=audiomuse POSTGRES_PASSWORD=audiomusepassword POSTGRES_DB=audiomusedb POSTGRES_PORT=5432 POSTGRES_HOST=postgres -REDIS_URL=redis://redis:6379/0 # /!\ change port adress if you change REDIS_PORT below -# --- Timezone (optional) --- -# Set container timezone using TZ (examples: UTC, Europe/Berlin, America/Los_Angeles) -# If omitted, default is UTC. Containers read this env var and apply it at startup. -TZ=UTC +REDIS_URL=redis://redis:6379/0 REDIS_PORT=6379 + +# ============================================================================= +# WEB SERVER +# ============================================================================= FRONTEND_PORT=8000 WORKER_PORT=8029 +# Timezone (examples: UTC, Europe/Berlin, America/Los_Angeles) +TZ=UTC + +# ============================================================================= +# NVIDIA GPU (for docker-compose-unified-nvidia.yaml) +# ============================================================================= +# GPU device ID (usually 0 for single GPU systems) +NVIDIA_GPU_ID=0 +# Enable GPU-accelerated clustering using RAPIDS cuML +# Automatically falls back to CPU if GPU is unavailable +USE_GPU_CLUSTERING=false + +# ============================================================================= +# ML FEATURES +# ============================================================================= +# CLAP Text Search - natural language music search +# Allows queries like "upbeat summer songs" or "relaxing piano music" +# Disable to save memory (~750MB) on systems with limited RAM +CLAP_ENABLED=true -# --- AI Model Configuration --- -# Choose your AI provider: NONE, OLLAMA, OPENAI, GEMINI, OPENAI, or MISTRAL +# ============================================================================= +# AI PLAYLIST NAMING (Optional) +# ============================================================================= +# Choose an AI provider for creative playlist names: NONE, OLLAMA, OPENAI, GEMINI, MISTRAL AI_MODEL_PROVIDER=NONE # --- OpenAI / OpenRouter Configuration --- @@ -71,25 +129,9 @@ AI_MODEL_PROVIDER=NONE OPENAI_API_KEY= OPENAI_SERVER_URL=https://openrouter.ai/api/v1/chat/completions OPENAI_MODEL_NAME= -# Optional: Delay between API calls to respect rate limits (default: 7 seconds) +# Delay between API calls to respect rate limits (default: 7 seconds) OPENAI_API_CALL_DELAY_SECONDS=7 # --- Other AI Provider API Keys --- GEMINI_API_KEY= MISTRAL_API_KEY= - -# --- GPU Acceleration for Clustering --- -# Enable GPU-accelerated clustering using RAPIDS cuML (requires NVIDIA GPU) -# Set to true to use GPU for KMeans, DBSCAN, and PCA in clustering tasks -# Automatically falls back to CPU if GPU is unavailable -# Default: false (CPU only) -USE_GPU_CLUSTERING=false - -# --- CLAP Text Search Configuration --- -# Enable CLAP (Contrastive Language-Audio Pretraining) for natural language music search -# CLAP allows searching your music collection using text queries like "upbeat summer songs" or "relaxing piano music" -# Set to false to disable CLAP and save memory/CPU on slower systems -# WARNING: If disabled, text search functionality will not work (only similarity search will be available) -# Models: Audio model (~268MB) for analysis, Text model (~478MB) for search -# Default: true -CLAP_ENABLED=true diff --git a/deployment/docker-compose-unified-nvidia.yaml b/deployment/docker-compose-unified-nvidia.yaml new file mode 100644 index 00000000..b355eb1b --- /dev/null +++ b/deployment/docker-compose-unified-nvidia.yaml @@ -0,0 +1,210 @@ +# AudioMuse-AI Unified Docker Compose - NVIDIA GPU Edition +# ============================================================================= +# This deployment file includes NVIDIA GPU acceleration support. +# For CPU-only systems, use docker-compose-unified.yaml +# +# Requirements: +# - NVIDIA GPU with CUDA support +# - NVIDIA Container Toolkit installed +# - nvidia-docker2 or Docker 19.03+ with nvidia runtime +# +# Quick Start: +# 1. Copy .env.example to .env and configure your settings +# 2. Run: docker-compose -f docker-compose-unified-nvidia.yaml up -d +# 3. Open http://localhost:8000 and complete the setup wizard +# +# All provider-specific settings are now configured via the GUI setup wizard +# or the .env file. No need to use different docker-compose files for different +# media servers! +# ============================================================================= + +version: '3.8' + +services: + # --------------------------------------------------------------------------- + # Redis - Task Queue + # --------------------------------------------------------------------------- + redis: + image: redis:7-alpine + container_name: audiomuse-redis + ports: + - "${REDIS_PORT:-6379}:6379" + volumes: + - redis-data:/data + restart: unless-stopped + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 30s + timeout: 10s + retries: 3 + + # --------------------------------------------------------------------------- + # PostgreSQL - Database + # --------------------------------------------------------------------------- + postgres: + image: postgres:15-alpine + container_name: audiomuse-postgres + environment: + POSTGRES_USER: ${POSTGRES_USER:-audiomuse} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} + POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} + ports: + - "${POSTGRES_PORT:-5432}:5432" + volumes: + - postgres-data:/var/lib/postgresql/data + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-audiomuse}"] + interval: 30s + timeout: 10s + retries: 3 + + # --------------------------------------------------------------------------- + # AudioMuse-AI Flask Application (Web UI & API) - NVIDIA GPU + # --------------------------------------------------------------------------- + audiomuse-ai-flask: + image: ghcr.io/neptunehub/audiomuse-ai:latest-nvidia + container_name: audiomuse-ai-flask-app + ports: + - "${FRONTEND_PORT:-8000}:8000" + environment: + SERVICE_TYPE: "flask" + TZ: "${TZ:-UTC}" + # Media Server Configuration + # Configure via GUI setup wizard or set here for legacy support + MEDIASERVER_TYPE: "${MEDIASERVER_TYPE:-localfiles}" + # Jellyfin (if using) + JELLYFIN_URL: "${JELLYFIN_URL:-}" + JELLYFIN_USER_ID: "${JELLYFIN_USER_ID:-}" + JELLYFIN_TOKEN: "${JELLYFIN_TOKEN:-}" + # Navidrome (if using) + NAVIDROME_URL: "${NAVIDROME_URL:-}" + NAVIDROME_USER: "${NAVIDROME_USER:-}" + NAVIDROME_PASSWORD: "${NAVIDROME_PASSWORD:-}" + # Lyrion (if using) + LYRION_URL: "${LYRION_URL:-}" + # MPD (if using) + MPD_HOST: "${MPD_HOST:-}" + MPD_PORT: "${MPD_PORT:-6600}" + MPD_PASSWORD: "${MPD_PASSWORD:-}" + MPD_MUSIC_DIRECTORY: "${MPD_MUSIC_DIRECTORY:-/music}" + # Emby (if using) + EMBY_URL: "${EMBY_URL:-}" + EMBY_USER_ID: "${EMBY_USER_ID:-}" + EMBY_TOKEN: "${EMBY_TOKEN:-}" + # Local Files Provider (default) + LOCALFILES_MUSIC_DIRECTORY: "${LOCALFILES_MUSIC_DIRECTORY:-/music}" + LOCALFILES_PLAYLIST_DIR: "${LOCALFILES_PLAYLIST_DIR:-/music/playlists}" + # Database Configuration + POSTGRES_USER: ${POSTGRES_USER:-audiomuse} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} + POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} + POSTGRES_HOST: "postgres" + POSTGRES_PORT: "${POSTGRES_PORT:-5432}" + REDIS_URL: "${REDIS_URL:-redis://redis:6379/0}" + # AI Configuration + AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER:-NONE}" + OPENAI_API_KEY: "${OPENAI_API_KEY:-}" + OPENAI_SERVER_URL: "${OPENAI_SERVER_URL:-}" + OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME:-}" + GEMINI_API_KEY: "${GEMINI_API_KEY:-}" + MISTRAL_API_KEY: "${MISTRAL_API_KEY:-}" + # Features + CLAP_ENABLED: "${CLAP_ENABLED:-true}" + TEMP_DIR: "/app/temp_audio" + volumes: + - temp-audio-flask:/app/temp_audio + # Mount music directory for local files provider + - ${MUSIC_PATH:-./music}:/music:ro + depends_on: + redis: + condition: service_healthy + postgres: + condition: service_healthy + restart: unless-stopped + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["${NVIDIA_GPU_ID:-0}"] + capabilities: [gpu] + + # --------------------------------------------------------------------------- + # AudioMuse-AI Worker (Background Tasks & ML Analysis) - NVIDIA GPU + # --------------------------------------------------------------------------- + audiomuse-ai-worker: + image: ghcr.io/neptunehub/audiomuse-ai:latest-nvidia + container_name: audiomuse-ai-worker-instance + environment: + SERVICE_TYPE: "worker" + TZ: "${TZ:-UTC}" + # NVIDIA GPU Settings + NVIDIA_VISIBLE_DEVICES: "${NVIDIA_GPU_ID:-0}" + NVIDIA_DRIVER_CAPABILITIES: "compute,utility" + # Media Server Configuration (same as flask service) + MEDIASERVER_TYPE: "${MEDIASERVER_TYPE:-localfiles}" + JELLYFIN_URL: "${JELLYFIN_URL:-}" + JELLYFIN_USER_ID: "${JELLYFIN_USER_ID:-}" + JELLYFIN_TOKEN: "${JELLYFIN_TOKEN:-}" + NAVIDROME_URL: "${NAVIDROME_URL:-}" + NAVIDROME_USER: "${NAVIDROME_USER:-}" + NAVIDROME_PASSWORD: "${NAVIDROME_PASSWORD:-}" + LYRION_URL: "${LYRION_URL:-}" + MPD_HOST: "${MPD_HOST:-}" + MPD_PORT: "${MPD_PORT:-6600}" + MPD_PASSWORD: "${MPD_PASSWORD:-}" + MPD_MUSIC_DIRECTORY: "${MPD_MUSIC_DIRECTORY:-/music}" + EMBY_URL: "${EMBY_URL:-}" + EMBY_USER_ID: "${EMBY_USER_ID:-}" + EMBY_TOKEN: "${EMBY_TOKEN:-}" + LOCALFILES_MUSIC_DIRECTORY: "${LOCALFILES_MUSIC_DIRECTORY:-/music}" + LOCALFILES_PLAYLIST_DIR: "${LOCALFILES_PLAYLIST_DIR:-/music/playlists}" + # Database Configuration + POSTGRES_USER: ${POSTGRES_USER:-audiomuse} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} + POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} + POSTGRES_HOST: "postgres" + POSTGRES_PORT: "${POSTGRES_PORT:-5432}" + REDIS_URL: "${REDIS_URL:-redis://redis:6379/0}" + # AI Configuration + AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER:-NONE}" + OPENAI_API_KEY: "${OPENAI_API_KEY:-}" + OPENAI_SERVER_URL: "${OPENAI_SERVER_URL:-}" + OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME:-}" + GEMINI_API_KEY: "${GEMINI_API_KEY:-}" + MISTRAL_API_KEY: "${MISTRAL_API_KEY:-}" + # Features - Enable GPU clustering for NVIDIA + CLAP_ENABLED: "${CLAP_ENABLED:-true}" + USE_GPU_CLUSTERING: "${USE_GPU_CLUSTERING:-true}" + TEMP_DIR: "/app/temp_audio" + volumes: + - temp-audio-worker:/app/temp_audio + # Mount music directory for local files provider + - ${MUSIC_PATH:-./music}:/music:ro + depends_on: + redis: + condition: service_healthy + postgres: + condition: service_healthy + restart: unless-stopped + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["${NVIDIA_GPU_ID:-0}"] + capabilities: [gpu] + +# ============================================================================= +# Volumes +# ============================================================================= +volumes: + redis-data: + name: audiomuse-redis-data + postgres-data: + name: audiomuse-postgres-data + temp-audio-flask: + name: audiomuse-temp-flask + temp-audio-worker: + name: audiomuse-temp-worker diff --git a/deployment/docker-compose-unified.yaml b/deployment/docker-compose-unified.yaml new file mode 100644 index 00000000..bdb8be31 --- /dev/null +++ b/deployment/docker-compose-unified.yaml @@ -0,0 +1,188 @@ +# AudioMuse-AI Unified Docker Compose +# ============================================================================= +# This is the unified deployment file for CPU-only systems. +# For NVIDIA GPU acceleration, use docker-compose-unified-nvidia.yaml +# +# Quick Start: +# 1. Copy .env.example to .env and configure your settings +# 2. Run: docker-compose -f docker-compose-unified.yaml up -d +# 3. Open http://localhost:8000 and complete the setup wizard +# +# All provider-specific settings are now configured via the GUI setup wizard +# or the .env file. No need to use different docker-compose files for different +# media servers! +# ============================================================================= + +version: '3.8' + +services: + # --------------------------------------------------------------------------- + # Redis - Task Queue + # --------------------------------------------------------------------------- + redis: + image: redis:7-alpine + container_name: audiomuse-redis + ports: + - "${REDIS_PORT:-6379}:6379" + volumes: + - redis-data:/data + restart: unless-stopped + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 30s + timeout: 10s + retries: 3 + + # --------------------------------------------------------------------------- + # PostgreSQL - Database + # --------------------------------------------------------------------------- + postgres: + image: postgres:15-alpine + container_name: audiomuse-postgres + environment: + POSTGRES_USER: ${POSTGRES_USER:-audiomuse} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} + POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} + ports: + - "${POSTGRES_PORT:-5432}:5432" + volumes: + - postgres-data:/var/lib/postgresql/data + restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-audiomuse}"] + interval: 30s + timeout: 10s + retries: 3 + + # --------------------------------------------------------------------------- + # AudioMuse-AI Flask Application (Web UI & API) + # --------------------------------------------------------------------------- + audiomuse-ai-flask: + image: ghcr.io/neptunehub/audiomuse-ai:latest + container_name: audiomuse-ai-flask-app + ports: + - "${FRONTEND_PORT:-8000}:8000" + environment: + SERVICE_TYPE: "flask" + TZ: "${TZ:-UTC}" + # Media Server Configuration + # Configure via GUI setup wizard or set here for legacy support + MEDIASERVER_TYPE: "${MEDIASERVER_TYPE:-localfiles}" + # Jellyfin (if using) + JELLYFIN_URL: "${JELLYFIN_URL:-}" + JELLYFIN_USER_ID: "${JELLYFIN_USER_ID:-}" + JELLYFIN_TOKEN: "${JELLYFIN_TOKEN:-}" + # Navidrome (if using) + NAVIDROME_URL: "${NAVIDROME_URL:-}" + NAVIDROME_USER: "${NAVIDROME_USER:-}" + NAVIDROME_PASSWORD: "${NAVIDROME_PASSWORD:-}" + # Lyrion (if using) + LYRION_URL: "${LYRION_URL:-}" + # MPD (if using) + MPD_HOST: "${MPD_HOST:-}" + MPD_PORT: "${MPD_PORT:-6600}" + MPD_PASSWORD: "${MPD_PASSWORD:-}" + MPD_MUSIC_DIRECTORY: "${MPD_MUSIC_DIRECTORY:-/music}" + # Emby (if using) + EMBY_URL: "${EMBY_URL:-}" + EMBY_USER_ID: "${EMBY_USER_ID:-}" + EMBY_TOKEN: "${EMBY_TOKEN:-}" + # Local Files Provider (default) + LOCALFILES_MUSIC_DIRECTORY: "${LOCALFILES_MUSIC_DIRECTORY:-/music}" + LOCALFILES_PLAYLIST_DIR: "${LOCALFILES_PLAYLIST_DIR:-/music/playlists}" + # Database Configuration + POSTGRES_USER: ${POSTGRES_USER:-audiomuse} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} + POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} + POSTGRES_HOST: "postgres" + POSTGRES_PORT: "${POSTGRES_PORT:-5432}" + REDIS_URL: "${REDIS_URL:-redis://redis:6379/0}" + # AI Configuration + AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER:-NONE}" + OPENAI_API_KEY: "${OPENAI_API_KEY:-}" + OPENAI_SERVER_URL: "${OPENAI_SERVER_URL:-}" + OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME:-}" + GEMINI_API_KEY: "${GEMINI_API_KEY:-}" + MISTRAL_API_KEY: "${MISTRAL_API_KEY:-}" + # Features + CLAP_ENABLED: "${CLAP_ENABLED:-true}" + TEMP_DIR: "/app/temp_audio" + volumes: + - temp-audio-flask:/app/temp_audio + # Mount music directory for local files provider + - ${MUSIC_PATH:-./music}:/music:ro + depends_on: + redis: + condition: service_healthy + postgres: + condition: service_healthy + restart: unless-stopped + + # --------------------------------------------------------------------------- + # AudioMuse-AI Worker (Background Tasks & ML Analysis) + # --------------------------------------------------------------------------- + audiomuse-ai-worker: + image: ghcr.io/neptunehub/audiomuse-ai:latest + container_name: audiomuse-ai-worker-instance + environment: + SERVICE_TYPE: "worker" + TZ: "${TZ:-UTC}" + # Media Server Configuration (same as flask service) + MEDIASERVER_TYPE: "${MEDIASERVER_TYPE:-localfiles}" + JELLYFIN_URL: "${JELLYFIN_URL:-}" + JELLYFIN_USER_ID: "${JELLYFIN_USER_ID:-}" + JELLYFIN_TOKEN: "${JELLYFIN_TOKEN:-}" + NAVIDROME_URL: "${NAVIDROME_URL:-}" + NAVIDROME_USER: "${NAVIDROME_USER:-}" + NAVIDROME_PASSWORD: "${NAVIDROME_PASSWORD:-}" + LYRION_URL: "${LYRION_URL:-}" + MPD_HOST: "${MPD_HOST:-}" + MPD_PORT: "${MPD_PORT:-6600}" + MPD_PASSWORD: "${MPD_PASSWORD:-}" + MPD_MUSIC_DIRECTORY: "${MPD_MUSIC_DIRECTORY:-/music}" + EMBY_URL: "${EMBY_URL:-}" + EMBY_USER_ID: "${EMBY_USER_ID:-}" + EMBY_TOKEN: "${EMBY_TOKEN:-}" + LOCALFILES_MUSIC_DIRECTORY: "${LOCALFILES_MUSIC_DIRECTORY:-/music}" + LOCALFILES_PLAYLIST_DIR: "${LOCALFILES_PLAYLIST_DIR:-/music/playlists}" + # Database Configuration + POSTGRES_USER: ${POSTGRES_USER:-audiomuse} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} + POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} + POSTGRES_HOST: "postgres" + POSTGRES_PORT: "${POSTGRES_PORT:-5432}" + REDIS_URL: "${REDIS_URL:-redis://redis:6379/0}" + # AI Configuration + AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER:-NONE}" + OPENAI_API_KEY: "${OPENAI_API_KEY:-}" + OPENAI_SERVER_URL: "${OPENAI_SERVER_URL:-}" + OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME:-}" + GEMINI_API_KEY: "${GEMINI_API_KEY:-}" + MISTRAL_API_KEY: "${MISTRAL_API_KEY:-}" + # Features + CLAP_ENABLED: "${CLAP_ENABLED:-true}" + USE_GPU_CLUSTERING: "false" + TEMP_DIR: "/app/temp_audio" + volumes: + - temp-audio-worker:/app/temp_audio + # Mount music directory for local files provider + - ${MUSIC_PATH:-./music}:/music:ro + depends_on: + redis: + condition: service_healthy + postgres: + condition: service_healthy + restart: unless-stopped + +# ============================================================================= +# Volumes +# ============================================================================= +volumes: + redis-data: + name: audiomuse-redis-data + postgres-data: + name: audiomuse-postgres-data + temp-audio-flask: + name: audiomuse-temp-flask + temp-audio-worker: + name: audiomuse-temp-worker diff --git a/docs/MULTI_PROVIDER_ARCHITECTURE.md b/docs/MULTI_PROVIDER_ARCHITECTURE.md new file mode 100644 index 00000000..83f5ebc3 --- /dev/null +++ b/docs/MULTI_PROVIDER_ARCHITECTURE.md @@ -0,0 +1,323 @@ +# Multi-Provider Architecture Design + +## Overview + +This document outlines the architecture for supporting multiple media providers simultaneously in AudioMuse-AI without requiring re-analysis of tracks. The design ensures: + +1. **No re-analysis required** when adding new providers +2. **Seamless migration** for existing installations +3. **Future-proof** extensibility for new providers +4. **Minimal schema changes** to existing tables + +## Key Design Decisions + +### 1. Primary Key Strategy for Local File Provider + +**Decision: Use normalized file path as the stable identifier** + +Rationale: +- File paths are unique within a music library +- Content hashes would require reading entire files (slow for large libraries) +- File path changes are rare and can be handled via re-scan +- Consistent with MPD provider which already uses file paths + +For the local file provider: +- `item_id` = SHA-256 hash of the normalized relative file path +- This creates a stable, predictable ID that won't change unless the file moves + +### 2. Linking Tracks Across Providers + +**Decision: Use file path as the universal linking key** + +The key insight is that most providers ultimately point to the same physical files: +- Jellyfin, Navidrome, Lyrion, Emby all index local music directories +- Local file provider scans the same directories +- The file path (relative to the music library root) is the common denominator + +### 3. Database Schema Design + +#### New Tables + +```sql +-- Provider configuration storage +CREATE TABLE provider ( + id SERIAL PRIMARY KEY, + provider_type VARCHAR(50) NOT NULL, -- jellyfin, navidrome, localfiles, etc. + name VARCHAR(255) NOT NULL, -- User-friendly name + config JSONB NOT NULL, -- Provider-specific configuration + enabled BOOLEAN DEFAULT TRUE, + priority INTEGER DEFAULT 0, -- For ordering when same track in multiple providers + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(provider_type, name) +); + +-- Track identity table - links analysis to file paths +CREATE TABLE track ( + id SERIAL PRIMARY KEY, + file_path_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 of normalized relative path + file_path TEXT NOT NULL, -- Original file path for display + file_size BIGINT, -- For change detection + file_modified TIMESTAMP, -- For change detection + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Links provider-specific item_ids to tracks +CREATE TABLE provider_track ( + id SERIAL PRIMARY KEY, + provider_id INTEGER NOT NULL REFERENCES provider(id) ON DELETE CASCADE, + track_id INTEGER NOT NULL REFERENCES track(id) ON DELETE CASCADE, + item_id TEXT NOT NULL, -- Provider's native ID + title TEXT, -- Title from this provider + artist TEXT, -- Artist from this provider + album TEXT, -- Album from this provider + last_synced TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(provider_id, item_id), + UNIQUE(provider_id, track_id) +); +CREATE INDEX idx_provider_track_item_id ON provider_track(item_id); +CREATE INDEX idx_provider_track_track_id ON provider_track(track_id); + +-- Application settings stored in database (for GUI configuration) +CREATE TABLE app_settings ( + key VARCHAR(255) PRIMARY KEY, + value JSONB NOT NULL, + category VARCHAR(100), -- For UI grouping + description TEXT, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); +``` + +#### Modified Tables + +The `score` table remains largely unchanged, but we add a foreign key to `track`: + +```sql +ALTER TABLE score ADD COLUMN track_id INTEGER REFERENCES track(id); +CREATE INDEX idx_score_track_id ON score(track_id); +``` + +**Critical**: `item_id` remains the PRIMARY KEY for backward compatibility. The new `track_id` column provides the link to file-based identity. + +### 4. Migration Strategy + +**Phase 1: Schema Extension (Non-breaking)** +1. Add new tables (`provider`, `track`, `provider_track`, `app_settings`) +2. Add `track_id` column to `score` table (nullable initially) +3. Existing code continues to work unchanged + +**Phase 2: Data Migration** +1. Create default provider entry for current `MEDIASERVER_TYPE` +2. For each existing `score` record: + - Look up the track in the current provider + - Extract file path if available (via provider API) + - Create `track` record + - Create `provider_track` mapping + - Update `score.track_id` + +**Phase 3: Multi-Provider Activation** +1. Enable multi-provider mode via configuration +2. New providers can be added via GUI or API +3. When scanning new providers: + - Match tracks by file path + - Reuse existing analysis data + - Create new `provider_track` mappings + +### 5. API Changes + +#### New Endpoints + +``` +POST /api/setup/provider - Add/configure a provider +GET /api/setup/providers - List configured providers +PUT /api/setup/provider/{id} - Update provider config +DELETE /api/setup/provider/{id} - Remove provider + +GET /api/setup/settings - Get all settings +PUT /api/setup/settings - Update settings +GET /api/setup/wizard/status - Get setup wizard state + +POST /api/provider/{id}/sync - Sync tracks from provider +GET /api/provider/{id}/status - Get provider sync status +``` + +#### Modified Endpoints + +Existing endpoints remain unchanged but internally: +- `/api/analyze` - Uses active providers (priority-ordered) +- `/api/similarity` - Returns tracks with provider context +- Playlist creation - Creates in preferred provider(s) + +### 6. Provider Interface + +All providers must implement: + +```python +class MediaProvider: + """Base class for media providers""" + + def get_provider_type(self) -> str: + """Return provider type identifier (jellyfin, navidrome, etc.)""" + + def test_connection(self) -> Tuple[bool, str]: + """Test if provider is reachable, return (success, message)""" + + def get_all_songs(self) -> List[Dict]: + """Return all songs with metadata including file_path if available""" + + def get_tracks_from_album(self, album_id: str) -> List[Dict]: + """Return tracks for an album""" + + def download_track(self, temp_dir: str, item: Dict) -> Optional[str]: + """Download track to temp directory, return local path""" + + def create_playlist(self, name: str, item_ids: List[str]) -> Optional[str]: + """Create playlist, return playlist ID""" + + def get_file_path(self, item: Dict) -> Optional[str]: + """Extract file path from item metadata (for track linking)""" +``` + +### 7. Local File Provider Specifics + +```python +# Configuration for local file provider +{ + "provider_type": "localfiles", + "name": "Local Music Library", + "config": { + "music_directory": "/path/to/music", + "supported_formats": ["mp3", "flac", "m4a", "ogg", "wav"], + "scan_subdirectories": true, + "use_embedded_metadata": true + } +} +``` + +Features: +- Scans directories for audio files +- Extracts metadata from ID3 tags (MP3), Vorbis comments (FLAC/OGG), etc. +- Creates item_id from file path hash +- Supports playlist creation via M3U files + +### 8. Docker Compose Simplification + +**Before**: 16+ docker-compose files for various scenarios +**After**: 2 main files + optional components + +``` +deployment/ +├── docker-compose.yaml # CPU version (default) +├── docker-compose-nvidia.yaml # GPU/NVIDIA version +└── docker-compose-extras.yaml # Optional: pgAdmin, monitoring, etc. +``` + +All provider-specific configuration moves to: +1. `.env` file for initial setup +2. Database `app_settings` table for runtime configuration +3. GUI Setup Wizard for user-friendly configuration + +### 9. Setup Wizard Flow + +``` +1. Welcome Screen + - Detect if first run or existing installation + - Show version and hardware detection (GPU available?) + +2. Hardware Selection + - CPU-only or NVIDIA GPU acceleration + - Validate GPU drivers if selected + +3. Provider Configuration + - List of available providers with descriptions + - Multi-select enabled providers + - For each provider: configuration form + - Connection test for each provider + +4. Music Library Paths + - For local file provider: select directories + - For media servers: auto-detected from provider config + +5. Advanced Settings (collapsible) + - Database settings (show defaults, allow override) + - Analysis settings (CLAP, MuLan options) + - AI provider settings (optional) + +6. Review & Apply + - Summary of all settings + - Apply configuration + - Start initial sync (optional) + +7. Complete + - Link to main dashboard + - Quick start guide +``` + +### 10. Backward Compatibility + +The system maintains full backward compatibility: + +1. **Environment variables** still work: + - `MEDIASERVER_TYPE` creates a default provider on first run + - All `JELLYFIN_*`, `NAVIDROME_*` etc. variables honored + +2. **Existing data preserved**: + - `score` table unchanged except optional `track_id` column + - `embedding`, `clap_embedding` tables unchanged + - All indexes and projections preserved + +3. **Gradual migration**: + - Single-provider mode works exactly as before + - Multi-provider can be enabled via GUI/API + - No forced migration path + +### 11. File Path Normalization + +To ensure consistent file path matching: + +```python +def normalize_file_path(path: str, base_path: str = "") -> str: + """ + Normalize a file path for cross-provider matching. + + - Convert to POSIX style (forward slashes) + - Make relative to music library root + - Lowercase (optional, for case-insensitive filesystems) + - Remove leading/trailing whitespace + """ + import os + from pathlib import PurePosixPath + + # Convert to Path object + p = Path(path) + + # Make relative if absolute and base_path provided + if base_path and p.is_absolute(): + try: + p = p.relative_to(base_path) + except ValueError: + pass # Not relative to base, keep as-is + + # Convert to POSIX style + normalized = PurePosixPath(p).as_posix() + + return normalized.strip() + + +def file_path_hash(normalized_path: str) -> str: + """Generate SHA-256 hash of normalized file path.""" + import hashlib + return hashlib.sha256(normalized_path.encode('utf-8')).hexdigest() +``` + +## Implementation Order + +1. Database schema changes (migration-safe) +2. Local file provider implementation +3. Provider configuration storage +4. Multi-provider dispatcher updates +5. Setup wizard backend +6. Setup wizard frontend +7. Docker Compose simplification +8. Documentation updates diff --git a/tasks/analysis.py b/tasks/analysis.py index 9c3901d7..df4871dd 100644 --- a/tasks/analysis.py +++ b/tasks/analysis.py @@ -911,7 +911,15 @@ def get_missing_mulan_track_ids(track_ids): logger.info(f" - Top Moods: {top_moods}") logger.info(f" - Other Features: {other_features}") - save_track_analysis_and_embedding(item['Id'], item['Name'], item.get('AlbumArtist', 'Unknown'), analysis['tempo'], analysis['key'], analysis['scale'], top_moods, embedding, energy=analysis['energy'], other_features=other_features, album=item.get('Album', None)) + save_track_analysis_and_embedding( + item['Id'], item['Name'], item.get('AlbumArtist', 'Unknown'), + analysis['tempo'], analysis['key'], analysis['scale'], + top_moods, embedding, + energy=analysis['energy'], + other_features=other_features, + album=item.get('Album', None), + file_path=item.get('Path') # For multi-provider track linking + ) track_processed = True # Increment session recycler counter after successful analysis diff --git a/tasks/mediaserver.py b/tasks/mediaserver.py index ae47dad0..ce6c3752 100644 --- a/tasks/mediaserver.py +++ b/tasks/mediaserver.py @@ -1,4 +1,24 @@ # tasks/mediaserver.py +""" +Media Server Dispatcher for AudioMuse-AI + +This module provides a unified interface to multiple media server providers. +It dispatches function calls to the appropriate provider implementation based +on the configured MEDIASERVER_TYPE. + +Supported providers: +- jellyfin: Jellyfin Media Server +- navidrome: Navidrome (Subsonic API) +- lyrion: Lyrion Music Server (formerly LMS) +- mpd: Music Player Daemon +- emby: Emby Media Server +- localfiles: Local file system scanner + +Multi-provider support: +When multi_provider_enabled is true in app_settings, multiple providers can be +configured and used simultaneously. Tracks are linked via file paths, allowing +analysis data to be shared across providers. +""" import logging import os @@ -73,10 +93,74 @@ get_top_played_songs as emby_get_top_played_songs, get_last_played_time as emby_get_last_played_time, ) +from tasks.mediaserver_localfiles import ( + get_all_playlists as localfiles_get_all_playlists, + delete_playlist as localfiles_delete_playlist, + get_recent_albums as localfiles_get_recent_albums, + get_tracks_from_album as localfiles_get_tracks_from_album, + download_track as localfiles_download_track, + get_all_songs as localfiles_get_all_songs, + get_playlist_by_name as localfiles_get_playlist_by_name, + create_playlist as localfiles_create_playlist, + create_instant_playlist as localfiles_create_instant_playlist, + get_top_played_songs as localfiles_get_top_played_songs, + get_last_played_time as localfiles_get_last_played_time, + test_connection as localfiles_test_connection, + get_provider_info as localfiles_get_provider_info, +) logger = logging.getLogger(__name__) +# ############################################################################## +# PROVIDER REGISTRY +# ############################################################################## + +PROVIDER_TYPES = { + 'jellyfin': { + 'name': 'Jellyfin', + 'description': 'Jellyfin Media Server - Open source media solution', + 'supports_user_auth': True, + 'supports_play_history': True, + }, + 'navidrome': { + 'name': 'Navidrome', + 'description': 'Navidrome - Modern music server (Subsonic API)', + 'supports_user_auth': True, + 'supports_play_history': True, + }, + 'lyrion': { + 'name': 'Lyrion', + 'description': 'Lyrion Music Server (formerly Logitech Media Server)', + 'supports_user_auth': False, + 'supports_play_history': True, + }, + 'mpd': { + 'name': 'MPD', + 'description': 'Music Player Daemon - Flexible music server', + 'supports_user_auth': False, + 'supports_play_history': False, + }, + 'emby': { + 'name': 'Emby', + 'description': 'Emby Media Server - Personal media server', + 'supports_user_auth': True, + 'supports_play_history': True, + }, + 'localfiles': { + 'name': 'Local Files', + 'description': 'Scan local directories for audio files', + 'supports_user_auth': False, + 'supports_play_history': False, + }, +} + + +def get_available_provider_types(): + """Return information about all available provider types.""" + return PROVIDER_TYPES.copy() + + # ############################################################################## # PUBLIC API (Dispatcher functions) # ############################################################################## @@ -128,6 +212,7 @@ def get_recent_albums(limit): if config.MEDIASERVER_TYPE == 'lyrion': return lyrion_get_recent_albums(limit) if config.MEDIASERVER_TYPE == 'mpd': return mpd_get_recent_albums(limit) if config.MEDIASERVER_TYPE == 'emby': return emby_get_recent_albums(limit) + if config.MEDIASERVER_TYPE == 'localfiles': return localfiles_get_recent_albums(limit) return [] def get_recent_music_items(limit): @@ -156,6 +241,7 @@ def get_tracks_from_album(album_id): if config.MEDIASERVER_TYPE == 'lyrion': return lyrion_get_tracks_from_album(album_id) if config.MEDIASERVER_TYPE == 'mpd': return mpd_get_tracks_from_album(album_id) if config.MEDIASERVER_TYPE == 'emby': return emby_get_tracks_from_album(album_id) + if config.MEDIASERVER_TYPE == 'localfiles': return localfiles_get_tracks_from_album(album_id) return [] def download_track(temp_dir, item): @@ -167,6 +253,7 @@ def download_track(temp_dir, item): elif config.MEDIASERVER_TYPE == 'lyrion': downloaded_path = lyrion_download_track(temp_dir, item) elif config.MEDIASERVER_TYPE == 'mpd': downloaded_path = mpd_download_track(temp_dir, item) elif config.MEDIASERVER_TYPE == 'emby': downloaded_path = emby_download_track(temp_dir, item) + elif config.MEDIASERVER_TYPE == 'localfiles': downloaded_path = localfiles_download_track(temp_dir, item) # If download failed or returned None, return as is if not downloaded_path: @@ -244,6 +331,7 @@ def get_all_songs(): if config.MEDIASERVER_TYPE == 'lyrion': return lyrion_get_all_songs() if config.MEDIASERVER_TYPE == 'mpd': return mpd_get_all_songs() if config.MEDIASERVER_TYPE == 'emby': return emby_get_all_songs() + if config.MEDIASERVER_TYPE == 'localfiles': return localfiles_get_all_songs() return [] def get_playlist_by_name(playlist_name): @@ -254,6 +342,7 @@ def get_playlist_by_name(playlist_name): if config.MEDIASERVER_TYPE == 'lyrion': return lyrion_get_playlist_by_name(playlist_name) if config.MEDIASERVER_TYPE == 'mpd': return mpd_get_playlist_by_name(playlist_name) if config.MEDIASERVER_TYPE == 'emby': return emby_get_playlist_by_name(playlist_name) + if config.MEDIASERVER_TYPE == 'localfiles': return localfiles_get_playlist_by_name(playlist_name) return None def create_playlist(base_name, item_ids): @@ -265,12 +354,13 @@ def create_playlist(base_name, item_ids): elif config.MEDIASERVER_TYPE == 'lyrion': lyrion_create_playlist(base_name, item_ids) elif config.MEDIASERVER_TYPE == 'mpd': mpd_create_playlist(base_name, item_ids) elif config.MEDIASERVER_TYPE == 'emby': emby_create_playlist(base_name, item_ids) + elif config.MEDIASERVER_TYPE == 'localfiles': localfiles_create_playlist(base_name, item_ids) def create_instant_playlist(playlist_name, item_ids, user_creds=None): """Creates an instant playlist. Uses user_creds if provided, otherwise admin.""" if not playlist_name: raise ValueError("Playlist name is required.") if not item_ids: raise ValueError("Track IDs are required.") - + if config.MEDIASERVER_TYPE == 'jellyfin': return jellyfin_create_instant_playlist(playlist_name, item_ids, user_creds) if config.MEDIASERVER_TYPE == 'navidrome': @@ -281,6 +371,8 @@ def create_instant_playlist(playlist_name, item_ids, user_creds=None): return mpd_create_instant_playlist(playlist_name, item_ids, user_creds) if config.MEDIASERVER_TYPE == 'emby': return emby_create_instant_playlist(playlist_name, item_ids, user_creds) + if config.MEDIASERVER_TYPE == 'localfiles': + return localfiles_create_instant_playlist(playlist_name, item_ids, user_creds) return None def get_top_played_songs(limit, user_creds=None): @@ -295,6 +387,8 @@ def get_top_played_songs(limit, user_creds=None): return mpd_get_top_played_songs(limit, user_creds) if config.MEDIASERVER_TYPE == 'emby': return emby_get_top_played_songs(limit, user_creds) + if config.MEDIASERVER_TYPE == 'localfiles': + return localfiles_get_top_played_songs(limit, user_creds) return [] def get_last_played_time(item_id, user_creds=None): @@ -309,5 +403,166 @@ def get_last_played_time(item_id, user_creds=None): return mpd_get_last_played_time(item_id, user_creds) if config.MEDIASERVER_TYPE == 'emby': return emby_get_last_played_time(item_id, user_creds) + if config.MEDIASERVER_TYPE == 'localfiles': + return localfiles_get_last_played_time(item_id, user_creds) return None + +# ############################################################################## +# MULTI-PROVIDER SUPPORT FUNCTIONS +# ############################################################################## + +def test_provider_connection(provider_type: str, config_dict: dict = None): + """ + Test connection to a specific provider. + + Args: + provider_type: Type of provider (jellyfin, navidrome, localfiles, etc.) + config_dict: Optional configuration dictionary for the provider + + Returns: + Tuple of (success: bool, message: str) + """ + import requests + + try: + if provider_type == 'localfiles': + return localfiles_test_connection(config_dict) + + elif provider_type == 'jellyfin': + url = config_dict.get('url') if config_dict else config.JELLYFIN_URL + token = config_dict.get('token') if config_dict else config.JELLYFIN_TOKEN + if not url or not token: + return False, "Jellyfin URL and token are required" + resp = requests.get(f"{url.rstrip('/')}/System/Info", + headers={"X-Emby-Token": token}, timeout=10) + if resp.status_code == 200: + return True, f"Connected to Jellyfin at {url}" + return False, f"Jellyfin returned status {resp.status_code}" + + elif provider_type == 'navidrome': + import hashlib + import secrets + url = config_dict.get('url') if config_dict else config.NAVIDROME_URL + user = config_dict.get('user') if config_dict else config.NAVIDROME_USER + password = config_dict.get('password') if config_dict else config.NAVIDROME_PASSWORD + if not url or not user or not password: + return False, "Navidrome URL, user, and password are required" + salt = secrets.token_hex(8) + token = hashlib.md5((password + salt).encode()).hexdigest() + params = {'u': user, 't': token, 's': salt, 'v': '1.16.1', 'c': 'audiomuse', 'f': 'json'} + resp = requests.get(f"{url.rstrip('/')}/rest/ping", params=params, timeout=10) + if resp.status_code == 200: + data = resp.json() + if data.get('subsonic-response', {}).get('status') == 'ok': + return True, f"Connected to Navidrome at {url}" + err = data.get('subsonic-response', {}).get('error', {}).get('message', 'Unknown error') + return False, f"Navidrome error: {err}" + return False, f"Navidrome returned status {resp.status_code}" + + elif provider_type == 'lyrion': + url = config_dict.get('url') if config_dict else config.LYRION_URL + if not url: + return False, "Lyrion URL is required" + resp = requests.get(f"{url.rstrip('/')}/status.html", timeout=10) + if resp.status_code == 200: + return True, f"Connected to Lyrion at {url}" + return False, f"Lyrion returned status {resp.status_code}" + + elif provider_type == 'emby': + url = config_dict.get('url') if config_dict else config.EMBY_URL + token = config_dict.get('token') if config_dict else config.EMBY_TOKEN + if not url or not token: + return False, "Emby URL and token are required" + resp = requests.get(f"{url.rstrip('/')}/System/Info", + headers={"X-Emby-Token": token}, timeout=10) + if resp.status_code == 200: + return True, f"Connected to Emby at {url}" + return False, f"Emby returned status {resp.status_code}" + + elif provider_type == 'mpd': + try: + from mpd import MPDClient + host = config_dict.get('host') if config_dict else config.MPD_HOST + port = config_dict.get('port') if config_dict else config.MPD_PORT + password = config_dict.get('password') if config_dict else config.MPD_PASSWORD + client = MPDClient() + client.timeout = 10 + client.connect(host, int(port)) + if password: + client.password(password) + stats = client.stats() + client.close() + client.disconnect() + return True, f"Connected to MPD at {host}:{port} ({stats.get('songs', 0)} songs)" + except Exception as e: + return False, f"MPD connection error: {str(e)}" + + else: + return False, f"Unknown provider type: {provider_type}" + + except requests.RequestException as e: + return False, f"Network error: {str(e)}" + except Exception as e: + return False, f"Connection test failed: {str(e)}" + + +def get_provider_info(provider_type: str): + """Get detailed information about a provider type including config fields.""" + if provider_type == 'localfiles': + return localfiles_get_provider_info() + + # Return basic info for other providers + if provider_type in PROVIDER_TYPES: + info = PROVIDER_TYPES[provider_type].copy() + info['type'] = provider_type + info['config_fields'] = _get_provider_config_fields(provider_type) + return info + + return None + + +def _get_provider_config_fields(provider_type: str): + """Get configuration fields for a provider type.""" + fields = { + 'jellyfin': [ + {'name': 'url', 'label': 'Server URL', 'type': 'url', 'required': True, + 'description': 'Jellyfin server URL (e.g., http://192.168.1.100:8096)'}, + {'name': 'user_id', 'label': 'User ID', 'type': 'text', 'required': True, + 'description': 'Jellyfin user ID (found in dashboard)'}, + {'name': 'token', 'label': 'API Token', 'type': 'password', 'required': True, + 'description': 'API key from Jellyfin settings'}, + ], + 'navidrome': [ + {'name': 'url', 'label': 'Server URL', 'type': 'url', 'required': True, + 'description': 'Navidrome server URL (e.g., http://192.168.1.100:4533)'}, + {'name': 'user', 'label': 'Username', 'type': 'text', 'required': True, + 'description': 'Navidrome username'}, + {'name': 'password', 'label': 'Password', 'type': 'password', 'required': True, + 'description': 'Navidrome password'}, + ], + 'lyrion': [ + {'name': 'url', 'label': 'Server URL', 'type': 'url', 'required': True, + 'description': 'Lyrion server URL (e.g., http://192.168.1.100:9000)'}, + ], + 'mpd': [ + {'name': 'host', 'label': 'Host', 'type': 'text', 'required': True, + 'description': 'MPD server hostname or IP', 'default': 'localhost'}, + {'name': 'port', 'label': 'Port', 'type': 'number', 'required': True, + 'description': 'MPD port number', 'default': 6600}, + {'name': 'password', 'label': 'Password', 'type': 'password', 'required': False, + 'description': 'MPD password (if configured)'}, + {'name': 'music_directory', 'label': 'Music Directory', 'type': 'path', 'required': True, + 'description': 'Path to music files on the MPD server'}, + ], + 'emby': [ + {'name': 'url', 'label': 'Server URL', 'type': 'url', 'required': True, + 'description': 'Emby server URL (e.g., http://192.168.1.100:8096)'}, + {'name': 'user_id', 'label': 'User ID', 'type': 'text', 'required': True, + 'description': 'Emby user ID'}, + {'name': 'token', 'label': 'API Token', 'type': 'password', 'required': True, + 'description': 'API key from Emby settings'}, + ], + } + return fields.get(provider_type, []) + diff --git a/tasks/mediaserver_localfiles.py b/tasks/mediaserver_localfiles.py new file mode 100644 index 00000000..df02dca1 --- /dev/null +++ b/tasks/mediaserver_localfiles.py @@ -0,0 +1,604 @@ +# tasks/mediaserver_localfiles.py +""" +Local File Media Provider for AudioMuse-AI + +This provider scans local directories for audio files and extracts metadata +from embedded tags (ID3 for MP3, Vorbis comments for FLAC/OGG, etc.). + +The item_id for each track is a SHA-256 hash of the normalized relative file path, +ensuring stable, predictable identifiers that won't change unless files move. +""" + +import logging +import os +import hashlib +import shutil +from datetime import datetime +from pathlib import Path, PurePosixPath +from typing import List, Dict, Optional, Tuple +import json + +try: + from mutagen import File as MutagenFile + from mutagen.mp3 import MP3 + from mutagen.flac import FLAC + from mutagen.oggvorbis import OggVorbis + from mutagen.mp4 import MP4 + from mutagen.id3 import ID3 + MUTAGEN_AVAILABLE = True +except ImportError: + MUTAGEN_AVAILABLE = False + +import config + +logger = logging.getLogger(__name__) + +# Supported audio formats +SUPPORTED_FORMATS = {'.mp3', '.flac', '.ogg', '.m4a', '.mp4', '.wav', '.wma', '.aac', '.opus'} + +# ############################################################################## +# CONFIGURATION +# ############################################################################## + +def get_config() -> Dict: + """Get local file provider configuration from environment or defaults.""" + return { + 'music_directory': os.environ.get('LOCALFILES_MUSIC_DIRECTORY', '/music'), + 'supported_formats': os.environ.get('LOCALFILES_FORMATS', ','.join(SUPPORTED_FORMATS)).split(','), + 'scan_subdirectories': os.environ.get('LOCALFILES_SCAN_SUBDIRS', 'true').lower() == 'true', + 'use_embedded_metadata': os.environ.get('LOCALFILES_USE_METADATA', 'true').lower() == 'true', + 'playlist_directory': os.environ.get('LOCALFILES_PLAYLIST_DIR', '/music/playlists'), + } + + +# ############################################################################## +# UTILITY FUNCTIONS +# ############################################################################## + +def normalize_file_path(path: str, base_path: str = "") -> str: + """ + Normalize a file path for cross-provider matching. + + - Convert to POSIX style (forward slashes) + - Make relative to music library root + - Strip leading/trailing whitespace + """ + p = Path(path) + + # Make relative if absolute and base_path provided + if base_path and p.is_absolute(): + try: + base = Path(base_path) + p = p.relative_to(base) + except ValueError: + pass # Not relative to base, keep as-is + + # Convert to POSIX style + normalized = PurePosixPath(p).as_posix() + + return normalized.strip() + + +def file_path_hash(normalized_path: str) -> str: + """Generate SHA-256 hash of normalized file path for use as item_id.""" + return hashlib.sha256(normalized_path.encode('utf-8')).hexdigest() + + +def extract_metadata(file_path: str) -> Dict: + """ + Extract metadata from an audio file using mutagen. + + Returns a dict with keys: title, artist, album, album_artist, track_number, year, genre + """ + metadata = { + 'title': os.path.splitext(os.path.basename(file_path))[0], # Default to filename + 'artist': 'Unknown Artist', + 'album': 'Unknown Album', + 'album_artist': None, + 'track_number': None, + 'year': None, + 'genre': None, + 'duration': None, + } + + if not MUTAGEN_AVAILABLE: + logger.warning("Mutagen not available, using filename as title") + return metadata + + try: + audio = MutagenFile(file_path, easy=True) + if audio is None: + logger.debug(f"Mutagen couldn't read: {file_path}") + return metadata + + # Extract common tags (easy=True gives us simplified tag access) + if hasattr(audio, 'info') and audio.info: + metadata['duration'] = getattr(audio.info, 'length', None) + + # Handle different tag formats + if isinstance(audio.tags, dict) or hasattr(audio, 'tags'): + tags = audio.tags if isinstance(audio.tags, dict) else dict(audio) + + # Title + if 'title' in tags: + val = tags['title'] + metadata['title'] = val[0] if isinstance(val, list) else str(val) + + # Artist + if 'artist' in tags: + val = tags['artist'] + metadata['artist'] = val[0] if isinstance(val, list) else str(val) + elif 'performer' in tags: + val = tags['performer'] + metadata['artist'] = val[0] if isinstance(val, list) else str(val) + + # Album + if 'album' in tags: + val = tags['album'] + metadata['album'] = val[0] if isinstance(val, list) else str(val) + + # Album Artist + if 'albumartist' in tags: + val = tags['albumartist'] + metadata['album_artist'] = val[0] if isinstance(val, list) else str(val) + elif 'album artist' in tags: + val = tags['album artist'] + metadata['album_artist'] = val[0] if isinstance(val, list) else str(val) + + # Track number + if 'tracknumber' in tags: + val = tags['tracknumber'] + track_str = val[0] if isinstance(val, list) else str(val) + try: + # Handle "1/12" format + metadata['track_number'] = int(track_str.split('/')[0]) + except (ValueError, IndexError): + pass + + # Year/Date + if 'date' in tags: + val = tags['date'] + date_str = val[0] if isinstance(val, list) else str(val) + try: + metadata['year'] = int(date_str[:4]) + except (ValueError, IndexError): + pass + elif 'year' in tags: + val = tags['year'] + year_str = val[0] if isinstance(val, list) else str(val) + try: + metadata['year'] = int(year_str) + except ValueError: + pass + + # Genre + if 'genre' in tags: + val = tags['genre'] + metadata['genre'] = val[0] if isinstance(val, list) else str(val) + + except Exception as e: + logger.warning(f"Error extracting metadata from {file_path}: {e}") + + return metadata + + +def _format_song(file_path: str, base_path: str) -> Dict: + """Format a local file into the standard song format used by AudioMuse-AI.""" + normalized_path = normalize_file_path(file_path, base_path) + item_id = file_path_hash(normalized_path) + + metadata = extract_metadata(file_path) + + # Get file stats + try: + stat = os.stat(file_path) + file_size = stat.st_size + file_modified = datetime.fromtimestamp(stat.st_mtime) + except OSError: + file_size = None + file_modified = None + + return { + 'Id': item_id, + 'Name': metadata['title'], + 'Artist': metadata['artist'], + 'AlbumArtist': metadata['album_artist'] or metadata['artist'], + 'Album': metadata['album'], + 'Path': file_path, + 'RelativePath': normalized_path, + 'TrackNumber': metadata['track_number'], + 'Year': metadata['year'], + 'Genre': metadata['genre'], + 'Duration': metadata['duration'], + 'FileSize': file_size, + 'last-modified': file_modified.isoformat() if file_modified else None, + # For compatibility with other providers + 'ArtistId': None, # Local files don't have artist IDs + } + + +# ############################################################################## +# PUBLIC API +# ############################################################################## + +def test_connection(config_override: Dict = None) -> Tuple[bool, str]: + """Test if the local file provider can access the music directory. + + Args: + config_override: Optional dict with configuration to test instead of default + """ + if config_override: + cfg = { + 'music_directory': config_override.get('music_directory', '/music'), + 'supported_formats': config_override.get('supported_formats', SUPPORTED_FORMATS), + 'scan_subdirectories': config_override.get('scan_subdirectories', True), + 'playlist_directory': config_override.get('playlist_directory', '/music/playlists'), + } + else: + cfg = get_config() + music_dir = cfg['music_directory'] + + if not os.path.exists(music_dir): + return False, f"Music directory does not exist: {music_dir}" + + if not os.path.isdir(music_dir): + return False, f"Music path is not a directory: {music_dir}" + + if not os.access(music_dir, os.R_OK): + return False, f"Music directory is not readable: {music_dir}" + + # Count files to verify + try: + audio_count = 0 + for root, _, files in os.walk(music_dir): + for f in files: + if os.path.splitext(f)[1].lower() in SUPPORTED_FORMATS: + audio_count += 1 + if audio_count >= 10: # Quick check, don't scan everything + break + if audio_count >= 10: + break + + if audio_count == 0: + return False, f"No audio files found in: {music_dir}" + + return True, f"Found audio files in: {music_dir}" + except Exception as e: + return False, f"Error scanning music directory: {e}" + + +def get_all_songs() -> List[Dict]: + """Fetch all audio files from the music directory.""" + cfg = get_config() + music_dir = cfg['music_directory'] + supported = set(fmt.lower() if fmt.startswith('.') else f'.{fmt.lower()}' + for fmt in cfg['supported_formats']) + scan_subdirs = cfg['scan_subdirectories'] + + all_songs = [] + + if not os.path.isdir(music_dir): + logger.error(f"Music directory not found: {music_dir}") + return [] + + logger.info(f"Scanning local music directory: {music_dir}") + + try: + if scan_subdirs: + for root, _, files in os.walk(music_dir): + for filename in files: + ext = os.path.splitext(filename)[1].lower() + if ext in supported: + full_path = os.path.join(root, filename) + try: + song = _format_song(full_path, music_dir) + all_songs.append(song) + except Exception as e: + logger.warning(f"Error processing {full_path}: {e}") + else: + for filename in os.listdir(music_dir): + ext = os.path.splitext(filename)[1].lower() + if ext in supported: + full_path = os.path.join(music_dir, filename) + if os.path.isfile(full_path): + try: + song = _format_song(full_path, music_dir) + all_songs.append(song) + except Exception as e: + logger.warning(f"Error processing {full_path}: {e}") + + logger.info(f"Found {len(all_songs)} audio files in local library") + + except Exception as e: + logger.error(f"Error scanning music directory: {e}", exc_info=True) + + return all_songs + + +def get_recent_albums(limit: int) -> List[Dict]: + """ + Get recently modified albums from the local music directory. + + For local files, we group songs by album and return the most recently + modified albums based on the newest file in each album. + """ + cfg = get_config() + music_dir = cfg['music_directory'] + + all_songs = get_all_songs() + if not all_songs: + return [] + + # Group by album + albums = {} + for song in all_songs: + album_name = song.get('Album', 'Unknown Album') + album_artist = song.get('AlbumArtist', 'Unknown Artist') + album_key = f"{album_artist} - {album_name}" + + if album_key not in albums: + albums[album_key] = { + 'Id': album_key, # Use album name as ID + 'Name': album_name, + 'Artist': album_artist, + 'tracks': [], + 'last_modified': None + } + + albums[album_key]['tracks'].append(song) + + # Track the most recent modification time + mod_time = song.get('last-modified') + if mod_time: + if albums[album_key]['last_modified'] is None or mod_time > albums[album_key]['last_modified']: + albums[album_key]['last_modified'] = mod_time + + # Sort by modification time (most recent first) + sorted_albums = sorted( + albums.values(), + key=lambda a: a.get('last_modified') or '', + reverse=True + ) + + # Return requested limit (0 = all) + if limit == 0: + return sorted_albums + return sorted_albums[:limit] + + +def get_tracks_from_album(album_id: str) -> List[Dict]: + """ + Get all tracks from an album. + + For local files, album_id is "Artist - Album Name" format. + """ + all_songs = get_all_songs() + + # Filter songs matching this album + tracks = [] + for song in all_songs: + album_name = song.get('Album', 'Unknown Album') + album_artist = song.get('AlbumArtist', 'Unknown Artist') + song_album_key = f"{album_artist} - {album_name}" + + if song_album_key == album_id or album_name == album_id: + tracks.append(song) + + # Sort by track number if available + tracks.sort(key=lambda t: (t.get('TrackNumber') or 999, t.get('Name', ''))) + + logger.info(f"Found {len(tracks)} tracks for album '{album_id}'") + return tracks + + +def download_track(temp_dir: str, item: Dict) -> Optional[str]: + """ + 'Download' a track - for local files, we simply copy to temp directory. + + Returns the path to the temporary file. + """ + source_path = item.get('Path') + if not source_path or not os.path.exists(source_path): + logger.error(f"Source file not found: {source_path}") + return None + + try: + # Create a unique filename in temp directory + filename = os.path.basename(source_path) + dest_path = os.path.join(temp_dir, filename) + + # Handle filename collisions + if os.path.exists(dest_path): + name, ext = os.path.splitext(filename) + item_id = item.get('Id', '')[:8] + dest_path = os.path.join(temp_dir, f"{name}_{item_id}{ext}") + + # Copy file to temp directory + shutil.copy2(source_path, dest_path) + logger.info(f"Copied '{item.get('Name', filename)}' to temp directory") + + return dest_path + + except Exception as e: + logger.error(f"Error copying file {source_path}: {e}", exc_info=True) + return None + + +def get_all_playlists() -> List[Dict]: + """Get all M3U playlists from the playlist directory.""" + cfg = get_config() + playlist_dir = cfg['playlist_directory'] + + playlists = [] + + if not os.path.isdir(playlist_dir): + logger.info(f"Playlist directory not found: {playlist_dir}") + return playlists + + try: + for filename in os.listdir(playlist_dir): + if filename.lower().endswith(('.m3u', '.m3u8')): + name = os.path.splitext(filename)[0] + playlists.append({ + 'Id': filename, + 'Name': name, + 'Path': os.path.join(playlist_dir, filename) + }) + except Exception as e: + logger.error(f"Error listing playlists: {e}") + + return playlists + + +def get_playlist_by_name(playlist_name: str) -> Optional[Dict]: + """Find a playlist by name.""" + playlists = get_all_playlists() + for p in playlists: + if p['Name'] == playlist_name: + return p + return None + + +def create_playlist(base_name: str, item_ids: List[str]) -> Optional[str]: + """ + Create an M3U playlist file. + + item_ids are the file path hashes - we need to look up the actual paths. + """ + cfg = get_config() + playlist_dir = cfg['playlist_directory'] + music_dir = cfg['music_directory'] + + # Ensure playlist directory exists + os.makedirs(playlist_dir, exist_ok=True) + + # Build a lookup from item_id to file path + all_songs = get_all_songs() + id_to_path = {song['Id']: song['Path'] for song in all_songs} + + # Resolve paths + paths = [] + for item_id in item_ids: + if item_id in id_to_path: + # Use relative path for portability + full_path = id_to_path[item_id] + try: + rel_path = os.path.relpath(full_path, playlist_dir) + except ValueError: + rel_path = full_path # Different drive on Windows + paths.append(rel_path) + else: + logger.warning(f"Track not found for item_id: {item_id}") + + if not paths: + logger.error("No valid tracks found for playlist") + return None + + # Write M3U file + playlist_name = f"{base_name}_automatic.m3u" + playlist_path = os.path.join(playlist_dir, playlist_name) + + try: + with open(playlist_path, 'w', encoding='utf-8') as f: + f.write("#EXTM3U\n") + for path in paths: + f.write(f"{path}\n") + + logger.info(f"Created playlist '{playlist_name}' with {len(paths)} tracks") + return playlist_name + + except Exception as e: + logger.error(f"Error creating playlist: {e}", exc_info=True) + return None + + +def delete_playlist(playlist_id: str) -> bool: + """Delete an M3U playlist file.""" + cfg = get_config() + playlist_dir = cfg['playlist_directory'] + + playlist_path = os.path.join(playlist_dir, playlist_id) + + if not os.path.exists(playlist_path): + logger.warning(f"Playlist file not found: {playlist_path}") + return False + + try: + os.remove(playlist_path) + logger.info(f"Deleted playlist: {playlist_id}") + return True + except Exception as e: + logger.error(f"Error deleting playlist: {e}") + return False + + +def create_instant_playlist(playlist_name: str, item_ids: List[str], user_creds=None) -> Optional[Dict]: + """Create an instant playlist (same as regular playlist for local files).""" + final_name = f"{playlist_name.strip()}_instant" + result = create_playlist(final_name, item_ids) + if result: + return {'Id': result, 'Name': final_name} + return None + + +def get_top_played_songs(limit: int, user_creds=None) -> List[Dict]: + """Not supported for local files - no play history tracking.""" + logger.warning("get_top_played_songs is not supported for local files provider") + return [] + + +def get_last_played_time(item_id: str, user_creds=None): + """Not supported for local files - no play history tracking.""" + logger.warning("get_last_played_time is not supported for local files provider") + return None + + +# ############################################################################## +# PROVIDER INFO +# ############################################################################## + +def get_provider_info() -> Dict: + """Return information about this provider.""" + cfg = get_config() + return { + 'type': 'localfiles', + 'name': 'Local Files', + 'description': 'Scan local directories for audio files', + 'supports_playlists': True, + 'supports_play_history': False, + 'supports_user_auth': False, + 'config_fields': [ + { + 'name': 'music_directory', + 'label': 'Music Directory', + 'type': 'path', + 'required': True, + 'description': 'Path to your music library folder', + 'default': '/music' + }, + { + 'name': 'supported_formats', + 'label': 'Supported Formats', + 'type': 'text', + 'required': False, + 'description': 'Comma-separated list of audio file extensions', + 'default': ','.join(SUPPORTED_FORMATS) + }, + { + 'name': 'scan_subdirectories', + 'label': 'Scan Subdirectories', + 'type': 'boolean', + 'required': False, + 'description': 'Include files in subdirectories', + 'default': True + }, + { + 'name': 'playlist_directory', + 'label': 'Playlist Directory', + 'type': 'path', + 'required': False, + 'description': 'Where to save generated M3U playlists', + 'default': '/music/playlists' + } + ] + } diff --git a/templates/setup.html b/templates/setup.html new file mode 100644 index 00000000..08f65f23 --- /dev/null +++ b/templates/setup.html @@ -0,0 +1,1028 @@ +{% extends "includes/layout.html" %} + +{% block headAdditions %} + +{% endblock %} + +{% block content %} +
+
+

AudioMuse-AI Setup

+

Configure your music analysis system

+
+ + +
+
+
1
+ Welcome +
+
+
2
+ Providers +
+
+
3
+ Settings +
+
+
4
+ Complete +
+
+ + +
+
+

Welcome to AudioMuse-AI

+

+ This setup wizard will help you configure AudioMuse-AI to analyze your music library + and create intelligent playlists. Let's get started! +

+ + +
+ +
+

Select Hardware Configuration

+

+ Choose your hardware setup for music analysis. GPU acceleration significantly speeds up + the ML analysis process. +

+ +
+
+
💻
+
CPU Only
+
Works on any system. Slower analysis but no special requirements.
+
+
+
+
NVIDIA GPU
+
Faster analysis with CUDA acceleration. Requires NVIDIA GPU with drivers.
+
+
+
+ +
+
+ +
+
+ + +
+
+

Configure Media Providers

+

+ Select one or more media providers. You can add multiple providers to analyze music + from different sources. Tracks are linked by file path, so the same song in different + providers will share analysis data. +

+ +
+ +
+
+ + + + + +
+ + +
+
+ + +
+
+

Analysis Settings

+

+ Configure optional analysis features. The defaults work well for most users. +

+ +
+ +
+ Allows searching your music using natural language queries like "upbeat summer songs". + Uses additional memory (~750MB). +
+
+ +
+ +
+ Accelerate clustering with NVIDIA RAPIDS cuML. Only available with NVIDIA GPU. +
+
+
+ +
+ + Advanced Settings +
+ +
+
+

Database Configuration

+

+ These settings are typically configured via environment variables in Docker. + Only change if you know what you're doing. +

+ +
+ + +
Database host (set via POSTGRES_HOST env var)
+
+ +
+ + +
Redis connection URL (set via REDIS_URL env var)
+
+
+ +
+

AI Playlist Naming

+

+ Optionally configure an AI provider for creative playlist names. +

+ +
+ + +
AI service for generating creative playlist names
+
+
+
+ +
+ + +
+
+ + +
+
+

Setup Summary

+

+ Review your configuration before completing the setup. +

+ +
+ +
+
+ +
+

What's Next?

+
    +
  • Run the analysis to scan and analyze your music library
  • +
  • Generate intelligent playlists based on audio fingerprints
  • +
  • Explore your music with similarity search and visualization tools
  • +
+
+ +
+ + +
+
+
+{% endblock %} + +{% block bodyAdditions %} + + +{% endblock %} diff --git a/templates/sidebar_navi.html b/templates/sidebar_navi.html index d8e98bb8..1f8932b9 100644 --- a/templates/sidebar_navi.html +++ b/templates/sidebar_navi.html @@ -12,4 +12,5 @@
  • Cleaning
  • Scheduled Tasks
  • +
  • Setup & Providers
  • \ No newline at end of file From 79c89708439e67fe72a542bd02a5bc08b7744b35 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 4 Feb 2026 04:27:53 +0000 Subject: [PATCH 02/33] Add provider-aware API helpers with backward compatibility Implements provider fallback logic for API calls: - When provider_id specified: look up in that provider, fall back to score - When provider_id NOT specified (backward compatible): 1. Try primary provider first 2. Try other enabled providers by priority 3. Fall back to direct score table (legacy mode) New helper functions: - get_track_by_item_id(item_id, provider_id=None) - get_tracks_by_item_ids(item_ids, provider_id=None) - get_primary_provider_id() - get_enabled_provider_ids() - resolve_item_id_to_provider(item_id) - get_item_id_for_provider(file_path_or_track_id, provider_id) - is_multi_provider_mode() - set_primary_provider(provider_id) https://claude.ai/code/session_011AebTWAucDafK4m6uoSSNg --- app_helper.py | 248 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 246 insertions(+), 2 deletions(-) diff --git a/app_helper.py b/app_helper.py index 25a47f90..24f56f20 100644 --- a/app_helper.py +++ b/app_helper.py @@ -1076,5 +1076,249 @@ def cancel_job_and_children_recursive(job_id, task_type_from_db=None, reason="Ta if child_db_info and child_db_info.get('status') not in [TASK_STATUS_SUCCESS, TASK_STATUS_FAILURE, TASK_STATUS_REVOKED]: logger.info(f"Recursively cancelling child job: {child_job_id}") cancelled_count += cancel_job_and_children_recursive(child_job_id, reason="Cancelled due to parent task revocation.") - - return cancelled_count \ No newline at end of file + + return cancelled_count + + +# ############################################################################## +# MULTI-PROVIDER HELPER FUNCTIONS +# ############################################################################## + +def get_primary_provider_id(): + """Get the primary provider ID from app_settings.""" + db = get_db() + with db.cursor() as cur: + cur.execute("SELECT value FROM app_settings WHERE key = 'primary_provider_id'") + row = cur.fetchone() + if row and row[0] is not None: + try: + # Value is stored as JSONB, could be int or null + val = row[0] + if isinstance(val, int): + return val + if val is None or val == 'null': + return None + return int(val) + except (ValueError, TypeError): + return None + return None + + +def get_enabled_provider_ids(): + """Get list of enabled provider IDs ordered by priority (highest first).""" + db = get_db() + with db.cursor() as cur: + cur.execute(""" + SELECT id FROM provider + WHERE enabled = TRUE + ORDER BY priority DESC, created_at ASC + """) + return [row[0] for row in cur.fetchall()] + + +def get_track_by_item_id(item_id, provider_id=None): + """ + Look up a track by item_id with provider fallback logic. + + If provider_id is specified: + - Look up in provider_track for that provider first + - Fall back to score table if not in provider_track + + If provider_id is NOT specified (backward compatible mode): + 1. Try the primary provider first + 2. Try other enabled providers in priority order + 3. Fall back to direct score table lookup (legacy mode) + + Returns: + dict with track info or None if not found + """ + db = get_db() + + def lookup_in_score(item_id): + """Direct lookup in score table (legacy mode).""" + with db.cursor() as cur: + cur.execute(""" + SELECT item_id, title, author, album, tempo, key, scale, + mood_vector, energy, other_features, file_path, track_id + FROM score WHERE item_id = %s + """, (item_id,)) + row = cur.fetchone() + if row: + return { + 'item_id': row[0], + 'title': row[1], + 'author': row[2], + 'album': row[3], + 'tempo': row[4], + 'key': row[5], + 'scale': row[6], + 'mood_vector': row[7], + 'energy': row[8], + 'other_features': row[9], + 'file_path': row[10], + 'track_id': row[11], + 'provider_id': None # Unknown provider in legacy mode + } + return None + + def lookup_via_provider(item_id, prov_id): + """Look up via provider_track table.""" + with db.cursor() as cur: + # First check provider_track + cur.execute(""" + SELECT pt.item_id, pt.title, pt.artist, pt.album, pt.track_id, + s.tempo, s.key, s.scale, s.mood_vector, s.energy, + s.other_features, s.file_path + FROM provider_track pt + LEFT JOIN score s ON ( + pt.item_id = s.item_id OR + (pt.track_id IS NOT NULL AND pt.track_id = s.track_id) + ) + WHERE pt.provider_id = %s AND pt.item_id = %s + """, (prov_id, item_id)) + row = cur.fetchone() + if row: + return { + 'item_id': row[0], + 'title': row[1], + 'author': row[2], + 'album': row[3], + 'track_id': row[4], + 'tempo': row[5], + 'key': row[6], + 'scale': row[7], + 'mood_vector': row[8], + 'energy': row[9], + 'other_features': row[10], + 'file_path': row[11], + 'provider_id': prov_id + } + return None + + # If provider_id specified, try that provider first then fall back + if provider_id is not None: + result = lookup_via_provider(item_id, provider_id) + if result: + return result + # Fall back to direct score lookup + return lookup_in_score(item_id) + + # No provider specified - use fallback logic + # 1. Try primary provider first + primary_id = get_primary_provider_id() + if primary_id: + result = lookup_via_provider(item_id, primary_id) + if result: + return result + + # 2. Try other enabled providers in priority order + enabled_ids = get_enabled_provider_ids() + for prov_id in enabled_ids: + if prov_id == primary_id: + continue # Already tried + result = lookup_via_provider(item_id, prov_id) + if result: + return result + + # 3. Fall back to direct score table lookup (legacy/backward compatible) + return lookup_in_score(item_id) + + +def get_tracks_by_item_ids(item_ids, provider_id=None): + """ + Look up multiple tracks by item_ids with provider fallback logic. + + Args: + item_ids: List of item IDs to look up + provider_id: Optional provider ID to scope the lookup + + Returns: + dict mapping item_id to track info + """ + if not item_ids: + return {} + + results = {} + for item_id in item_ids: + track = get_track_by_item_id(item_id, provider_id) + if track: + results[item_id] = track + + return results + + +def resolve_item_id_to_provider(item_id): + """ + Resolve which provider(s) know about a given item_id. + + Returns: + List of provider_ids that have this item_id, + or empty list if only in score table (legacy) + """ + db = get_db() + with db.cursor() as cur: + cur.execute(""" + SELECT DISTINCT provider_id FROM provider_track + WHERE item_id = %s + """, (item_id,)) + return [row[0] for row in cur.fetchall()] + + +def get_item_id_for_provider(file_path_or_track_id, provider_id): + """ + Get the provider-specific item_id for a track. + + Useful when you have analysis data linked to one provider + and need to find the equivalent track in another provider. + + Args: + file_path_or_track_id: Either file path (str) or track_id (int) + provider_id: The provider to look up in + + Returns: + The item_id for that provider, or None if not found + """ + db = get_db() + with db.cursor() as cur: + if isinstance(file_path_or_track_id, int): + # Lookup by track_id + cur.execute(""" + SELECT item_id FROM provider_track + WHERE provider_id = %s AND track_id = %s + """, (provider_id, file_path_or_track_id)) + else: + # Lookup by file path - need to join through track table + cur.execute(""" + SELECT pt.item_id FROM provider_track pt + JOIN track t ON pt.track_id = t.id + WHERE pt.provider_id = %s AND t.file_path = %s + """, (provider_id, file_path_or_track_id)) + + row = cur.fetchone() + return row[0] if row else None + + +def is_multi_provider_mode(): + """Check if multi-provider mode is enabled.""" + db = get_db() + with db.cursor() as cur: + cur.execute("SELECT value FROM app_settings WHERE key = 'multi_provider_enabled'") + row = cur.fetchone() + if row: + val = row[0] + return val is True or val == True or val == 'true' + return False + + +def set_primary_provider(provider_id): + """Set the primary provider ID.""" + db = get_db() + with db.cursor() as cur: + cur.execute(""" + INSERT INTO app_settings (key, value, category, description, updated_at) + VALUES ('primary_provider_id', %s, 'providers', 'ID of the primary provider', NOW()) + ON CONFLICT (key) DO UPDATE SET + value = EXCLUDED.value, + updated_at = NOW() + """, (str(provider_id) if provider_id is not None else 'null',)) + db.commit() \ No newline at end of file From 6d575fc35058f640df1d84f77f1fd9398c1f574c Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 4 Feb 2026 04:42:17 +0000 Subject: [PATCH 03/33] Add split deployment support for remote workers - Add docker-compose-worker-cpu.yaml for CPU-only remote workers - Update docker-compose-worker-nvidia.yaml with all provider configs - Update docker-compose-server.yaml for server-only deployment - Add worker connection settings to .env.example (WORKER_REDIS_URL, WORKER_POSTGRES_HOST) - Add deployment mode selection to setup wizard (unified/split) - Add worker configuration section in setup wizard with connection info - Add server-info API endpoint for automatic IP detection This allows running ML analysis workers on separate machines from the main Flask server, useful for utilizing dedicated GPU servers or distributing workload across multiple workers. https://claude.ai/code/session_011AebTWAucDafK4m6uoSSNg --- app_setup.py | 43 ++++ deployment/.env.example | 30 +++ deployment/docker-compose-server.yaml | 116 ++++++++--- deployment/docker-compose-worker-cpu.yaml | 113 +++++++++++ deployment/docker-compose-worker-nvidia.yaml | 118 ++++++++--- templates/setup.html | 200 ++++++++++++++++++- 6 files changed, 569 insertions(+), 51 deletions(-) create mode 100644 deployment/docker-compose-worker-cpu.yaml diff --git a/app_setup.py b/app_setup.py index db895861..507441be 100644 --- a/app_setup.py +++ b/app_setup.py @@ -680,3 +680,46 @@ def set_primary_provider(): 'message': 'Primary provider set', 'primary_provider_id': provider_id }) + + +@setup_bp.route('/api/setup/server-info', methods=['GET']) +def get_server_info(): + """ + Get server connection information for configuring remote workers. + --- + tags: + - Setup + responses: + 200: + description: Server connection information + """ + import socket + import os + + # Try to get the server's IP address + try: + # Get the hostname and try to resolve it + hostname = socket.gethostname() + host_ip = socket.gethostbyname(hostname) + # If we get a loopback address, try to get a better one + if host_ip.startswith('127.'): + # Try to connect to a public DNS to get our real IP + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(('8.8.8.8', 80)) + host_ip = s.getsockname()[0] + except Exception: + host_ip = hostname # Fall back to hostname + finally: + s.close() + except Exception: + host_ip = 'localhost' + + return jsonify({ + 'host': host_ip, + 'hostname': socket.gethostname() if hasattr(socket, 'gethostname') else 'unknown', + 'redis_port': os.environ.get('REDIS_PORT', '6379'), + 'postgres_port': os.environ.get('POSTGRES_PORT', '5432'), + 'postgres_host': os.environ.get('POSTGRES_HOST', 'postgres'), + 'redis_url': os.environ.get('REDIS_URL', 'redis://redis:6379/0'), + }) diff --git a/deployment/.env.example b/deployment/.env.example index d7e7db62..fc8867a4 100644 --- a/deployment/.env.example +++ b/deployment/.env.example @@ -96,6 +96,36 @@ WORKER_PORT=8029 # Timezone (examples: UTC, Europe/Berlin, America/Los_Angeles) TZ=UTC +# ============================================================================= +# SPLIT DEPLOYMENT (Remote Worker Configuration) +# ============================================================================= +# Use these settings when running workers on a separate machine from the server. +# +# DEPLOYMENT OPTIONS: +# 1. UNIFIED (default) - Server and worker on same machine +# Use: docker-compose-unified.yaml or docker-compose-unified-nvidia.yaml +# Leave these settings at defaults +# +# 2. SPLIT - Server and worker on different machines +# Server machine: Run docker-compose-server.yaml +# Worker machine: Run docker-compose-worker-cpu.yaml or docker-compose-worker-nvidia.yaml +# Configure WORKER_* settings below on worker machine +# +# On the WORKER machine, set these to point to the SERVER machine: +# +# Redis URL on server (replace SERVER_IP with your server's IP/hostname) +# Format: redis://[password@]host:port/db +# Examples: +# WORKER_REDIS_URL=redis://192.168.1.100:6379/0 +# WORKER_REDIS_URL=redis://:mypassword@192.168.1.100:6379/0 +WORKER_REDIS_URL=redis://redis:6379/0 +# +# PostgreSQL host on server (replace with your server's IP/hostname) +# Examples: +# WORKER_POSTGRES_HOST=192.168.1.100 +# WORKER_POSTGRES_HOST=my-server.local +WORKER_POSTGRES_HOST=postgres + # ============================================================================= # NVIDIA GPU (for docker-compose-unified-nvidia.yaml) # ============================================================================= diff --git a/deployment/docker-compose-server.yaml b/deployment/docker-compose-server.yaml index 757c3e9a..8c968d5b 100644 --- a/deployment/docker-compose-server.yaml +++ b/deployment/docker-compose-server.yaml @@ -1,13 +1,38 @@ -# AudioMuse-AI Deployment Configuration -# -# SERVER TEMPLATE - Run database, Redis and Flask API -# This can be run on a lightweight server, tested on an N100 mini PC (without heavy CPU or GPU requirements). -# For remote worker setup with CPU/GPU on a separate machine, use docker-compose-worker.yaml on the remote machine -# and configure WORKER_POSTGRES_HOST and WORKER_REDIS_URL in that worker's .env file to point to this server. +# AudioMuse-AI Server-Only Docker Compose +# ============================================================================= +# SERVER TEMPLATE - Runs Flask API + Redis + PostgreSQL (NO workers) +# +# Use this for split deployments where: +# - Server runs on lightweight hardware (N100, Raspberry Pi, NAS, etc.) +# - Workers run on separate machines with GPU/better CPU +# +# Quick Start: +# 1. Copy .env.example to .env and configure settings +# 2. Run: docker-compose -f docker-compose-server.yaml up -d +# 3. Note this server's IP address for worker configuration +# 4. On worker machine(s), use docker-compose-worker-cpu.yaml or +# docker-compose-worker-nvidia.yaml with these .env settings: +# WORKER_REDIS_URL=redis://SERVER_IP:6379/0 +# WORKER_POSTGRES_HOST=SERVER_IP +# +# Network Requirements (ports to open for workers): +# - Port 6379: Redis (workers connect here for task queue) +# - Port 5432: PostgreSQL (workers connect here for data) +# - Port 8000: Web UI (optional, only if exposing to users) +# +# Security Note: +# For production, consider: +# - VPN or private network between server and workers +# - Redis password (uncomment command in redis service) +# - PostgreSQL with SSL and strong password +# ============================================================================= version: '3.8' + services: - # Redis service for RQ (task queue) + # --------------------------------------------------------------------------- + # Redis - Task Queue (exposed for remote workers) + # --------------------------------------------------------------------------- redis: image: redis:7-alpine container_name: audiomuse-redis @@ -16,8 +41,17 @@ services: volumes: - redis-data:/data restart: unless-stopped + # Uncomment for password protection: + # command: redis-server --requirepass ${REDIS_PASSWORD:-changeme} + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 30s + timeout: 10s + retries: 3 - # PostgreSQL database service + # --------------------------------------------------------------------------- + # PostgreSQL - Database (exposed for remote workers) + # --------------------------------------------------------------------------- postgres: image: postgres:15-alpine container_name: audiomuse-postgres @@ -30,42 +64,76 @@ services: volumes: - postgres-data:/var/lib/postgresql/data restart: unless-stopped + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-audiomuse}"] + interval: 30s + timeout: 10s + retries: 3 - # AudioMuse-AI Flask application service + # --------------------------------------------------------------------------- + # AudioMuse-AI Flask Application (Web UI & API only) + # NOTE: No worker - analysis tasks are sent to remote workers via Redis + # --------------------------------------------------------------------------- audiomuse-ai-flask: - image: ghcr.io/neptunehub/audiomuse-ai:latest-nvidia + image: ghcr.io/neptunehub/audiomuse-ai:latest container_name: audiomuse-ai-flask-app ports: - "${FRONTEND_PORT:-8000}:8000" environment: SERVICE_TYPE: "flask" TZ: "${TZ:-UTC}" - MEDIASERVER_TYPE: "jellyfin" - JELLYFIN_USER_ID: "${JELLYFIN_USER_ID}" - JELLYFIN_TOKEN: "${JELLYFIN_TOKEN}" - JELLYFIN_URL: "${JELLYFIN_URL}" + # Media Server Configuration (configure via GUI or .env) + MEDIASERVER_TYPE: "${MEDIASERVER_TYPE:-localfiles}" + JELLYFIN_URL: "${JELLYFIN_URL:-}" + JELLYFIN_USER_ID: "${JELLYFIN_USER_ID:-}" + JELLYFIN_TOKEN: "${JELLYFIN_TOKEN:-}" + NAVIDROME_URL: "${NAVIDROME_URL:-}" + NAVIDROME_USER: "${NAVIDROME_USER:-}" + NAVIDROME_PASSWORD: "${NAVIDROME_PASSWORD:-}" + LYRION_URL: "${LYRION_URL:-}" + MPD_HOST: "${MPD_HOST:-}" + MPD_PORT: "${MPD_PORT:-6600}" + MPD_PASSWORD: "${MPD_PASSWORD:-}" + MPD_MUSIC_DIRECTORY: "${MPD_MUSIC_DIRECTORY:-/music}" + EMBY_URL: "${EMBY_URL:-}" + EMBY_USER_ID: "${EMBY_USER_ID:-}" + EMBY_TOKEN: "${EMBY_TOKEN:-}" + LOCALFILES_MUSIC_DIRECTORY: "${LOCALFILES_MUSIC_DIRECTORY:-/music}" + LOCALFILES_PLAYLIST_DIR: "${LOCALFILES_PLAYLIST_DIR:-/music/playlists}" + # Database Configuration POSTGRES_USER: ${POSTGRES_USER:-audiomuse} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} POSTGRES_HOST: "postgres" POSTGRES_PORT: "${POSTGRES_PORT:-5432}" REDIS_URL: "${REDIS_URL:-redis://redis:6379/0}" - AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER}" - OPENAI_API_KEY: "${OPENAI_API_KEY}" - OPENAI_SERVER_URL: "${OPENAI_SERVER_URL}" - OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME}" - GEMINI_API_KEY: "${GEMINI_API_KEY}" - MISTRAL_API_KEY: "${MISTRAL_API_KEY}" - CLAP_ENABLED: "${CLAP_ENABLED:-true}" # Enable CLAP text search (set to false for slower systems) + # AI Configuration + AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER:-NONE}" + OPENAI_API_KEY: "${OPENAI_API_KEY:-}" + OPENAI_SERVER_URL: "${OPENAI_SERVER_URL:-}" + OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME:-}" + GEMINI_API_KEY: "${GEMINI_API_KEY:-}" + MISTRAL_API_KEY: "${MISTRAL_API_KEY:-}" + # Flask doesn't need CLAP models - workers handle that + CLAP_ENABLED: "false" TEMP_DIR: "/app/temp_audio" volumes: - temp-audio-flask:/app/temp_audio + - ${MUSIC_PATH:-./music}:/music:ro depends_on: - - redis - - postgres + redis: + condition: service_healthy + postgres: + condition: service_healthy restart: unless-stopped +# ============================================================================= +# Volumes +# ============================================================================= volumes: redis-data: + name: audiomuse-redis-data postgres-data: - temp-audio-flask: \ No newline at end of file + name: audiomuse-postgres-data + temp-audio-flask: + name: audiomuse-temp-flask \ No newline at end of file diff --git a/deployment/docker-compose-worker-cpu.yaml b/deployment/docker-compose-worker-cpu.yaml new file mode 100644 index 00000000..ae093b87 --- /dev/null +++ b/deployment/docker-compose-worker-cpu.yaml @@ -0,0 +1,113 @@ +# AudioMuse-AI Worker-Only Docker Compose (CPU) +# ============================================================================= +# WORKER TEMPLATE - Runs RQ workers only, connects to remote server +# +# Use this for split deployments where: +# - Server (Flask + Redis + PostgreSQL) runs on separate machine +# - This worker handles CPU-intensive ML analysis tasks +# +# Prerequisites: +# - Server running docker-compose-server.yaml on another machine +# - Network connectivity to server (ports 6379 Redis, 5432 PostgreSQL) +# - Access to same music files (via media server API or shared storage) +# +# Quick Start: +# 1. Copy .env.example to .env on this worker machine +# 2. Configure connection to remote server: +# WORKER_REDIS_URL=redis://SERVER_IP:6379/0 +# WORKER_POSTGRES_HOST=SERVER_IP +# 3. Copy media server credentials from server's .env +# 4. Run: docker-compose -f docker-compose-worker-cpu.yaml up -d +# +# Scaling: +# - Run multiple workers on different machines +# - All workers connect to same Redis queue +# - Tasks automatically distributed across workers +# ============================================================================= + +version: '3.8' + +services: + # --------------------------------------------------------------------------- + # AudioMuse-AI Worker (CPU-only, connects to remote server) + # --------------------------------------------------------------------------- + audiomuse-ai-worker: + image: ghcr.io/neptunehub/audiomuse-ai:latest + container_name: audiomuse-ai-worker-cpu + environment: + SERVICE_TYPE: "worker" + TZ: "${TZ:-UTC}" + # ======================================================================= + # REMOTE SERVER CONNECTION (REQUIRED) + # ======================================================================= + # Redis URL - point to server machine + # Format: redis://[password@]host:port/db + # Examples: + # redis://192.168.1.100:6379/0 + # redis://:mypassword@192.168.1.100:6379/0 + REDIS_URL: "${WORKER_REDIS_URL:-redis://redis:6379/0}" + # PostgreSQL host - IP or hostname of server machine + POSTGRES_HOST: "${WORKER_POSTGRES_HOST:-postgres}" + POSTGRES_PORT: "${POSTGRES_PORT:-5432}" + POSTGRES_USER: ${POSTGRES_USER:-audiomuse} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} + POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} + # ======================================================================= + # MEDIA SERVER CONFIGURATION (must match server) + # ======================================================================= + MEDIASERVER_TYPE: "${MEDIASERVER_TYPE:-localfiles}" + JELLYFIN_URL: "${JELLYFIN_URL:-}" + JELLYFIN_USER_ID: "${JELLYFIN_USER_ID:-}" + JELLYFIN_TOKEN: "${JELLYFIN_TOKEN:-}" + NAVIDROME_URL: "${NAVIDROME_URL:-}" + NAVIDROME_USER: "${NAVIDROME_USER:-}" + NAVIDROME_PASSWORD: "${NAVIDROME_PASSWORD:-}" + LYRION_URL: "${LYRION_URL:-}" + MPD_HOST: "${MPD_HOST:-}" + MPD_PORT: "${MPD_PORT:-6600}" + MPD_PASSWORD: "${MPD_PASSWORD:-}" + MPD_MUSIC_DIRECTORY: "${MPD_MUSIC_DIRECTORY:-/music}" + EMBY_URL: "${EMBY_URL:-}" + EMBY_USER_ID: "${EMBY_USER_ID:-}" + EMBY_TOKEN: "${EMBY_TOKEN:-}" + LOCALFILES_MUSIC_DIRECTORY: "${LOCALFILES_MUSIC_DIRECTORY:-/music}" + LOCALFILES_PLAYLIST_DIR: "${LOCALFILES_PLAYLIST_DIR:-/music/playlists}" + # ======================================================================= + # AI CONFIGURATION (optional) + # ======================================================================= + AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER:-NONE}" + OPENAI_API_KEY: "${OPENAI_API_KEY:-}" + OPENAI_SERVER_URL: "${OPENAI_SERVER_URL:-}" + OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME:-}" + GEMINI_API_KEY: "${GEMINI_API_KEY:-}" + MISTRAL_API_KEY: "${MISTRAL_API_KEY:-}" + # ======================================================================= + # WORKER FEATURES + # ======================================================================= + # CLAP text search - uses ~750MB RAM + CLAP_ENABLED: "${CLAP_ENABLED:-true}" + # GPU clustering - disabled for CPU worker + USE_GPU_CLUSTERING: "false" + # Worker tuning + RQ_MAX_JOBS: "${RQ_MAX_JOBS:-50}" + RQ_MAX_JOBS_HIGH: "${RQ_MAX_JOBS_HIGH:-100}" + RQ_LOGGING_LEVEL: "${RQ_LOGGING_LEVEL:-INFO}" + TEMP_DIR: "/app/temp_audio" + volumes: + - temp-audio-worker:/app/temp_audio + # Mount music directory if using local files provider + - ${MUSIC_PATH:-./music}:/music:ro + restart: unless-stopped + # Optional: Limit CPU usage + # deploy: + # resources: + # limits: + # cpus: '4' + # memory: 8G + +# ============================================================================= +# Volumes +# ============================================================================= +volumes: + temp-audio-worker: + name: audiomuse-temp-worker-cpu diff --git a/deployment/docker-compose-worker-nvidia.yaml b/deployment/docker-compose-worker-nvidia.yaml index 4b6021e0..18d7cf90 100644 --- a/deployment/docker-compose-worker-nvidia.yaml +++ b/deployment/docker-compose-worker-nvidia.yaml @@ -1,50 +1,118 @@ -# AudioMuse-AI Deployment Configuration +# AudioMuse-AI Worker-Only Docker Compose (NVIDIA GPU) +# ============================================================================= +# WORKER TEMPLATE - Runs RQ workers only with GPU acceleration # -# WORKER TEMPLATE - Run this for heavy CPU/GPU tasks like analysis and clustering, connected to a lightweight server with Jellyfin and AudioMuse-AI Flask application and databases. -# This configuration is intended for deployment on a server with NVIDIA GPU support. +# Use this for split deployments where: +# - Server (Flask + Redis + PostgreSQL) runs on separate machine +# - This worker handles GPU-accelerated ML analysis tasks # +# Prerequisites: +# - Server running docker-compose-server.yaml on another machine +# - Network connectivity to server (ports 6379 Redis, 5432 PostgreSQL) +# - Access to same music files (via media server API or shared storage) +# - NVIDIA GPU with docker nvidia-runtime installed +# +# Quick Start: +# 1. Copy .env.example to .env on this worker machine +# 2. Configure connection to remote server: +# WORKER_REDIS_URL=redis://SERVER_IP:6379/0 +# WORKER_POSTGRES_HOST=SERVER_IP +# 3. Copy media server credentials from server's .env +# 4. Run: docker-compose -f docker-compose-worker-nvidia.yaml up -d +# +# Scaling: +# - Run multiple workers on different machines +# - All workers connect to same Redis queue +# - Tasks automatically distributed across workers +# ============================================================================= version: '3.8' + services: - # AudioMuse-AI Worker service (GPU-dependent) + # --------------------------------------------------------------------------- + # AudioMuse-AI Worker (NVIDIA GPU, connects to remote server) + # --------------------------------------------------------------------------- audiomuse-ai-worker: image: ghcr.io/neptunehub/audiomuse-ai:latest-nvidia - container_name: audiomuse-ai-worker-instance - ports: - - "${WORKER_PORT:-8029}:8000" # Expose worker API + container_name: audiomuse-ai-worker-nvidia environment: SERVICE_TYPE: "worker" TZ: "${TZ:-UTC}" - MEDIASERVER_TYPE: "jellyfin" - JELLYFIN_USER_ID: "${JELLYFIN_USER_ID}" - JELLYFIN_TOKEN: "${JELLYFIN_TOKEN}" - JELLYFIN_URL: "${JELLYFIN_URL}" + # ======================================================================= + # REMOTE SERVER CONNECTION (REQUIRED) + # ======================================================================= + # Redis URL - point to server machine + # Format: redis://[password@]host:port/db + # Examples: + # redis://192.168.1.100:6379/0 + # redis://:mypassword@192.168.1.100:6379/0 + REDIS_URL: "${WORKER_REDIS_URL:-redis://redis:6379/0}" + # PostgreSQL host - IP or hostname of server machine + POSTGRES_HOST: "${WORKER_POSTGRES_HOST:-postgres}" + POSTGRES_PORT: "${POSTGRES_PORT:-5432}" POSTGRES_USER: ${POSTGRES_USER:-audiomuse} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-audiomusepassword} POSTGRES_DB: ${POSTGRES_DB:-audiomusedb} - POSTGRES_HOST: "${WORKER_POSTGRES_HOST:-postgres}" # Replace via WORKER_POSTGRES_HOST in .env when running remotely - POSTGRES_PORT: "${POSTGRES_PORT:-5432}" - REDIS_URL: "${WORKER_REDIS_URL:-redis://redis:6379/0}" # Set WORKER_REDIS_URL in .env for remote connections - AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER}" - OPENAI_API_KEY: "${OPENAI_API_KEY}" - OPENAI_SERVER_URL: "${OPENAI_SERVER_URL}" - OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME}" - GEMINI_API_KEY: "${GEMINI_API_KEY}" - MISTRAL_API_KEY: "${MISTRAL_API_KEY}" - CLAP_ENABLED: "${CLAP_ENABLED:-true}" # Enable CLAP text search (set to false for slower systems) + # ======================================================================= + # MEDIA SERVER CONFIGURATION (must match server) + # ======================================================================= + MEDIASERVER_TYPE: "${MEDIASERVER_TYPE:-localfiles}" + JELLYFIN_URL: "${JELLYFIN_URL:-}" + JELLYFIN_USER_ID: "${JELLYFIN_USER_ID:-}" + JELLYFIN_TOKEN: "${JELLYFIN_TOKEN:-}" + NAVIDROME_URL: "${NAVIDROME_URL:-}" + NAVIDROME_USER: "${NAVIDROME_USER:-}" + NAVIDROME_PASSWORD: "${NAVIDROME_PASSWORD:-}" + LYRION_URL: "${LYRION_URL:-}" + MPD_HOST: "${MPD_HOST:-}" + MPD_PORT: "${MPD_PORT:-6600}" + MPD_PASSWORD: "${MPD_PASSWORD:-}" + MPD_MUSIC_DIRECTORY: "${MPD_MUSIC_DIRECTORY:-/music}" + EMBY_URL: "${EMBY_URL:-}" + EMBY_USER_ID: "${EMBY_USER_ID:-}" + EMBY_TOKEN: "${EMBY_TOKEN:-}" + LOCALFILES_MUSIC_DIRECTORY: "${LOCALFILES_MUSIC_DIRECTORY:-/music}" + LOCALFILES_PLAYLIST_DIR: "${LOCALFILES_PLAYLIST_DIR:-/music/playlists}" + # ======================================================================= + # AI CONFIGURATION (optional) + # ======================================================================= + AI_MODEL_PROVIDER: "${AI_MODEL_PROVIDER:-NONE}" + OPENAI_API_KEY: "${OPENAI_API_KEY:-}" + OPENAI_SERVER_URL: "${OPENAI_SERVER_URL:-}" + OPENAI_MODEL_NAME: "${OPENAI_MODEL_NAME:-}" + GEMINI_API_KEY: "${GEMINI_API_KEY:-}" + MISTRAL_API_KEY: "${MISTRAL_API_KEY:-}" + # ======================================================================= + # WORKER FEATURES (GPU-enabled) + # ======================================================================= + # CLAP text search - uses ~750MB RAM + CLAP_ENABLED: "${CLAP_ENABLED:-true}" + # GPU clustering - enabled for NVIDIA worker + USE_GPU_CLUSTERING: "${USE_GPU_CLUSTERING:-true}" + # Worker tuning + RQ_MAX_JOBS: "${RQ_MAX_JOBS:-50}" + RQ_MAX_JOBS_HIGH: "${RQ_MAX_JOBS_HIGH:-100}" + RQ_LOGGING_LEVEL: "${RQ_LOGGING_LEVEL:-INFO}" TEMP_DIR: "/app/temp_audio" - NVIDIA_VISIBLE_DEVICES: "0" + # NVIDIA GPU settings + NVIDIA_VISIBLE_DEVICES: "${NVIDIA_GPU_ID:-0}" NVIDIA_DRIVER_CAPABILITIES: "compute,utility" - USE_GPU_CLUSTERING: "${USE_GPU_CLUSTERING:-true}" volumes: - temp-audio-worker:/app/temp_audio + # Mount music directory if using local files provider + - ${MUSIC_PATH:-./music}:/music:ro restart: unless-stopped deploy: resources: reservations: devices: - driver: nvidia - device_ids: ["0"] + device_ids: ["${NVIDIA_GPU_ID:-0}"] capabilities: [gpu] + +# ============================================================================= +# Volumes +# ============================================================================= volumes: - temp-audio-worker: \ No newline at end of file + temp-audio-worker: + name: audiomuse-temp-worker-nvidia diff --git a/templates/setup.html b/templates/setup.html index 08f65f23..81f4189b 100644 --- a/templates/setup.html +++ b/templates/setup.html @@ -452,6 +452,47 @@ font-size: 0.85rem; cursor: pointer; } + + /* Worker Connection Info */ + .connection-info-box { + display: flex; + align-items: center; + gap: 0.5rem; + background: var(--bg-primary); + padding: 0.75rem; + border-radius: 4px; + border: 1px solid var(--border-color); + } + + .connection-info-box code { + flex: 1; + font-family: monospace; + font-size: 0.95rem; + word-break: break-all; + } + + .copy-btn { + padding: 0.25rem 0.75rem; + border: 1px solid var(--border-color); + border-radius: 4px; + background: var(--bg-secondary); + cursor: pointer; + font-size: 0.85rem; + } + + .copy-btn:hover { + background: var(--bg-primary); + } + + .copy-btn.copied { + background: rgba(40, 167, 69, 0.2); + border-color: #28a745; + color: #28a745; + } + + .worker-connection-info .config-field { + margin-bottom: 1.25rem; + } {% endblock %} @@ -496,6 +537,26 @@

    Welcome to AudioMuse-AI

    +
    +

    Select Deployment Mode

    +

    + Choose how you want to deploy AudioMuse-AI. +

    + +
    +
    +
    📦
    +
    Unified
    +
    Server and worker on the same machine. Best for most users.
    +
    +
    +
    🖧
    +
    Split
    +
    Run workers on separate machines. For distributed setups or dedicated GPU servers.
    +
    +
    +
    +

    Select Hardware Configuration

    @@ -588,6 +649,53 @@

    Analysis Settings

    + +
    Advanced Settings @@ -676,17 +784,21 @@

    What's Next?

    // State let currentStep = 1; let selectedHardware = 'cpu'; + let selectedDeployment = 'unified'; let selectedProviders = []; let providerConfigs = {}; let providerTypes = []; let existingProviders = []; + let serverInfo = { host: window.location.hostname, port: window.location.port || '8000' }; // Initialize document.addEventListener('DOMContentLoaded', async function() { await loadSetupStatus(); await loadProviderTypes(); + await loadServerInfo(); renderProviderGrid(); setupHardwareOptions(); + setupDeploymentOptions(); }); async function loadSetupStatus() { @@ -882,9 +994,9 @@

    Existing Installation Detected

    } function setupHardwareOptions() { - document.querySelectorAll('.hardware-option').forEach(option => { + document.querySelectorAll('.hardware-option[data-hardware]').forEach(option => { option.onclick = function() { - document.querySelectorAll('.hardware-option').forEach(o => o.classList.remove('selected')); + document.querySelectorAll('.hardware-option[data-hardware]').forEach(o => o.classList.remove('selected')); this.classList.add('selected'); selectedHardware = this.dataset.hardware; @@ -896,10 +1008,85 @@

    Existing Installation Detected

    gpuClustering.disabled = true; gpuClustering.checked = false; } + + // Update worker compose file recommendation + updateWorkerComposeRecommendation(); }; }); } + function setupDeploymentOptions() { + document.querySelectorAll('.hardware-option[data-deployment]').forEach(option => { + option.onclick = function() { + document.querySelectorAll('.hardware-option[data-deployment]').forEach(o => o.classList.remove('selected')); + this.classList.add('selected'); + selectedDeployment = this.dataset.deployment; + + // Show/hide worker config section + updateWorkerConfigVisibility(); + }; + }); + } + + async function loadServerInfo() { + try { + const response = await fetch('/api/setup/server-info'); + if (response.ok) { + const data = await response.json(); + serverInfo = data; + } + } catch (err) { + console.log('Could not load server info, using defaults'); + } + updateWorkerConnectionInfo(); + } + + function updateWorkerConfigVisibility() { + const workerSection = document.getElementById('worker-config-section'); + if (selectedDeployment === 'split') { + workerSection.style.display = 'block'; + updateWorkerConnectionInfo(); + } else { + workerSection.style.display = 'none'; + } + } + + function updateWorkerConnectionInfo() { + const serverIp = serverInfo.host || window.location.hostname; + const redisPort = serverInfo.redis_port || '6379'; + const postgresPort = serverInfo.postgres_port || '5432'; + + document.getElementById('worker-redis-url').textContent = `redis://${serverIp}:${redisPort}/0`; + document.getElementById('worker-postgres-host').textContent = serverIp; + updateWorkerComposeRecommendation(); + } + + function updateWorkerComposeRecommendation() { + const composeFile = selectedHardware === 'nvidia' + ? 'docker-compose-worker-nvidia.yaml' + : 'docker-compose-worker-cpu.yaml'; + document.getElementById('worker-compose-file').textContent = composeFile; + document.getElementById('worker-compose-help').innerHTML = + `Run on worker machine: docker-compose -f ${composeFile} up -d`; + } + + function copyToClipboard(elementId) { + const element = document.getElementById(elementId); + const text = element.textContent; + + navigator.clipboard.writeText(text).then(() => { + const btn = element.parentElement.querySelector('.copy-btn'); + btn.textContent = 'Copied!'; + btn.classList.add('copied'); + setTimeout(() => { + btn.textContent = 'Copy'; + btn.classList.remove('copied'); + }, 2000); + }).catch(err => { + console.error('Failed to copy:', err); + }); + } + function toggleAdvanced() { const content = document.getElementById('advanced-settings'); const arrow = document.getElementById('advanced-arrow'); @@ -958,7 +1145,15 @@

    Existing Installation Detected

    return p ? p.name : type; }).join(', '); + let deploymentInfo = selectedDeployment === 'unified' + ? 'Unified (server + worker on same machine)' + : 'Split (workers on separate machines)'; + summary.innerHTML = ` +
    + Deployment Mode + ${deploymentInfo} +
    Hardware ${selectedHardware === 'nvidia' ? 'NVIDIA GPU' : 'CPU Only'} @@ -1005,6 +1200,7 @@

    Existing Installation Detected

    method: 'PUT', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ + deployment_type: selectedDeployment, hardware_type: selectedHardware, clap_enabled: document.getElementById('setting-clap-enabled').checked, gpu_clustering: document.getElementById('setting-gpu-clustering').checked, From 9b2d8215bf60763f262be22360c0eb22d7968084 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 4 Feb 2026 04:46:47 +0000 Subject: [PATCH 04/33] Clarify unified compose files for CPU and NVIDIA in setup wizard - Add Docker Compose file recommendation to setup summary - Show correct unified file based on hardware selection - Clarify CPU vs NVIDIA image usage in .env.example comments https://claude.ai/code/session_011AebTWAucDafK4m6uoSSNg --- deployment/.env.example | 10 ++++++---- templates/setup.html | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/deployment/.env.example b/deployment/.env.example index fc8867a4..bc5b2bc1 100644 --- a/deployment/.env.example +++ b/deployment/.env.example @@ -103,12 +103,14 @@ TZ=UTC # # DEPLOYMENT OPTIONS: # 1. UNIFIED (default) - Server and worker on same machine -# Use: docker-compose-unified.yaml or docker-compose-unified-nvidia.yaml -# Leave these settings at defaults +# CPU: docker-compose-unified.yaml (uses :latest image) +# NVIDIA: docker-compose-unified-nvidia.yaml (uses :latest-nvidia image) +# Leave WORKER_* settings at defaults # # 2. SPLIT - Server and worker on different machines -# Server machine: Run docker-compose-server.yaml -# Worker machine: Run docker-compose-worker-cpu.yaml or docker-compose-worker-nvidia.yaml +# Server machine: docker-compose-server.yaml +# Worker machine: docker-compose-worker-cpu.yaml (CPU) or +# docker-compose-worker-nvidia.yaml (NVIDIA GPU) # Configure WORKER_* settings below on worker machine # # On the WORKER machine, set these to point to the SERVER machine: diff --git a/templates/setup.html b/templates/setup.html index 81f4189b..451b21f4 100644 --- a/templates/setup.html +++ b/templates/setup.html @@ -1149,6 +1149,19 @@

    Existing Installation Detected

    ? 'Unified (server + worker on same machine)' : 'Split (workers on separate machines)'; + // Determine recommended compose files + let composeFiles = []; + if (selectedDeployment === 'unified') { + composeFiles.push(selectedHardware === 'nvidia' + ? 'docker-compose-unified-nvidia.yaml' + : 'docker-compose-unified.yaml'); + } else { + composeFiles.push('docker-compose-server.yaml (this machine)'); + composeFiles.push(selectedHardware === 'nvidia' + ? 'docker-compose-worker-nvidia.yaml (workers)' + : 'docker-compose-worker-cpu.yaml (workers)'); + } + summary.innerHTML = `
    Deployment Mode @@ -1158,6 +1171,10 @@

    Existing Installation Detected

    Hardware ${selectedHardware === 'nvidia' ? 'NVIDIA GPU' : 'CPU Only'}
    +
    + Docker Compose Files + ${composeFiles.join('
    ')}
    +
    Providers ${providerNames || 'None selected'} From 9287f36ee975850ed89b6fecee0d99dcfd895227 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 4 Feb 2026 07:27:31 +0000 Subject: [PATCH 05/33] Add dedicated Settings page for configuration management - Create settings.html with collapsible sections for: - Media Providers (list, add, edit, delete, set primary) - Deployment (unified/split mode, hardware type, worker info) - Analysis (CLAP, GPU clustering) - AI Integration (playlist naming provider) - Add /settings route in app_setup.py - Update sidebar navigation with Settings link - Keep Setup Wizard as separate option for initial configuration The Settings page provides quick access to modify individual settings without going through the full setup wizard flow. https://claude.ai/code/session_011AebTWAucDafK4m6uoSSNg --- app_setup.py | 6 + templates/settings.html | 1184 +++++++++++++++++++++++++++++++++++ templates/sidebar_navi.html | 3 +- 3 files changed, 1192 insertions(+), 1 deletion(-) create mode 100644 templates/settings.html diff --git a/app_setup.py b/app_setup.py index 507441be..f857f5e8 100644 --- a/app_setup.py +++ b/app_setup.py @@ -278,6 +278,12 @@ def setup_page(): return render_template('setup.html', title='AudioMuse-AI - Setup', active='setup') +@setup_bp.route('/settings') +def settings_page(): + """Render the settings page.""" + return render_template('settings.html', title='AudioMuse-AI - Settings', active='settings') + + @setup_bp.route('/api/setup/status', methods=['GET']) def get_setup_status(): """ diff --git a/templates/settings.html b/templates/settings.html new file mode 100644 index 00000000..1116c7c1 --- /dev/null +++ b/templates/settings.html @@ -0,0 +1,1184 @@ +{% extends "includes/layout.html" %} + +{% block headAdditions %} + +{% endblock %} + +{% block content %} +
    +
    +

    Settings

    +

    Configure AudioMuse-AI settings and manage providers

    +
    + + +
    +
    +

    📦 Media Providers

    + +
    +
    +

    + Configure your music library sources. Multiple providers can share analysis data through file path linking. +

    + +
    + + +
    Used by default for API calls without provider specification
    +
    + +
    + +
    + + +
    +
    + + +
    +
    +

    Deployment

    + +
    +
    +

    + Configure how AudioMuse-AI is deployed across your infrastructure. +

    + +
    + +
    +
    +
    📦
    +
    Unified
    +
    Server and worker on same machine
    +
    +
    +
    🖧
    +
    Split
    +
    Workers on separate machines
    +
    +
    +
    + +
    + +
    +
    +
    💻
    +
    CPU Only
    +
    Standard processing
    +
    +
    +
    +
    NVIDIA GPU
    +
    CUDA acceleration
    +
    +
    +
    + + +
    +
    + + +
    +
    +

    📊 Analysis

    + +
    +
    +

    + Configure music analysis features and performance options. +

    + +
    + +
    Search music using natural language. Uses ~750MB additional memory.
    +
    + +
    + +
    Accelerate clustering with NVIDIA RAPIDS cuML. Requires NVIDIA GPU.
    +
    +
    +
    + + +
    +
    +

    🤖 AI Integration

    + +
    +
    +

    + Configure AI providers for creative playlist naming and other features. +

    + +
    + + +
    AI service for generating creative playlist names
    +
    +
    +
    + + +
    + +
    +
    + + + + + +
    +{% endblock %} + +{% block bodyAdditions %} + + +{% endblock %} diff --git a/templates/sidebar_navi.html b/templates/sidebar_navi.html index 1f8932b9..ef5ea47d 100644 --- a/templates/sidebar_navi.html +++ b/templates/sidebar_navi.html @@ -12,5 +12,6 @@
  • Cleaning
  • Scheduled Tasks
  • -
  • Setup & Providers
  • +
  • Settings
  • +
  • Setup Wizard
  • \ No newline at end of file From ab3b39ac6064c3088f2f913979d1948172967c01 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 4 Feb 2026 07:44:00 +0000 Subject: [PATCH 06/33] Add multi-provider playlist creation with provider selector dropdown - Add provider selector dropdown to all playlist forms - Create shared provider-selector.js component for consistent UI - Add /api/providers/enabled endpoint for fetching available providers - Update create_playlist_from_ids to support provider_ids parameter - Add create_playlist_multi_provider for creating on multiple providers - Add get_all_playlists_multi_provider with deduplication - Update all templates: similarity, path, clap_search, mulan_search, sonic_fingerprint, alchemy, artist_similarity, map, chat - Support 'all' option to create playlist on all enabled providers https://claude.ai/code/session_011AebTWAucDafK4m6uoSSNg --- app_chat.py | 26 +++-- app_voyager.py | 50 +++++++-- static/provider-selector.js | 179 ++++++++++++++++++++++++++++++ tasks/mediaserver.py | 185 +++++++++++++++++++++++++++++++ tasks/voyager_manager.py | 41 ++++++- templates/alchemy.html | 8 +- templates/artist_similarity.html | 23 ++-- templates/chat.html | 6 + templates/clap_search.html | 27 +++-- templates/map.html | 8 +- templates/mulan_search.html | 18 ++- templates/path.html | 19 +++- templates/similarity.html | 21 +++- templates/sonic_fingerprint.html | 6 + 14 files changed, 563 insertions(+), 54 deletions(-) create mode 100644 static/provider-selector.js diff --git a/app_chat.py b/app_chat.py index 4f8a5242..4ee17e84 100644 --- a/app_chat.py +++ b/app_chat.py @@ -750,7 +750,7 @@ def create_media_server_playlist_api(): API endpoint to create a playlist on the configured media server. """ # Local import to break circular dependency at startup - from tasks.mediaserver import create_instant_playlist + from tasks.voyager_manager import create_playlist_from_ids data = request.get_json() if not data or 'playlist_name' not in data or 'item_ids' not in data: @@ -758,6 +758,7 @@ def create_media_server_playlist_api(): user_playlist_name = data.get('playlist_name') item_ids = data.get('item_ids') # This will be a list of strings + provider_ids = data.get('provider_ids') # Can be 'all', int, or list of ints if not user_playlist_name.strip(): return jsonify({"message": "Error: Playlist name cannot be empty."}), 400 @@ -765,14 +766,21 @@ def create_media_server_playlist_api(): return jsonify({"message": "Error: No songs provided to create the playlist."}), 400 try: - # MODIFIED: Call the simplified create_instant_playlist function - created_playlist_info = create_instant_playlist(user_playlist_name, item_ids) - - if not created_playlist_info: - raise Exception("Media server did not return playlist information after creation.") - - # The created_playlist_info is the full JSON response from the media server - return jsonify({"message": f"Successfully created playlist '{user_playlist_name}' on the media server with ID: {created_playlist_info.get('Id')}"}), 200 + # Use the voyager_manager function that supports multi-provider + result = create_playlist_from_ids(user_playlist_name, item_ids, provider_ids=provider_ids) + + # Handle multi-provider result (dict) vs single provider result (string) + if isinstance(result, dict): + # Multi-provider response + success_count = sum(1 for r in result.values() if r.get('success')) + total_count = len(result) + return jsonify({ + "message": f"Playlist '{user_playlist_name}' created on {success_count}/{total_count} provider(s).", + "results": result + }), 200 + else: + # Single provider response + return jsonify({"message": f"Successfully created playlist '{user_playlist_name}' on the media server with ID: {result}"}), 200 except Exception as e: # Log detailed error on the server diff --git a/app_voyager.py b/app_voyager.py index b1837a58..9896d0c0 100644 --- a/app_voyager.py +++ b/app_voyager.py @@ -321,6 +321,8 @@ def create_media_server_playlist(): items: type: string description: A list of track Item IDs to add to the playlist. + provider_ids: + description: Provider(s) to create playlist on. Can be 'all', a single ID, or array of IDs. responses: 201: description: Playlist created successfully. @@ -341,6 +343,7 @@ def create_media_server_playlist(): playlist_name = data.get('playlist_name') track_ids_raw = data.get('track_ids', []) + provider_ids = data.get('provider_ids') # Can be 'all', int, or list of ints if not playlist_name: return jsonify({"error": "Missing 'playlist_name'"}), 400 @@ -364,15 +367,46 @@ def create_media_server_playlist(): user_creds = data.get('user_creds') if isinstance(data, dict) else None try: - new_playlist_id = create_playlist_from_ids(playlist_name, final_track_ids, user_creds=user_creds) - - logger.info(f"Successfully created playlist '{playlist_name}' with ID {new_playlist_id}.") - - return jsonify({ - "message": f"Playlist '{playlist_name}' created successfully!", - "playlist_id": new_playlist_id - }), 201 + result = create_playlist_from_ids(playlist_name, final_track_ids, user_creds=user_creds, provider_ids=provider_ids) + + # Handle multi-provider result (dict) vs single provider result (string) + if isinstance(result, dict): + # Multi-provider response + success_count = sum(1 for r in result.values() if r.get('success')) + total_count = len(result) + logger.info(f"Created playlist '{playlist_name}' on {success_count}/{total_count} providers.") + return jsonify({ + "message": f"Playlist '{playlist_name}' created on {success_count}/{total_count} provider(s).", + "results": result + }), 201 + else: + # Single provider response (backward compatible) + logger.info(f"Successfully created playlist '{playlist_name}' with ID {result}.") + return jsonify({ + "message": f"Playlist '{playlist_name}' created successfully!", + "playlist_id": result + }), 201 except Exception as e: logger.error(f"Failed to create media server playlist '{playlist_name}': {e}", exc_info=True) return jsonify({"error": "An error occurred while creating the playlist on the media server."}), 500 + + +@voyager_bp.route('/api/providers/enabled', methods=['GET']) +def get_enabled_providers(): + """ + Get list of enabled providers for playlist creation dropdown. + --- + tags: + - Providers + responses: + 200: + description: List of enabled providers + """ + try: + from tasks.mediaserver import get_enabled_providers_for_playlists + providers = get_enabled_providers_for_playlists() + return jsonify(providers), 200 + except Exception as e: + logger.error(f"Failed to get enabled providers: {e}", exc_info=True) + return jsonify([]), 200 diff --git a/static/provider-selector.js b/static/provider-selector.js new file mode 100644 index 00000000..754cfdb4 --- /dev/null +++ b/static/provider-selector.js @@ -0,0 +1,179 @@ +/** + * Provider Selector Component for Multi-Provider Playlist Support + * + * Usage: + * 1. Include this script in your template + * 2. Add a container div:
    + * 3. Call initProviderSelector() after DOM is loaded + * 4. Get selected value with getSelectedProviders() when creating playlist + */ + +let _providers = []; +let _selectedProviderValue = null; // null = primary/default, 'all' = all providers, number = specific provider + +/** + * Initialize the provider selector component. + * Fetches enabled providers and renders the dropdown. + * + * @param {string} containerId - ID of the container element + * @param {object} options - Configuration options + * @param {boolean} options.showAllOption - Whether to show "All Providers" option (default: true) + * @param {boolean} options.showLabel - Whether to show label (default: true) + * @param {string} options.labelText - Label text (default: "Save to:") + */ +async function initProviderSelector(containerId = 'provider-selector-container', options = {}) { + const container = document.getElementById(containerId); + if (!container) { + console.warn(`Provider selector container '${containerId}' not found`); + return; + } + + const showAllOption = options.showAllOption !== false; + const showLabel = options.showLabel !== false; + const labelText = options.labelText || 'Save to:'; + + try { + const response = await fetch('/api/providers/enabled'); + _providers = await response.json(); + } catch (err) { + console.error('Failed to load providers:', err); + _providers = []; + } + + // Only show selector if there are multiple providers or showAllOption is true + if (_providers.length <= 1 && !showAllOption) { + container.style.display = 'none'; + return; + } + + // Build the selector HTML + let html = '
    '; + + if (showLabel) { + html += ``; + } + + html += '
    '; + + container.innerHTML = html; + + // Add event listener + const select = document.getElementById('provider-select'); + if (select) { + select.addEventListener('change', function() { + const value = this.value; + if (value === '') { + _selectedProviderValue = null; + } else if (value === 'all') { + _selectedProviderValue = 'all'; + } else { + _selectedProviderValue = parseInt(value, 10); + } + }); + } +} + +/** + * Get the currently selected provider value. + * + * @returns {null|string|number} null for primary, 'all' for all, or provider ID + */ +function getSelectedProviders() { + return _selectedProviderValue; +} + +/** + * Get the list of loaded providers. + * + * @returns {Array} List of provider objects + */ +function getProviderList() { + return _providers; +} + +/** + * Check if multiple providers are available. + * + * @returns {boolean} + */ +function hasMultipleProviders() { + return _providers.length > 1; +} + +/** + * Add provider_ids to a playlist creation payload. + * + * @param {object} payload - The existing payload object + * @returns {object} Payload with provider_ids added if applicable + */ +function addProviderToPayload(payload) { + const selected = getSelectedProviders(); + if (selected !== null) { + payload.provider_ids = selected; + } + return payload; +} + +// CSS styles for the provider selector +const providerSelectorStyles = ` +.provider-selector { + display: flex; + align-items: center; + gap: 0.5rem; + margin-bottom: 0.75rem; +} + +.provider-selector label { + font-weight: 500; + font-size: 0.9rem; + white-space: nowrap; +} + +.provider-select { + padding: 0.4rem 0.75rem; + border: 1px solid var(--border-color, #ccc); + border-radius: 4px; + background: var(--bg-primary, #fff); + color: var(--text-color, #333); + font-size: 0.9rem; + min-width: 150px; +} + +.provider-select:focus { + outline: none; + border-color: var(--primary-color, #007bff); +} + +/* Compact variant for inline use */ +.provider-selector.compact { + margin-bottom: 0; +} + +.provider-selector.compact label { + font-size: 0.85rem; +} + +.provider-selector.compact .provider-select { + padding: 0.3rem 0.5rem; + font-size: 0.85rem; + min-width: 120px; +} +`; + +// Inject styles when script loads +(function() { + const styleEl = document.createElement('style'); + styleEl.textContent = providerSelectorStyles; + document.head.appendChild(styleEl); +})(); diff --git a/tasks/mediaserver.py b/tasks/mediaserver.py index ce6c3752..6fd99418 100644 --- a/tasks/mediaserver.py +++ b/tasks/mediaserver.py @@ -566,3 +566,188 @@ def _get_provider_config_fields(provider_type: str): } return fields.get(provider_type, []) + +# ############################################################################## +# MULTI-PROVIDER PLAYLIST FUNCTIONS +# ############################################################################## + +def get_all_playlists_multi_provider(provider_ids=None): + """ + Get playlists from multiple providers with deduplication. + + Args: + provider_ids: List of provider IDs to query, or None for all enabled providers + + Returns: + List of playlists with provider info, deduplicated by name + """ + from app_helper import get_providers, get_provider_by_id + + all_playlists = [] + seen_names = {} # Track playlist names to detect duplicates + + # Get providers to query + if provider_ids is None: + providers = get_providers(enabled_only=True) + else: + providers = [get_provider_by_id(pid) for pid in provider_ids if get_provider_by_id(pid)] + + for provider in providers: + try: + provider_type = provider['provider_type'] + playlists = _get_playlists_for_provider_type(provider_type) + + for playlist in playlists: + playlist_name = playlist.get('Name') or playlist.get('name', '') + playlist_id = playlist.get('Id') or playlist.get('id', '') + + # Add provider info to playlist + playlist['provider_id'] = provider['id'] + playlist['provider_type'] = provider_type + playlist['provider_name'] = provider.get('name', provider_type) + + # Check for duplicates by name + if playlist_name in seen_names: + # Mark as duplicate + playlist['is_duplicate'] = True + playlist['duplicate_of_provider'] = seen_names[playlist_name] + else: + playlist['is_duplicate'] = False + seen_names[playlist_name] = provider['id'] + + all_playlists.append(playlist) + + except Exception as e: + logger.warning(f"Failed to get playlists from provider {provider.get('name', 'unknown')}: {e}") + continue + + return all_playlists + + +def _get_playlists_for_provider_type(provider_type): + """Get playlists for a specific provider type using current config.""" + if provider_type == 'jellyfin': + return jellyfin_get_all_playlists() + elif provider_type == 'navidrome': + return navidrome_get_all_playlists() + elif provider_type == 'lyrion': + return lyrion_get_all_playlists() + elif provider_type == 'mpd': + return mpd_get_all_playlists() + elif provider_type == 'emby': + return emby_get_all_playlists() + elif provider_type == 'localfiles': + return localfiles_get_all_playlists() + return [] + + +def create_playlist_multi_provider(playlist_name, item_ids, provider_ids=None, user_creds=None): + """ + Create a playlist on one or more providers. + + Args: + playlist_name: Name of the playlist to create + item_ids: List of track IDs to add + provider_ids: List of provider IDs to create playlist on, + 'all' for all enabled providers, + or None for the primary/default provider + user_creds: Optional user credentials for providers that support them + + Returns: + Dict with results for each provider: {provider_id: {'success': bool, 'playlist_id': str, 'error': str}} + """ + from app_helper import get_providers, get_provider_by_id, get_primary_provider_id + + if not playlist_name: + raise ValueError("Playlist name is required") + if not item_ids: + raise ValueError("Track IDs are required") + + results = {} + + # Determine which providers to use + if provider_ids == 'all': + providers = get_providers(enabled_only=True) + elif provider_ids is None: + # Use primary provider or fall back to current config + primary_id = get_primary_provider_id() + if primary_id: + provider = get_provider_by_id(primary_id) + providers = [provider] if provider else [] + else: + # Fall back to creating on current configured provider + try: + created = create_instant_playlist(playlist_name, item_ids, user_creds=user_creds) + return {'default': {'success': True, 'playlist_id': created.get('Id') if created else None}} + except Exception as e: + return {'default': {'success': False, 'error': str(e)}} + else: + # Specific provider IDs + if isinstance(provider_ids, (list, tuple)): + providers = [get_provider_by_id(pid) for pid in provider_ids if get_provider_by_id(pid)] + else: + provider = get_provider_by_id(provider_ids) + providers = [provider] if provider else [] + + # Create playlist on each provider + for provider in providers: + provider_id = provider['id'] + provider_type = provider['provider_type'] + + try: + # For now, use the dispatcher which uses current config + # In the future, we may want provider-specific config + created = _create_playlist_for_provider_type(provider_type, playlist_name, item_ids, user_creds) + + results[provider_id] = { + 'success': True, + 'playlist_id': created.get('Id') or created.get('id') if created else None, + 'provider_name': provider.get('name', provider_type) + } + except Exception as e: + logger.error(f"Failed to create playlist on provider {provider.get('name')}: {e}") + results[provider_id] = { + 'success': False, + 'error': str(e), + 'provider_name': provider.get('name', provider_type) + } + + return results + + +def _create_playlist_for_provider_type(provider_type, playlist_name, item_ids, user_creds=None): + """Create playlist on a specific provider type.""" + if provider_type == 'jellyfin': + return jellyfin_create_instant_playlist(playlist_name, item_ids, user_creds) + elif provider_type == 'navidrome': + return navidrome_create_instant_playlist(playlist_name, item_ids, user_creds) + elif provider_type == 'lyrion': + return lyrion_create_instant_playlist(playlist_name, item_ids) + elif provider_type == 'mpd': + return mpd_create_instant_playlist(playlist_name, item_ids, user_creds) + elif provider_type == 'emby': + return emby_create_instant_playlist(playlist_name, item_ids, user_creds) + elif provider_type == 'localfiles': + return localfiles_create_instant_playlist(playlist_name, item_ids, user_creds) + else: + raise ValueError(f"Unknown provider type: {provider_type}") + + +def get_enabled_providers_for_playlists(): + """ + Get list of enabled providers for use in playlist dropdowns. + + Returns: + List of dicts with 'id', 'name', 'type' for each enabled provider + """ + from app_helper import get_providers + + providers = get_providers(enabled_only=True) + return [ + { + 'id': p['id'], + 'name': p.get('name') or p['provider_type'], + 'type': p['provider_type'] + } + for p in providers + ] diff --git a/tasks/voyager_manager.py b/tasks/voyager_manager.py index ac0f9909..30fedf14 100644 --- a/tasks/voyager_manager.py +++ b/tasks/voyager_manager.py @@ -1647,15 +1647,46 @@ def search_tracks_by_title_and_artist(title_query: str, artist_query: str, limit return results -def create_playlist_from_ids(playlist_name: str, track_ids: list, user_creds: dict = None): +def create_playlist_from_ids(playlist_name: str, track_ids: list, user_creds: dict = None, provider_ids=None): """ - Creates a new playlist on the configured media server with the provided name and track IDs. + Creates a new playlist on the configured media server(s) with the provided name and track IDs. + + Args: + playlist_name: Name of the playlist + track_ids: List of track IDs + user_creds: Optional user credentials + provider_ids: Provider(s) to create playlist on: + - None: Use primary provider or default config + - 'all': Create on all enabled providers + - int: Single provider ID + - list[int]: Multiple provider IDs + + Returns: + If single provider: playlist_id (str) + If multiple providers: dict of {provider_id: {'success': bool, 'playlist_id': str, 'error': str}} """ try: - # Use the mediaserver dispatcher (imported at module top) to create the playlist. - # This avoids importing app_external which may not export the helper. + from tasks.mediaserver import create_playlist_multi_provider + + # Use multi-provider function if provider_ids is specified as 'all' or a list + if provider_ids == 'all' or isinstance(provider_ids, (list, tuple)): + return create_playlist_multi_provider(playlist_name, track_ids, provider_ids, user_creds) + + # Single provider specified + if provider_ids is not None: + results = create_playlist_multi_provider(playlist_name, track_ids, provider_ids, user_creds) + # Extract single result + if results: + result = list(results.values())[0] + if result.get('success'): + return result.get('playlist_id') + else: + raise Exception(result.get('error', 'Playlist creation failed')) + raise Exception("No provider found") + + # Default: use existing single-provider logic for backward compatibility created_playlist = create_instant_playlist(playlist_name, track_ids, user_creds=user_creds) - + if not created_playlist: raise Exception("Playlist creation failed. The media server did not return a playlist object.") diff --git a/templates/alchemy.html b/templates/alchemy.html index 0286d19a..4aca6d41 100644 --- a/templates/alchemy.html +++ b/templates/alchemy.html @@ -143,6 +143,7 @@

    Create a Playlist from Results

    +
    @@ -150,8 +151,11 @@

    Create a Playlist from Results

    {% endblock %} {% block bodyAdditions %} + + {% endblock %} From 0342a9d7d6ab8bb7efe0b1b73645f540d9f3e369 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 5 Feb 2026 09:26:44 +0000 Subject: [PATCH 14/33] Add comprehensive testing and comparison suite for dual-instance comparison Builds a complete testing framework that connects to two AudioMuse-AI instances (e.g., main vs feature branch) via API, PostgreSQL, Docker, and compares results across all dimensions: - Database comparator: schema validation (17 tables), row counts, data quality (NULL rates, duplicates, mood_vector format), embedding integrity (dimensions, coverage, NaN checks), referential integrity, score distributions, playlist quality, index presence, task health, provider config, and app settings comparison - API comparator: tests 30+ endpoints including config, playlists, search, similarity, map, CLAP, alchemy, path finding, sonic fingerprint, artist similarity, setup/providers, cron, external API, and error handling - comparing status codes, response shapes, keys, and list lengths between instances - Docker comparator: container health/status, restart counts, resource usage (memory/CPU), log error pattern analysis (tracebacks, OOM, timeouts, DB errors), warning detection, and service connectivity tests for Redis and PostgreSQL - Performance comparator: endpoint latency benchmarks (p50/p95/p99/mean) with warmup, concurrent load testing with configurable users, database query performance benchmarks for 8 critical queries - Existing test integration: discovers and runs all 17 unit test files, 2 integration tests, and 8 E2E API tests from the existing test suite, with per-file results and instance-specific E2E execution - HTML report generator: self-contained dark-themed report with status badges, per-category expandable sections, filterable tables, side-by- side instance comparison, and visual performance bar charts - CLI with full argument support, YAML config files, environment variable configuration, test category selection (--only/--skip), and --discover mode listing all 27 available tests https://claude.ai/code/session_0122SF3fSXM3e2dNqaJB5NDn --- testing_suite/__init__.py | 9 + testing_suite/__main__.py | 5 + testing_suite/comparators/__init__.py | 0 testing_suite/comparators/api_comparator.py | 760 +++++++++++ testing_suite/comparators/db_comparator.py | 1142 +++++++++++++++++ .../comparators/docker_comparator.py | 541 ++++++++ .../comparators/performance_comparator.py | 447 +++++++ testing_suite/comparison_config.example.yaml | 91 ++ testing_suite/config.py | 211 +++ testing_suite/orchestrator.py | 163 +++ testing_suite/reports/__init__.py | 0 testing_suite/reports/html_report.py | 359 ++++++ testing_suite/requirements.txt | 7 + testing_suite/run_comparison.py | 325 +++++ testing_suite/test_runner/__init__.py | 0 testing_suite/test_runner/existing_tests.py | 469 +++++++ testing_suite/utils.py | 434 +++++++ 17 files changed, 4963 insertions(+) create mode 100644 testing_suite/__init__.py create mode 100644 testing_suite/__main__.py create mode 100644 testing_suite/comparators/__init__.py create mode 100644 testing_suite/comparators/api_comparator.py create mode 100644 testing_suite/comparators/db_comparator.py create mode 100644 testing_suite/comparators/docker_comparator.py create mode 100644 testing_suite/comparators/performance_comparator.py create mode 100644 testing_suite/comparison_config.example.yaml create mode 100644 testing_suite/config.py create mode 100644 testing_suite/orchestrator.py create mode 100644 testing_suite/reports/__init__.py create mode 100644 testing_suite/reports/html_report.py create mode 100644 testing_suite/requirements.txt create mode 100644 testing_suite/run_comparison.py create mode 100644 testing_suite/test_runner/__init__.py create mode 100644 testing_suite/test_runner/existing_tests.py create mode 100644 testing_suite/utils.py diff --git a/testing_suite/__init__.py b/testing_suite/__init__.py new file mode 100644 index 00000000..9f60c65e --- /dev/null +++ b/testing_suite/__init__.py @@ -0,0 +1,9 @@ +# AudioMuse-AI Testing & Comparison Suite +# Compares two instances (e.g., main branch vs feature branch) across: +# - API endpoints (results, response shapes, status codes) +# - Database quality (schema, data integrity, embeddings, track counts) +# - Docker container health and logs +# - Performance benchmarks (latency, throughput) +# - Existing unit and integration tests + +__version__ = "1.0.0" diff --git a/testing_suite/__main__.py b/testing_suite/__main__.py new file mode 100644 index 00000000..018ad6bd --- /dev/null +++ b/testing_suite/__main__.py @@ -0,0 +1,5 @@ +"""Allow running the testing suite as a module: python -m testing_suite""" +from testing_suite.run_comparison import main +import sys + +sys.exit(main()) diff --git a/testing_suite/comparators/__init__.py b/testing_suite/comparators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testing_suite/comparators/api_comparator.py b/testing_suite/comparators/api_comparator.py new file mode 100644 index 00000000..fecb265d --- /dev/null +++ b/testing_suite/comparators/api_comparator.py @@ -0,0 +1,760 @@ +""" +API Comparison Module for AudioMuse-AI Testing Suite. + +Tests all API endpoints on both instances and compares: + - HTTP status codes + - Response shapes (keys, types, list lengths) + - Response content (values where deterministic) + - Error handling and edge cases + - Endpoint availability + - Task lifecycle (start -> poll -> success) +""" + +import json +import logging +import time +import warnings +from typing import Any, Dict, List, Optional, Tuple + +from testing_suite.config import ComparisonConfig, InstanceConfig +from testing_suite.utils import ( + ComparisonReport, TestResult, TestStatus, + http_get, http_post, timed_request, wait_for_task_success, pct_diff +) + +logger = logging.getLogger(__name__) + + +class APIComparator: + """Tests and compares API endpoints across two AudioMuse-AI instances.""" + + def __init__(self, config: ComparisonConfig): + self.config = config + self.url_a = config.instance_a.api_url.rstrip('/') + self.url_b = config.instance_b.api_url.rstrip('/') + self.name_a = config.instance_a.name + self.name_b = config.instance_b.name + self.timeout = config.api_timeout + self.retries = config.api_retries + self.retry_delay = config.api_retry_delay + + def run_all(self, report: ComparisonReport): + """Run all API comparison tests.""" + logger.info("Starting API comparison tests...") + + # Check connectivity first + alive_a = self._check_alive(self.url_a) + alive_b = self._check_alive(self.url_b) + + report.add_result(TestResult( + category="api", + name="Instance A connectivity", + status=TestStatus.PASS if alive_a else TestStatus.ERROR, + message=f"{self.url_a}: {'reachable' if alive_a else 'unreachable'}", + instance_a_value=alive_a, + )) + report.add_result(TestResult( + category="api", + name="Instance B connectivity", + status=TestStatus.PASS if alive_b else TestStatus.ERROR, + message=f"{self.url_b}: {'reachable' if alive_b else 'unreachable'}", + instance_b_value=alive_b, + )) + + if not alive_a and not alive_b: + report.add_result(TestResult( + category="api", + name="API Tests", + status=TestStatus.ERROR, + message="Neither instance is reachable; skipping all API tests", + )) + return + + # Run endpoint tests + self._test_config_endpoint(report, alive_a, alive_b) + self._test_playlists_endpoint(report, alive_a, alive_b) + self._test_active_tasks_endpoint(report, alive_a, alive_b) + self._test_last_task_endpoint(report, alive_a, alive_b) + self._test_search_tracks_endpoint(report, alive_a, alive_b) + self._test_similar_tracks_endpoint(report, alive_a, alive_b) + self._test_max_distance_endpoint(report, alive_a, alive_b) + self._test_map_endpoint(report, alive_a, alive_b) + self._test_map_cache_status(report, alive_a, alive_b) + self._test_clap_stats(report, alive_a, alive_b) + self._test_clap_warmup_status(report, alive_a, alive_b) + self._test_clap_top_queries(report, alive_a, alive_b) + self._test_setup_status(report, alive_a, alive_b) + self._test_setup_providers(report, alive_a, alive_b) + self._test_setup_settings(report, alive_a, alive_b) + self._test_setup_server_info(report, alive_a, alive_b) + self._test_provider_types(report, alive_a, alive_b) + self._test_providers_enabled(report, alive_a, alive_b) + self._test_cron_entries(report, alive_a, alive_b) + self._test_waveform_endpoint(report, alive_a, alive_b) + self._test_find_path_endpoint(report, alive_a, alive_b) + self._test_sonic_fingerprint(report, alive_a, alive_b) + self._test_alchemy_endpoint(report, alive_a, alive_b) + self._test_artist_projections(report, alive_a, alive_b) + self._test_search_artists(report, alive_a, alive_b) + self._test_external_search(report, alive_a, alive_b) + self._test_chat_config_defaults(report, alive_a, alive_b) + self._test_error_handling(report, alive_a, alive_b) + self._test_collection_last_task(report, alive_a, alive_b) + + logger.info("API comparison tests complete.") + + # ------------------------------------------------------------------ + # Connectivity + # ------------------------------------------------------------------ + + def _check_alive(self, url: str) -> bool: + """Check if an instance is reachable.""" + try: + resp = http_get(f"{url}/api/config", timeout=15, retries=2, retry_delay=1) + return resp.status_code == 200 + except Exception: + return False + + # ------------------------------------------------------------------ + # Helper: compare GET endpoint on both instances + # ------------------------------------------------------------------ + + def _compare_get(self, report: ComparisonReport, path: str, test_name: str, + params: dict = None, alive_a: bool = True, alive_b: bool = True, + expected_status: int = 200, check_keys: list = None, + compare_list_length: bool = False): + """ + Hit a GET endpoint on both instances and compare the results. + Adds test results to the report. + """ + t0 = time.time() + resp_a = resp_b = None + data_a = data_b = None + + try: + if alive_a: + resp_a, lat_a = timed_request("GET", f"{self.url_a}{path}", + params=params, timeout=self.timeout, + retries=self.retries, retry_delay=self.retry_delay) + if alive_b: + resp_b, lat_b = timed_request("GET", f"{self.url_b}{path}", + params=params, timeout=self.timeout, + retries=self.retries, retry_delay=self.retry_delay) + except Exception as e: + report.add_result(TestResult( + category="api", + name=f"{test_name}: request", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + return None, None + + # Status code comparison + status_a = resp_a.status_code if resp_a else None + status_b = resp_b.status_code if resp_b else None + + if alive_a and alive_b: + if status_a == expected_status and status_b == expected_status: + status = TestStatus.PASS + msg = f"Both returned {expected_status}" + elif status_a == status_b: + status = TestStatus.WARN + msg = f"Both returned {status_a} (expected {expected_status})" + else: + status = TestStatus.FAIL + msg = f"Status codes differ: A={status_a}, B={status_b}" + + report.add_result(TestResult( + category="api", + name=f"{test_name}: status code", + status=status, + message=msg, + instance_a_value=status_a, + instance_b_value=status_b, + duration_seconds=time.time() - t0, + details={"latency_a": lat_a if alive_a else None, + "latency_b": lat_b if alive_b else None}, + )) + + # Parse JSON response + try: + if resp_a and resp_a.status_code == expected_status: + data_a = resp_a.json() + if resp_b and resp_b.status_code == expected_status: + data_b = resp_b.json() + except Exception as e: + report.add_result(TestResult( + category="api", + name=f"{test_name}: JSON parse", + status=TestStatus.ERROR, + message=f"JSON parse error: {e}", + duration_seconds=time.time() - t0, + )) + return data_a, data_b + + # Key comparison (if both have JSON data) + if data_a is not None and data_b is not None: + if isinstance(data_a, dict) and isinstance(data_b, dict): + keys_a = set(data_a.keys()) + keys_b = set(data_b.keys()) + if keys_a == keys_b: + key_status = TestStatus.PASS + key_msg = f"Same keys: {sorted(keys_a)}" + else: + key_status = TestStatus.FAIL + missing_b = keys_a - keys_b + missing_a = keys_b - keys_a + key_msg = f"Keys differ: only_A={missing_b}, only_B={missing_a}" + + report.add_result(TestResult( + category="api", + name=f"{test_name}: response shape", + status=key_status, + message=key_msg, + instance_a_value=sorted(keys_a), + instance_b_value=sorted(keys_b), + duration_seconds=time.time() - t0, + )) + + if isinstance(data_a, list) and isinstance(data_b, list): + if compare_list_length: + len_a = len(data_a) + len_b = len(data_b) + if len_a == len_b: + l_status = TestStatus.PASS + l_msg = f"Same list length: {len_a}" + else: + diff = pct_diff(len_a, len_b) + l_status = TestStatus.WARN if diff <= 20 else TestStatus.FAIL + l_msg = f"List lengths differ: A={len_a}, B={len_b} ({diff:.1f}%)" + + report.add_result(TestResult( + category="api", + name=f"{test_name}: list length", + status=l_status, + message=l_msg, + instance_a_value=len_a, + instance_b_value=len_b, + duration_seconds=time.time() - t0, + )) + + # Check specific keys exist + if check_keys and isinstance(data_a, dict) and isinstance(data_b, dict): + for key in check_keys: + has_a = key in data_a + has_b = key in data_b + if has_a and has_b: + kstatus = TestStatus.PASS + elif has_a or has_b: + kstatus = TestStatus.FAIL + else: + kstatus = TestStatus.WARN + + report.add_result(TestResult( + category="api", + name=f"{test_name}: key '{key}'", + status=kstatus, + message=f"A has '{key}': {has_a}, B has '{key}': {has_b}", + duration_seconds=time.time() - t0, + )) + + return data_a, data_b + + def _compare_post(self, report: ComparisonReport, path: str, test_name: str, + json_data: dict = None, alive_a: bool = True, alive_b: bool = True, + expected_status: int = 200, check_keys: list = None): + """Hit a POST endpoint on both instances and compare results.""" + t0 = time.time() + resp_a = resp_b = None + data_a = data_b = None + + try: + if alive_a: + resp_a, lat_a = timed_request("POST", f"{self.url_a}{path}", + json_data=json_data, timeout=self.timeout, + retries=self.retries, retry_delay=self.retry_delay) + if alive_b: + resp_b, lat_b = timed_request("POST", f"{self.url_b}{path}", + json_data=json_data, timeout=self.timeout, + retries=self.retries, retry_delay=self.retry_delay) + except Exception as e: + report.add_result(TestResult( + category="api", + name=f"{test_name}: request", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + return None, None + + status_a = resp_a.status_code if resp_a else None + status_b = resp_b.status_code if resp_b else None + + if alive_a and alive_b: + if status_a == expected_status and status_b == expected_status: + status = TestStatus.PASS + msg = f"Both returned {expected_status}" + elif status_a == status_b: + status = TestStatus.WARN + msg = f"Both returned {status_a} (expected {expected_status})" + else: + status = TestStatus.FAIL + msg = f"Status codes differ: A={status_a}, B={status_b}" + + report.add_result(TestResult( + category="api", + name=f"{test_name}: status code", + status=status, + message=msg, + instance_a_value=status_a, + instance_b_value=status_b, + duration_seconds=time.time() - t0, + details={"latency_a": lat_a if alive_a else None, + "latency_b": lat_b if alive_b else None}, + )) + + try: + if resp_a and resp_a.status_code == expected_status: + data_a = resp_a.json() + if resp_b and resp_b.status_code == expected_status: + data_b = resp_b.json() + except Exception: + pass + + if data_a is not None and data_b is not None and isinstance(data_a, dict) and isinstance(data_b, dict): + keys_a = set(data_a.keys()) + keys_b = set(data_b.keys()) + if keys_a == keys_b: + report.add_result(TestResult( + category="api", name=f"{test_name}: response shape", + status=TestStatus.PASS, message=f"Same keys: {sorted(keys_a)}", + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="api", name=f"{test_name}: response shape", + status=TestStatus.FAIL, + message=f"Keys differ: only_A={keys_a - keys_b}, only_B={keys_b - keys_a}", + duration_seconds=time.time() - t0, + )) + + if check_keys: + for key in check_keys: + has_a = key in data_a + has_b = key in data_b + report.add_result(TestResult( + category="api", name=f"{test_name}: key '{key}'", + status=TestStatus.PASS if (has_a and has_b) else TestStatus.FAIL, + message=f"A: {has_a}, B: {has_b}", + duration_seconds=time.time() - t0, + )) + + return data_a, data_b + + # ------------------------------------------------------------------ + # Individual endpoint tests + # ------------------------------------------------------------------ + + def _test_config_endpoint(self, report, alive_a, alive_b): + self._compare_get(report, "/api/config", "GET /api/config", + alive_a=alive_a, alive_b=alive_b) + + def _test_playlists_endpoint(self, report, alive_a, alive_b): + self._compare_get(report, "/api/playlists", "GET /api/playlists", + alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + def _test_active_tasks_endpoint(self, report, alive_a, alive_b): + self._compare_get(report, "/api/active_tasks", "GET /api/active_tasks", + alive_a=alive_a, alive_b=alive_b) + + def _test_last_task_endpoint(self, report, alive_a, alive_b): + self._compare_get(report, "/api/last_task", "GET /api/last_task", + alive_a=alive_a, alive_b=alive_b) + + def _test_search_tracks_endpoint(self, report, alive_a, alive_b): + params = { + "artist": self.config.test_track_artist_1, + "title": self.config.test_track_title_1, + } + data_a, data_b = self._compare_get( + report, "/api/search_tracks", "GET /api/search_tracks", + params=params, alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + # Validate response has expected track fields + t0 = time.time() + for label, data in [("A", data_a), ("B", data_b)]: + if data and isinstance(data, list) and data: + track = data[0] + expected = {"item_id", "title"} + present = expected.intersection(track.keys()) + if present == expected: + report.add_result(TestResult( + category="api", + name=f"GET /api/search_tracks: {label} track fields", + status=TestStatus.PASS, + message=f"Track has required fields: {expected}", + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="api", + name=f"GET /api/search_tracks: {label} track fields", + status=TestStatus.FAIL, + message=f"Missing fields: {expected - present}. Has: {set(track.keys())}", + duration_seconds=time.time() - t0, + )) + + def _test_similar_tracks_endpoint(self, report, alive_a, alive_b): + params = { + "title": self.config.test_track_title_1, + "artist": self.config.test_track_artist_1, + "n": 5, + } + data_a, data_b = self._compare_get( + report, "/api/similar_tracks", "GET /api/similar_tracks", + params=params, alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + # Validate result tracks have item_id + t0 = time.time() + for label, data in [("A", data_a), ("B", data_b)]: + if data and isinstance(data, list): + has_ids = all('item_id' in t for t in data) + report.add_result(TestResult( + category="api", + name=f"GET /api/similar_tracks: {label} item_ids present", + status=TestStatus.PASS if has_ids else TestStatus.FAIL, + message=f"All tracks have item_id: {has_ids} ({len(data)} tracks)", + duration_seconds=time.time() - t0, + )) + + def _test_max_distance_endpoint(self, report, alive_a, alive_b): + self._compare_get(report, "/api/max_distance", "GET /api/max_distance", + alive_a=alive_a, alive_b=alive_b) + + def _test_map_endpoint(self, report, alive_a, alive_b): + data_a, data_b = self._compare_get( + report, "/api/map", "GET /api/map", + params={"percent": 10}, alive_a=alive_a, alive_b=alive_b, + check_keys=["items"]) + + # Validate items structure + t0 = time.time() + for label, data in [("A", data_a), ("B", data_b)]: + if data and isinstance(data, dict) and "items" in data: + items = data["items"] + if isinstance(items, list) and items: + report.add_result(TestResult( + category="api", + name=f"GET /api/map: {label} items non-empty", + status=TestStatus.PASS, + message=f"{len(items)} items returned", + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="api", + name=f"GET /api/map: {label} items non-empty", + status=TestStatus.WARN, + message=f"Empty items list", + duration_seconds=time.time() - t0, + )) + + def _test_map_cache_status(self, report, alive_a, alive_b): + self._compare_get(report, "/api/map_cache_status", "GET /api/map_cache_status", + alive_a=alive_a, alive_b=alive_b) + + def _test_clap_stats(self, report, alive_a, alive_b): + self._compare_get(report, "/api/clap/stats", "GET /api/clap/stats", + alive_a=alive_a, alive_b=alive_b) + + def _test_clap_warmup_status(self, report, alive_a, alive_b): + self._compare_get(report, "/api/clap/warmup/status", "GET /api/clap/warmup/status", + alive_a=alive_a, alive_b=alive_b) + + def _test_clap_top_queries(self, report, alive_a, alive_b): + self._compare_get(report, "/api/clap/top_queries", "GET /api/clap/top_queries", + alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + def _test_setup_status(self, report, alive_a, alive_b): + self._compare_get(report, "/api/setup/status", "GET /api/setup/status", + alive_a=alive_a, alive_b=alive_b) + + def _test_setup_providers(self, report, alive_a, alive_b): + self._compare_get(report, "/api/setup/providers", "GET /api/setup/providers", + alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + def _test_setup_settings(self, report, alive_a, alive_b): + self._compare_get(report, "/api/setup/settings", "GET /api/setup/settings", + alive_a=alive_a, alive_b=alive_b) + + def _test_setup_server_info(self, report, alive_a, alive_b): + self._compare_get(report, "/api/setup/server-info", "GET /api/setup/server-info", + alive_a=alive_a, alive_b=alive_b) + + def _test_provider_types(self, report, alive_a, alive_b): + self._compare_get(report, "/api/setup/providers/types", "GET /api/setup/providers/types", + alive_a=alive_a, alive_b=alive_b) + + def _test_providers_enabled(self, report, alive_a, alive_b): + self._compare_get(report, "/api/providers/enabled", "GET /api/providers/enabled", + alive_a=alive_a, alive_b=alive_b) + + def _test_cron_entries(self, report, alive_a, alive_b): + self._compare_get(report, "/api/cron", "GET /api/cron", + alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + def _test_waveform_endpoint(self, report, alive_a, alive_b): + # Waveform needs a track query param - test without to verify error handling + self._compare_get(report, "/api/waveform", "GET /api/waveform (no params)", + alive_a=alive_a, alive_b=alive_b, + expected_status=400) + + def _test_find_path_endpoint(self, report, alive_a, alive_b): + """Test /api/find_path by first finding two track IDs.""" + t0 = time.time() + try: + # Find track IDs from both instances + id_a_start = self._find_track_id(self.url_a, + self.config.test_track_artist_1, + self.config.test_track_title_1) if alive_a else None + id_a_end = self._find_track_id(self.url_a, + self.config.test_track_artist_2, + self.config.test_track_title_2) if alive_a else None + id_b_start = self._find_track_id(self.url_b, + self.config.test_track_artist_1, + self.config.test_track_title_1) if alive_b else None + id_b_end = self._find_track_id(self.url_b, + self.config.test_track_artist_2, + self.config.test_track_title_2) if alive_b else None + + for label, url, start_id, end_id in [ + ("A", self.url_a, id_a_start, id_a_end), + ("B", self.url_b, id_b_start, id_b_end), + ]: + if not start_id or not end_id: + report.add_result(TestResult( + category="api", + name=f"GET /api/find_path ({label}): track lookup", + status=TestStatus.SKIP, + message="Could not find test tracks", + duration_seconds=time.time() - t0, + )) + continue + + resp, lat = timed_request("GET", f"{url}/api/find_path", + params={"start_song_id": start_id, + "end_song_id": end_id, + "max_steps": 10}, + timeout=self.timeout, retries=self.retries, + retry_delay=self.retry_delay) + + if resp.status_code == 200: + data = resp.json() + path = data.get('path', data) if isinstance(data, dict) else data + path_len = len(path) if isinstance(path, list) else 0 + report.add_result(TestResult( + category="api", + name=f"GET /api/find_path ({label})", + status=TestStatus.PASS if path_len > 0 else TestStatus.WARN, + message=f"Path length: {path_len}, latency: {lat:.2f}s", + instance_a_value=path_len if label == "A" else None, + instance_b_value=path_len if label == "B" else None, + duration_seconds=time.time() - t0, + details={"latency": lat}, + )) + else: + report.add_result(TestResult( + category="api", + name=f"GET /api/find_path ({label})", + status=TestStatus.FAIL, + message=f"Status {resp.status_code}: {resp.text[:200]}", + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="api", + name="GET /api/find_path", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + def _test_sonic_fingerprint(self, report, alive_a, alive_b): + """Test sonic fingerprint generation on both instances.""" + payload = {"n": 1, "jellyfin_user_identifier": "admin", "jellyfin_token": ""} + self._compare_post(report, "/api/sonic_fingerprint/generate", + "POST /api/sonic_fingerprint/generate", + json_data=payload, alive_a=alive_a, alive_b=alive_b) + + def _test_alchemy_endpoint(self, report, alive_a, alive_b): + """Test song alchemy (requires track IDs).""" + t0 = time.time() + try: + for label, url, alive in [("A", self.url_a, alive_a), ("B", self.url_b, alive_b)]: + if not alive: + continue + + add_id = self._find_track_id(url, self.config.test_track_artist_1, + self.config.test_track_title_1) + sub_id = self._find_track_id(url, self.config.test_track_artist_2, + self.config.test_track_title_2) + + if not add_id or not sub_id: + report.add_result(TestResult( + category="api", + name=f"POST /api/alchemy ({label}): track lookup", + status=TestStatus.SKIP, + message="Could not find test tracks for alchemy", + duration_seconds=time.time() - t0, + )) + continue + + payload = { + "items": [ + {"id": add_id, "op": "ADD"}, + {"id": sub_id, "op": "SUBTRACT"}, + ], + "n": 5, + "temperature": 1, + "subtract_distance": 0.2, + } + + resp, lat = timed_request("POST", f"{url}/api/alchemy", + json_data=payload, timeout=self.timeout, + retries=self.retries, retry_delay=self.retry_delay) + + if resp.status_code == 200: + data = resp.json() + expected_keys = {"results", "projection"} + has_keys = expected_keys.issubset(data.keys()) if isinstance(data, dict) else False + results_count = len(data.get("results", [])) if isinstance(data, dict) else 0 + + report.add_result(TestResult( + category="api", + name=f"POST /api/alchemy ({label})", + status=TestStatus.PASS if has_keys and results_count > 0 else TestStatus.WARN, + message=f"Has expected keys: {has_keys}, results: {results_count}, latency: {lat:.2f}s", + duration_seconds=time.time() - t0, + details={"latency": lat, "result_count": results_count}, + )) + else: + report.add_result(TestResult( + category="api", + name=f"POST /api/alchemy ({label})", + status=TestStatus.FAIL, + message=f"Status {resp.status_code}: {resp.text[:200]}", + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="api", + name="POST /api/alchemy", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + def _test_artist_projections(self, report, alive_a, alive_b): + self._compare_get(report, "/api/artist_projections", "GET /api/artist_projections", + alive_a=alive_a, alive_b=alive_b) + + def _test_search_artists(self, report, alive_a, alive_b): + self._compare_get(report, "/api/search_artists", "GET /api/search_artists", + params={"q": "Red Hot"}, alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + def _test_external_search(self, report, alive_a, alive_b): + self._compare_get(report, "/external/search", "GET /external/search", + params={"q": "piano"}, alive_a=alive_a, alive_b=alive_b, + compare_list_length=True) + + def _test_chat_config_defaults(self, report, alive_a, alive_b): + self._compare_get(report, "/chat/api/config_defaults", "GET /chat/api/config_defaults", + alive_a=alive_a, alive_b=alive_b) + + def _test_collection_last_task(self, report, alive_a, alive_b): + self._compare_get(report, "/api/collection/last_task", "GET /api/collection/last_task", + alive_a=alive_a, alive_b=alive_b) + + # ------------------------------------------------------------------ + # Error handling tests + # ------------------------------------------------------------------ + + def _test_error_handling(self, report, alive_a, alive_b): + """Test that both instances handle errors consistently.""" + error_cases = [ + ("/api/status/nonexistent_task_id_12345", "Nonexistent task status", 200), + ("/api/track", "Track without item_id", 400), + ("/api/similar_tracks", "Similar tracks without params", 400), + ] + + for path, desc, expected_status in error_cases: + t0 = time.time() + try: + resp_a = http_get(f"{self.url_a}{path}", timeout=15, retries=1) if alive_a else None + resp_b = http_get(f"{self.url_b}{path}", timeout=15, retries=1) if alive_b else None + + status_a = resp_a.status_code if resp_a else None + status_b = resp_b.status_code if resp_b else None + + if alive_a and alive_b: + if status_a == status_b: + report.add_result(TestResult( + category="api", + name=f"Error handling: {desc}", + status=TestStatus.PASS, + message=f"Consistent error codes: {status_a}", + instance_a_value=status_a, + instance_b_value=status_b, + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="api", + name=f"Error handling: {desc}", + status=TestStatus.WARN, + message=f"Different error codes: A={status_a}, B={status_b}", + instance_a_value=status_a, + instance_b_value=status_b, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="api", + name=f"Error handling: {desc}", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _find_track_id(self, base_url: str, artist: str, title: str) -> Optional[str]: + """Find a track's item_id via the search API.""" + try: + resp = http_get(f"{base_url}/api/search_tracks", + params={"artist": artist, "title": title}, + timeout=30, retries=2, retry_delay=1) + if resp.status_code == 200: + results = resp.json() + if isinstance(results, list) and results: + # Try exact match first + for track in results: + track_artist = track.get("author") or track.get("artist") or "" + if track_artist.lower() == artist.lower() and \ + track.get("title", "").lower() == title.lower(): + return track["item_id"] + # Fallback to first result + return results[0].get("item_id") + except Exception as e: + logger.debug(f"Track search failed: {e}") + return None diff --git a/testing_suite/comparators/db_comparator.py b/testing_suite/comparators/db_comparator.py new file mode 100644 index 00000000..41d5e5ec --- /dev/null +++ b/testing_suite/comparators/db_comparator.py @@ -0,0 +1,1142 @@ +""" +Database Comparison Module for AudioMuse-AI Testing Suite. + +Compares two PostgreSQL instances across: + - Schema presence and structure (all expected tables and columns) + - Row counts and data volume + - Data quality (NULL rates, value distributions, outliers) + - Embedding integrity (dimensions, NaN checks, storage sizes) + - Index and constraint validation + - Cross-table referential integrity + - Score/analysis value distributions + - Playlist quality metrics +""" + +import json +import logging +import struct +import time +from typing import Any, Dict, List, Optional, Tuple + +from testing_suite.config import ComparisonConfig, InstanceConfig +from testing_suite.utils import ( + ComparisonReport, TestResult, TestStatus, + pg_query, pg_query_dict, pg_scalar, pct_diff +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Expected schema definition (ground truth) +# --------------------------------------------------------------------------- + +EXPECTED_TABLES = { + "score": [ + "item_id", "title", "author", "album", "album_artist", + "tempo", "key", "scale", "mood_vector", "energy", + "other_features", "year", "rating", "file_path", "track_id" + ], + "embedding": ["item_id", "embedding"], + "clap_embedding": ["item_id", "embedding"], + "playlist": ["id", "playlist_name", "item_id", "title", "author"], + "task_status": [ + "id", "task_id", "parent_task_id", "task_type", "sub_type_identifier", + "status", "progress", "details", "timestamp", "start_time", "end_time" + ], + "voyager_index_data": [ + "index_name", "index_data", "id_map_json", "embedding_dimension", "created_at" + ], + "artist_index_data": [ + "index_name", "index_data", "artist_map_json", "gmm_params_json", "created_at" + ], + "map_projection_data": [ + "index_name", "projection_data", "id_map_json", "embedding_dimension", "created_at" + ], + "artist_component_projection": [ + "index_name", "projection_data", "artist_component_map_json", "created_at" + ], + "cron": ["id", "name", "task_type", "cron_expr", "enabled", "last_run", "created_at"], + "artist_mapping": ["artist_name", "artist_id"], + "text_search_queries": ["id", "query_text", "score", "rank", "created_at"], + "provider": [ + "id", "provider_type", "name", "config", "enabled", + "priority", "created_at", "updated_at" + ], + "track": [ + "id", "file_path_hash", "file_path", "normalized_path", + "file_size", "file_modified", "created_at", "updated_at" + ], + "provider_track": [ + "id", "provider_id", "track_id", "item_id", + "title", "artist", "album", "last_synced" + ], + "app_settings": ["key", "value", "category", "description", "updated_at"], +} + +# Critical columns that should not be NULL in the score table +SCORE_CRITICAL_COLUMNS = ["item_id", "title", "author", "tempo", "key", "scale", "mood_vector"] + +# Columns to check for statistical distribution in score +SCORE_NUMERIC_COLUMNS = ["tempo", "energy"] + + +def _safe_dsn_connect(dsn: str, instance_name: str) -> bool: + """Test if we can connect to the database.""" + try: + pg_scalar(dsn, "SELECT 1") + return True + except Exception as e: + logger.warning(f"Cannot connect to {instance_name} database: {e}") + return False + + +class DatabaseComparator: + """Compares two PostgreSQL database instances.""" + + def __init__(self, config: ComparisonConfig): + self.config = config + self.dsn_a = config.instance_a.pg_dsn + self.dsn_b = config.instance_b.pg_dsn + self.name_a = config.instance_a.name + self.name_b = config.instance_b.name + + def run_all(self, report: ComparisonReport): + """Run all database comparison tests and add results to report.""" + logger.info("Starting database comparison tests...") + + # Check connectivity first + can_a = _safe_dsn_connect(self.dsn_a, self.name_a) + can_b = _safe_dsn_connect(self.dsn_b, self.name_b) + + if not can_a and not can_b: + report.add_result(TestResult( + category="database", + name="DB Connectivity", + status=TestStatus.ERROR, + message="Cannot connect to either database instance" + )) + return + + if not can_a or not can_b: + report.add_result(TestResult( + category="database", + name="DB Connectivity", + status=TestStatus.WARN, + message=f"Only connected to {'A' if can_a else 'B'} instance" + )) + + # Run test suites + if can_a and can_b: + self._test_schema_comparison(report) + self._test_row_counts(report) + self._test_data_quality(report) + self._test_embedding_integrity(report) + self._test_referential_integrity(report) + self._test_score_distributions(report) + self._test_playlist_quality(report) + self._test_index_data_presence(report) + self._test_task_status_health(report) + self._test_provider_config(report) + self._test_app_settings(report) + elif can_a or can_b: + # Single-instance validation + dsn = self.dsn_a if can_a else self.dsn_b + name = self.name_a if can_a else self.name_b + self._test_single_instance_schema(report, dsn, name) + self._test_single_instance_quality(report, dsn, name) + + logger.info("Database comparison tests complete.") + + # ------------------------------------------------------------------ + # Schema comparison + # ------------------------------------------------------------------ + + def _test_schema_comparison(self, report: ComparisonReport): + """Compare table existence and column structure between instances.""" + for table_name, expected_cols in EXPECTED_TABLES.items(): + t0 = time.time() + try: + cols_a = self._get_table_columns(self.dsn_a, table_name) + cols_b = self._get_table_columns(self.dsn_b, table_name) + + table_exists_a = cols_a is not None + table_exists_b = cols_b is not None + + if not table_exists_a and not table_exists_b: + # Optional tables like mulan_embedding may not exist + report.add_result(TestResult( + category="database", + name=f"Schema: {table_name} existence", + status=TestStatus.SKIP, + message=f"Table '{table_name}' does not exist in either instance", + duration_seconds=time.time() - t0, + )) + continue + + if table_exists_a != table_exists_b: + report.add_result(TestResult( + category="database", + name=f"Schema: {table_name} existence", + status=TestStatus.FAIL, + message=f"Table '{table_name}' exists in {'A only' if table_exists_a else 'B only'}", + instance_a_value=table_exists_a, + instance_b_value=table_exists_b, + duration_seconds=time.time() - t0, + )) + continue + + # Compare columns + set_a = set(cols_a) + set_b = set(cols_b) + missing_in_b = set_a - set_b + missing_in_a = set_b - set_a + + if set_a == set_b: + status = TestStatus.PASS + msg = f"Columns match ({len(set_a)} columns)" + else: + status = TestStatus.FAIL + msg = f"Column mismatch: missing_in_B={missing_in_b}, missing_in_A={missing_in_a}" + + # Also check against expected columns + expected_set = set(expected_cols) + missing_expected_a = expected_set - set_a + missing_expected_b = expected_set - set_b + + if missing_expected_a or missing_expected_b: + if status == TestStatus.PASS: + status = TestStatus.WARN + msg += f" | Expected cols missing: A={missing_expected_a or 'none'}, B={missing_expected_b or 'none'}" + + report.add_result(TestResult( + category="database", + name=f"Schema: {table_name} columns", + status=status, + message=msg, + instance_a_value=sorted(set_a), + instance_b_value=sorted(set_b), + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name=f"Schema: {table_name}", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Row counts + # ------------------------------------------------------------------ + + def _test_row_counts(self, report: ComparisonReport): + """Compare row counts across all tables.""" + tables_to_count = [ + "score", "embedding", "clap_embedding", "playlist", + "task_status", "voyager_index_data", "artist_index_data", + "map_projection_data", "cron", "artist_mapping", + "text_search_queries", "provider", "track", "provider_track", + "app_settings" + ] + for table_name in tables_to_count: + t0 = time.time() + try: + count_a = self._safe_count(self.dsn_a, table_name) + count_b = self._safe_count(self.dsn_b, table_name) + + if count_a is None and count_b is None: + report.add_result(TestResult( + category="database", + name=f"Row Count: {table_name}", + status=TestStatus.SKIP, + message="Table does not exist in either instance", + duration_seconds=time.time() - t0, + )) + continue + + diff_pct = pct_diff(count_a or 0, count_b or 0) if (count_a or count_b) else 0 + + if count_a == count_b: + status = TestStatus.PASS + msg = f"Both have {count_a} rows" + elif diff_pct <= self.config.db_row_count_tolerance_pct: + status = TestStatus.WARN + msg = f"A={count_a}, B={count_b} (diff {diff_pct:.1f}% within tolerance)" + else: + status = TestStatus.FAIL + msg = f"A={count_a}, B={count_b} (diff {diff_pct:.1f}% exceeds {self.config.db_row_count_tolerance_pct}%)" + + report.add_result(TestResult( + category="database", + name=f"Row Count: {table_name}", + status=status, + message=msg, + instance_a_value=count_a, + instance_b_value=count_b, + diff=diff_pct, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name=f"Row Count: {table_name}", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Data quality checks + # ------------------------------------------------------------------ + + def _test_data_quality(self, report: ComparisonReport): + """Check NULL rates and data quality in the score table.""" + for col in SCORE_CRITICAL_COLUMNS: + t0 = time.time() + try: + null_pct_a = self._null_percentage(self.dsn_a, "score", col) + null_pct_b = self._null_percentage(self.dsn_b, "score", col) + + if null_pct_a is None and null_pct_b is None: + continue + + threshold = self.config.db_score_null_threshold_pct + problems = [] + if null_pct_a is not None and null_pct_a > threshold: + problems.append(f"A has {null_pct_a:.1f}% NULLs") + if null_pct_b is not None and null_pct_b > threshold: + problems.append(f"B has {null_pct_b:.1f}% NULLs") + + if problems: + status = TestStatus.FAIL + msg = f"score.{col}: " + "; ".join(problems) + f" (threshold {threshold}%)" + else: + status = TestStatus.PASS + msg = f"score.{col}: A={null_pct_a:.1f}% NULL, B={null_pct_b:.1f}% NULL (OK)" + + report.add_result(TestResult( + category="database", + name=f"Data Quality: score.{col} NULLs", + status=status, + message=msg, + instance_a_value=null_pct_a, + instance_b_value=null_pct_b, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name=f"Data Quality: score.{col}", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # Check for duplicate item_ids in score + t0 = time.time() + try: + dupes_a = pg_scalar(self.dsn_a, + "SELECT COUNT(*) FROM (SELECT item_id FROM score GROUP BY item_id HAVING COUNT(*) > 1) sub") + dupes_b = pg_scalar(self.dsn_b, + "SELECT COUNT(*) FROM (SELECT item_id FROM score GROUP BY item_id HAVING COUNT(*) > 1) sub") + + if (dupes_a or 0) == 0 and (dupes_b or 0) == 0: + status = TestStatus.PASS + msg = "No duplicate item_ids in either instance" + else: + status = TestStatus.FAIL + msg = f"Duplicate item_ids: A={dupes_a}, B={dupes_b}" + + report.add_result(TestResult( + category="database", + name="Data Quality: score duplicate item_ids", + status=status, + message=msg, + instance_a_value=dupes_a, + instance_b_value=dupes_b, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name="Data Quality: score duplicates", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # Check mood_vector format validity + t0 = time.time() + try: + invalid_moods_a = pg_scalar(self.dsn_a, """ + SELECT COUNT(*) FROM score + WHERE mood_vector IS NOT NULL + AND mood_vector NOT LIKE '%:%' + """) + invalid_moods_b = pg_scalar(self.dsn_b, """ + SELECT COUNT(*) FROM score + WHERE mood_vector IS NOT NULL + AND mood_vector NOT LIKE '%:%' + """) + + if (invalid_moods_a or 0) == 0 and (invalid_moods_b or 0) == 0: + status = TestStatus.PASS + msg = "All mood_vectors have valid format" + else: + status = TestStatus.WARN + msg = f"Invalid mood_vector format: A={invalid_moods_a}, B={invalid_moods_b}" + + report.add_result(TestResult( + category="database", + name="Data Quality: mood_vector format", + status=status, + message=msg, + instance_a_value=invalid_moods_a, + instance_b_value=invalid_moods_b, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name="Data Quality: mood_vector format", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Embedding integrity + # ------------------------------------------------------------------ + + def _test_embedding_integrity(self, report: ComparisonReport): + """Check embedding dimensions, storage, and coverage.""" + for emb_table, expected_dim in [ + ("embedding", self.config.db_embedding_dimension_expected), + ("clap_embedding", self.config.db_clap_dimension_expected), + ]: + t0 = time.time() + try: + count_a = self._safe_count(self.dsn_a, emb_table) + count_b = self._safe_count(self.dsn_b, emb_table) + score_count_a = self._safe_count(self.dsn_a, "score") + score_count_b = self._safe_count(self.dsn_b, "score") + + if count_a is None and count_b is None: + report.add_result(TestResult( + category="database", + name=f"Embedding: {emb_table} existence", + status=TestStatus.SKIP, + message=f"Table {emb_table} does not exist", + duration_seconds=time.time() - t0, + )) + continue + + # Coverage check + coverage_a = (count_a / score_count_a * 100) if score_count_a else 0 + coverage_b = (count_b / score_count_b * 100) if score_count_b else 0 + + if coverage_a >= 95 and coverage_b >= 95: + status = TestStatus.PASS + elif coverage_a >= 80 and coverage_b >= 80: + status = TestStatus.WARN + else: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="database", + name=f"Embedding: {emb_table} coverage", + status=status, + message=f"Coverage: A={coverage_a:.1f}% ({count_a}/{score_count_a}), " + f"B={coverage_b:.1f}% ({count_b}/{score_count_b})", + instance_a_value=coverage_a, + instance_b_value=coverage_b, + duration_seconds=time.time() - t0, + )) + + # NULL embedding check + null_emb_a = pg_scalar(self.dsn_a, + f"SELECT COUNT(*) FROM {emb_table} WHERE embedding IS NULL") + null_emb_b = pg_scalar(self.dsn_b, + f"SELECT COUNT(*) FROM {emb_table} WHERE embedding IS NULL") + + if (null_emb_a or 0) == 0 and (null_emb_b or 0) == 0: + status = TestStatus.PASS + msg = "No NULL embeddings" + else: + status = TestStatus.FAIL + msg = f"NULL embeddings: A={null_emb_a}, B={null_emb_b}" + + report.add_result(TestResult( + category="database", + name=f"Embedding: {emb_table} NULL check", + status=status, + message=msg, + instance_a_value=null_emb_a, + instance_b_value=null_emb_b, + duration_seconds=time.time() - t0, + )) + + # Average embedding size (proxy for dimension check) + avg_size_a = pg_scalar(self.dsn_a, + f"SELECT AVG(octet_length(embedding)) FROM {emb_table} WHERE embedding IS NOT NULL") + avg_size_b = pg_scalar(self.dsn_b, + f"SELECT AVG(octet_length(embedding)) FROM {emb_table} WHERE embedding IS NOT NULL") + + if avg_size_a and avg_size_b: + # float32 = 4 bytes per dimension + approx_dim_a = int(float(avg_size_a) / 4) if avg_size_a else 0 + approx_dim_b = int(float(avg_size_b) / 4) if avg_size_b else 0 + + if approx_dim_a == approx_dim_b: + status = TestStatus.PASS + else: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="database", + name=f"Embedding: {emb_table} avg dimension", + status=status, + message=f"Approx dimensions: A~{approx_dim_a}, B~{approx_dim_b} " + f"(avg bytes: A={float(avg_size_a):.0f}, B={float(avg_size_b):.0f})", + instance_a_value=approx_dim_a, + instance_b_value=approx_dim_b, + duration_seconds=time.time() - t0, + )) + + except Exception as e: + report.add_result(TestResult( + category="database", + name=f"Embedding: {emb_table}", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Referential integrity + # ------------------------------------------------------------------ + + def _test_referential_integrity(self, report: ComparisonReport): + """Check foreign key relationships are intact.""" + # Embeddings should all reference valid score rows + for emb_table in ["embedding", "clap_embedding"]: + t0 = time.time() + try: + orphans_a = pg_scalar(self.dsn_a, f""" + SELECT COUNT(*) FROM {emb_table} e + LEFT JOIN score s ON e.item_id = s.item_id + WHERE s.item_id IS NULL + """) + orphans_b = pg_scalar(self.dsn_b, f""" + SELECT COUNT(*) FROM {emb_table} e + LEFT JOIN score s ON e.item_id = s.item_id + WHERE s.item_id IS NULL + """) + + if orphans_a is None and orphans_b is None: + continue + + if (orphans_a or 0) == 0 and (orphans_b or 0) == 0: + status = TestStatus.PASS + msg = f"No orphaned rows in {emb_table}" + else: + status = TestStatus.FAIL + msg = f"Orphaned {emb_table} rows: A={orphans_a}, B={orphans_b}" + + report.add_result(TestResult( + category="database", + name=f"Referential: {emb_table} -> score", + status=status, + message=msg, + instance_a_value=orphans_a, + instance_b_value=orphans_b, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name=f"Referential: {emb_table} -> score", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # provider_track -> provider and track references + t0 = time.time() + try: + orphan_provider_a = pg_scalar(self.dsn_a, """ + SELECT COUNT(*) FROM provider_track pt + LEFT JOIN provider p ON pt.provider_id = p.id + WHERE p.id IS NULL + """) + orphan_provider_b = pg_scalar(self.dsn_b, """ + SELECT COUNT(*) FROM provider_track pt + LEFT JOIN provider p ON pt.provider_id = p.id + WHERE p.id IS NULL + """) + + if (orphan_provider_a or 0) == 0 and (orphan_provider_b or 0) == 0: + status = TestStatus.PASS + msg = "No orphaned provider_track -> provider rows" + else: + status = TestStatus.FAIL + msg = f"Orphaned provider refs: A={orphan_provider_a}, B={orphan_provider_b}" + + report.add_result(TestResult( + category="database", + name="Referential: provider_track -> provider", + status=status, + message=msg, + instance_a_value=orphan_provider_a, + instance_b_value=orphan_provider_b, + duration_seconds=time.time() - t0, + )) + except Exception as e: + # Tables may not exist in some deployments + if "does not exist" in str(e).lower() or "relation" in str(e).lower(): + report.add_result(TestResult( + category="database", + name="Referential: provider_track -> provider", + status=TestStatus.SKIP, + message="Multi-provider tables not present", + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="database", + name="Referential: provider_track -> provider", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Score distributions + # ------------------------------------------------------------------ + + def _test_score_distributions(self, report: ComparisonReport): + """Compare statistical distributions of score columns.""" + for col in SCORE_NUMERIC_COLUMNS: + t0 = time.time() + try: + stats_a = self._get_column_stats(self.dsn_a, "score", col) + stats_b = self._get_column_stats(self.dsn_b, "score", col) + + if not stats_a or not stats_b: + continue + + # Compare means + mean_diff = pct_diff(stats_a['avg'], stats_b['avg']) if stats_a['avg'] and stats_b['avg'] else 0 + + if mean_diff <= 10: + status = TestStatus.PASS + elif mean_diff <= 25: + status = TestStatus.WARN + else: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="database", + name=f"Distribution: score.{col}", + status=status, + message=( + f"A: min={stats_a['min']:.3f}, max={stats_a['max']:.3f}, " + f"avg={stats_a['avg']:.3f}, stddev={stats_a['stddev']:.3f} | " + f"B: min={stats_b['min']:.3f}, max={stats_b['max']:.3f}, " + f"avg={stats_b['avg']:.3f}, stddev={stats_b['stddev']:.3f} | " + f"Mean diff: {mean_diff:.1f}%" + ), + instance_a_value=stats_a, + instance_b_value=stats_b, + diff=mean_diff, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name=f"Distribution: score.{col}", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # Key distribution comparison + t0 = time.time() + try: + keys_a = pg_query_dict(self.dsn_a, + "SELECT key, COUNT(*) as cnt FROM score WHERE key IS NOT NULL GROUP BY key ORDER BY cnt DESC") + keys_b = pg_query_dict(self.dsn_b, + "SELECT key, COUNT(*) as cnt FROM score WHERE key IS NOT NULL GROUP BY key ORDER BY cnt DESC") + + keys_set_a = set(r['key'] for r in keys_a) + keys_set_b = set(r['key'] for r in keys_b) + + if keys_set_a == keys_set_b: + status = TestStatus.PASS + msg = f"Same key values detected ({len(keys_set_a)} keys)" + else: + status = TestStatus.WARN + diff_keys = keys_set_a.symmetric_difference(keys_set_b) + msg = f"Key distribution differs: unique to one side: {diff_keys}" + + report.add_result(TestResult( + category="database", + name="Distribution: score.key values", + status=status, + message=msg, + instance_a_value=[r['key'] for r in keys_a[:12]], + instance_b_value=[r['key'] for r in keys_b[:12]], + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name="Distribution: score.key values", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Playlist quality + # ------------------------------------------------------------------ + + def _test_playlist_quality(self, report: ComparisonReport): + """Check playlist table quality.""" + t0 = time.time() + try: + # Distinct playlists + pl_count_a = pg_scalar(self.dsn_a, + "SELECT COUNT(DISTINCT playlist_name) FROM playlist") + pl_count_b = pg_scalar(self.dsn_b, + "SELECT COUNT(DISTINCT playlist_name) FROM playlist") + + diff = pct_diff(pl_count_a or 0, pl_count_b or 0) + + if pl_count_a == pl_count_b: + status = TestStatus.PASS + elif diff <= 20: + status = TestStatus.WARN + else: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="database", + name="Playlist: distinct count", + status=status, + message=f"Distinct playlists: A={pl_count_a}, B={pl_count_b} (diff {diff:.1f}%)", + instance_a_value=pl_count_a, + instance_b_value=pl_count_b, + diff=diff, + duration_seconds=time.time() - t0, + )) + + # Average tracks per playlist + avg_tracks_a = pg_scalar(self.dsn_a, """ + SELECT AVG(cnt) FROM ( + SELECT COUNT(*) as cnt FROM playlist GROUP BY playlist_name + ) sub + """) + avg_tracks_b = pg_scalar(self.dsn_b, """ + SELECT AVG(cnt) FROM ( + SELECT COUNT(*) as cnt FROM playlist GROUP BY playlist_name + ) sub + """) + + if avg_tracks_a and avg_tracks_b: + diff = pct_diff(float(avg_tracks_a), float(avg_tracks_b)) + status = TestStatus.PASS if diff <= 20 else TestStatus.WARN + + report.add_result(TestResult( + category="database", + name="Playlist: avg tracks per playlist", + status=status, + message=f"Avg tracks/playlist: A={float(avg_tracks_a):.1f}, B={float(avg_tracks_b):.1f}", + instance_a_value=float(avg_tracks_a), + instance_b_value=float(avg_tracks_b), + diff=diff, + duration_seconds=time.time() - t0, + )) + + # Playlists with NULL item_ids + null_items_a = pg_scalar(self.dsn_a, + "SELECT COUNT(*) FROM playlist WHERE item_id IS NULL") + null_items_b = pg_scalar(self.dsn_b, + "SELECT COUNT(*) FROM playlist WHERE item_id IS NULL") + + if (null_items_a or 0) == 0 and (null_items_b or 0) == 0: + status = TestStatus.PASS + msg = "No NULL item_ids in playlists" + else: + status = TestStatus.WARN + msg = f"NULL item_ids in playlist: A={null_items_a}, B={null_items_b}" + + report.add_result(TestResult( + category="database", + name="Playlist: NULL item_ids", + status=status, + message=msg, + instance_a_value=null_items_a, + instance_b_value=null_items_b, + duration_seconds=time.time() - t0, + )) + + except Exception as e: + report.add_result(TestResult( + category="database", + name="Playlist quality", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Index data presence + # ------------------------------------------------------------------ + + def _test_index_data_presence(self, report: ComparisonReport): + """Check that Voyager/Artist indexes and map projections are present.""" + for table, desc in [ + ("voyager_index_data", "Voyager HNSW index"), + ("artist_index_data", "Artist GMM index"), + ("map_projection_data", "Map projection"), + ("artist_component_projection", "Artist projection"), + ]: + t0 = time.time() + try: + count_a = self._safe_count(self.dsn_a, table) + count_b = self._safe_count(self.dsn_b, table) + + if count_a is None and count_b is None: + report.add_result(TestResult( + category="database", + name=f"Index: {desc}", + status=TestStatus.SKIP, + message=f"Table {table} does not exist", + duration_seconds=time.time() - t0, + )) + continue + + if (count_a or 0) > 0 and (count_b or 0) > 0: + status = TestStatus.PASS + msg = f"Present in both: A={count_a}, B={count_b}" + elif (count_a or 0) > 0 or (count_b or 0) > 0: + status = TestStatus.WARN + msg = f"Only in {'A' if count_a else 'B'}: A={count_a}, B={count_b}" + else: + status = TestStatus.WARN + msg = "Empty in both instances (may need rebuild)" + + report.add_result(TestResult( + category="database", + name=f"Index: {desc}", + status=status, + message=msg, + instance_a_value=count_a, + instance_b_value=count_b, + duration_seconds=time.time() - t0, + )) + except Exception as e: + report.add_result(TestResult( + category="database", + name=f"Index: {desc}", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Task status health + # ------------------------------------------------------------------ + + def _test_task_status_health(self, report: ComparisonReport): + """Check task_status table for stuck or failed tasks.""" + t0 = time.time() + try: + # Failed tasks + failed_a = pg_scalar(self.dsn_a, + "SELECT COUNT(*) FROM task_status WHERE status = 'FAILURE'") + failed_b = pg_scalar(self.dsn_b, + "SELECT COUNT(*) FROM task_status WHERE status = 'FAILURE'") + + report.add_result(TestResult( + category="database", + name="Tasks: failed count", + status=TestStatus.PASS if (failed_a or 0) == (failed_b or 0) else TestStatus.WARN, + message=f"Failed tasks: A={failed_a}, B={failed_b}", + instance_a_value=failed_a, + instance_b_value=failed_b, + duration_seconds=time.time() - t0, + )) + + # Stuck tasks (STARTED more than 2 hours ago) + stuck_a = pg_scalar(self.dsn_a, """ + SELECT COUNT(*) FROM task_status + WHERE status IN ('STARTED', 'PROGRESS') + AND start_time < EXTRACT(EPOCH FROM NOW()) - 7200 + """) + stuck_b = pg_scalar(self.dsn_b, """ + SELECT COUNT(*) FROM task_status + WHERE status IN ('STARTED', 'PROGRESS') + AND start_time < EXTRACT(EPOCH FROM NOW()) - 7200 + """) + + if (stuck_a or 0) == 0 and (stuck_b or 0) == 0: + status = TestStatus.PASS + msg = "No stuck tasks" + else: + status = TestStatus.WARN + msg = f"Stuck tasks (>2hr): A={stuck_a}, B={stuck_b}" + + report.add_result(TestResult( + category="database", + name="Tasks: stuck check", + status=status, + message=msg, + instance_a_value=stuck_a, + instance_b_value=stuck_b, + duration_seconds=time.time() - t0, + )) + + # Success rate + total_a = self._safe_count(self.dsn_a, "task_status") or 1 + total_b = self._safe_count(self.dsn_b, "task_status") or 1 + success_a = pg_scalar(self.dsn_a, + "SELECT COUNT(*) FROM task_status WHERE status = 'SUCCESS'") or 0 + success_b = pg_scalar(self.dsn_b, + "SELECT COUNT(*) FROM task_status WHERE status = 'SUCCESS'") or 0 + + rate_a = success_a / total_a * 100 + rate_b = success_b / total_b * 100 + + report.add_result(TestResult( + category="database", + name="Tasks: success rate", + status=TestStatus.PASS if abs(rate_a - rate_b) < 10 else TestStatus.WARN, + message=f"Success rate: A={rate_a:.1f}%, B={rate_b:.1f}%", + instance_a_value=rate_a, + instance_b_value=rate_b, + diff=abs(rate_a - rate_b), + duration_seconds=time.time() - t0, + )) + + except Exception as e: + report.add_result(TestResult( + category="database", + name="Tasks health", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Provider config + # ------------------------------------------------------------------ + + def _test_provider_config(self, report: ComparisonReport): + """Compare provider configurations.""" + t0 = time.time() + try: + providers_a = pg_query_dict(self.dsn_a, + "SELECT provider_type, name, enabled, priority FROM provider ORDER BY id") + providers_b = pg_query_dict(self.dsn_b, + "SELECT provider_type, name, enabled, priority FROM provider ORDER BY id") + + if not providers_a and not providers_b: + report.add_result(TestResult( + category="database", + name="Provider: configuration", + status=TestStatus.SKIP, + message="No providers configured in either instance", + duration_seconds=time.time() - t0, + )) + return + + types_a = set(p['provider_type'] for p in providers_a) + types_b = set(p['provider_type'] for p in providers_b) + + if types_a == types_b: + status = TestStatus.PASS + msg = f"Same provider types: {types_a}" + else: + status = TestStatus.WARN + msg = f"Provider types differ: A={types_a}, B={types_b}" + + report.add_result(TestResult( + category="database", + name="Provider: configuration match", + status=status, + message=msg, + instance_a_value=[dict(p) for p in providers_a], + instance_b_value=[dict(p) for p in providers_b], + duration_seconds=time.time() - t0, + )) + except Exception as e: + if "does not exist" in str(e).lower(): + report.add_result(TestResult( + category="database", + name="Provider: configuration", + status=TestStatus.SKIP, + message="Provider table not present", + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="database", + name="Provider: configuration", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # App settings + # ------------------------------------------------------------------ + + def _test_app_settings(self, report: ComparisonReport): + """Compare app_settings between instances.""" + t0 = time.time() + try: + settings_a = pg_query_dict(self.dsn_a, + "SELECT key, value, category FROM app_settings ORDER BY key") + settings_b = pg_query_dict(self.dsn_b, + "SELECT key, value, category FROM app_settings ORDER BY key") + + keys_a = set(s['key'] for s in settings_a) + keys_b = set(s['key'] for s in settings_b) + + if keys_a == keys_b: + status = TestStatus.PASS + msg = f"Same settings keys ({len(keys_a)} settings)" + else: + missing_b = keys_a - keys_b + missing_a = keys_b - keys_a + status = TestStatus.WARN + msg = f"Settings differ: missing_in_B={missing_b}, missing_in_A={missing_a}" + + report.add_result(TestResult( + category="database", + name="App Settings: key comparison", + status=status, + message=msg, + instance_a_value=sorted(keys_a), + instance_b_value=sorted(keys_b), + duration_seconds=time.time() - t0, + )) + except Exception as e: + if "does not exist" in str(e).lower(): + report.add_result(TestResult( + category="database", + name="App Settings", + status=TestStatus.SKIP, + message="app_settings table not present", + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="database", + name="App Settings", + status=TestStatus.ERROR, + message=str(e), + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Single-instance tests (when only one DB is available) + # ------------------------------------------------------------------ + + def _test_single_instance_schema(self, report: ComparisonReport, dsn: str, name: str): + """Validate schema for a single instance.""" + for table_name, expected_cols in EXPECTED_TABLES.items(): + t0 = time.time() + cols = self._get_table_columns(dsn, table_name) + if cols is None: + report.add_result(TestResult( + category="database", + name=f"Schema ({name}): {table_name}", + status=TestStatus.SKIP, + message=f"Table does not exist in {name}", + duration_seconds=time.time() - t0, + )) + else: + missing = set(expected_cols) - set(cols) + status = TestStatus.PASS if not missing else TestStatus.WARN + report.add_result(TestResult( + category="database", + name=f"Schema ({name}): {table_name}", + status=status, + message=f"Columns: {sorted(cols)}. Missing expected: {missing or 'none'}", + duration_seconds=time.time() - t0, + )) + + def _test_single_instance_quality(self, report: ComparisonReport, dsn: str, name: str): + """Validate data quality for a single instance.""" + for col in SCORE_CRITICAL_COLUMNS: + t0 = time.time() + try: + null_pct = self._null_percentage(dsn, "score", col) + if null_pct is not None: + status = TestStatus.PASS if null_pct <= self.config.db_score_null_threshold_pct else TestStatus.FAIL + report.add_result(TestResult( + category="database", + name=f"Quality ({name}): score.{col} NULLs", + status=status, + message=f"{null_pct:.1f}% NULL", + duration_seconds=time.time() - t0, + )) + except Exception: + pass + + # ------------------------------------------------------------------ + # Helper methods + # ------------------------------------------------------------------ + + def _get_table_columns(self, dsn: str, table_name: str) -> Optional[List[str]]: + """Get column names for a table, or None if table doesn't exist.""" + try: + rows = pg_query(dsn, + "SELECT column_name FROM information_schema.columns WHERE table_name = %s ORDER BY ordinal_position", + (table_name,)) + if not rows: + return None + return [r[0] for r in rows] + except Exception: + return None + + def _safe_count(self, dsn: str, table_name: str) -> Optional[int]: + """Get row count for a table, or None if table doesn't exist.""" + try: + return pg_scalar(dsn, f"SELECT COUNT(*) FROM {table_name}") + except Exception: + return None + + def _null_percentage(self, dsn: str, table: str, column: str) -> Optional[float]: + """Get percentage of NULL values in a column.""" + try: + total = pg_scalar(dsn, f"SELECT COUNT(*) FROM {table}") + if not total: + return None + nulls = pg_scalar(dsn, f"SELECT COUNT(*) FROM {table} WHERE {column} IS NULL") + return (nulls / total) * 100 + except Exception: + return None + + def _get_column_stats(self, dsn: str, table: str, column: str) -> Optional[dict]: + """Get min/max/avg/stddev for a numeric column.""" + try: + rows = pg_query(dsn, f""" + SELECT MIN({column}), MAX({column}), AVG({column}), STDDEV({column}) + FROM {table} + WHERE {column} IS NOT NULL + """) + if rows and rows[0][0] is not None: + return { + 'min': float(rows[0][0]), + 'max': float(rows[0][1]), + 'avg': float(rows[0][2]), + 'stddev': float(rows[0][3]) if rows[0][3] else 0.0, + } + except Exception: + pass + return None diff --git a/testing_suite/comparators/docker_comparator.py b/testing_suite/comparators/docker_comparator.py new file mode 100644 index 00000000..2ee95834 --- /dev/null +++ b/testing_suite/comparators/docker_comparator.py @@ -0,0 +1,541 @@ +""" +Docker Comparison Module for AudioMuse-AI Testing Suite. + +Compares two Docker deployments across: + - Container health and status + - Resource usage (memory, CPU) + - Log analysis (error rates, warning patterns) + - Service connectivity (Redis, PostgreSQL, Flask, Worker) + - Container uptime and restart counts + - Log-based error pattern detection +""" + +import json +import logging +import re +import time +from collections import Counter +from typing import Any, Dict, List, Optional, Tuple + +from testing_suite.config import ComparisonConfig, InstanceConfig +from testing_suite.utils import ( + ComparisonReport, TestResult, TestStatus, + docker_exec, docker_logs, docker_inspect, pct_diff +) + +logger = logging.getLogger(__name__) + +# Log patterns to search for +ERROR_PATTERNS = [ + (r"(?i)traceback \(most recent call last\)", "Python Traceback"), + (r"(?i)error|exception", "Error/Exception"), + (r"(?i)out of memory|oom|killed", "OOM/Memory Kill"), + (r"(?i)connection refused|connection reset|broken pipe", "Connection Error"), + (r"(?i)timeout|timed out", "Timeout"), + (r"(?i)permission denied|access denied", "Permission Error"), + (r"(?i)disk full|no space left", "Disk Space"), + (r"(?i)segmentation fault|segfault|core dump", "Crash/Segfault"), + (r"(?i)worker .* died|worker .* killed", "Worker Death"), + (r"(?i)database .* error|psycopg2\..*error", "Database Error"), + (r"(?i)redis\..*error|redis connection", "Redis Error"), +] + +WARNING_PATTERNS = [ + (r"(?i)deprecat", "Deprecation Warning"), + (r"(?i)warning", "Warning"), + (r"(?i)retry|retrying", "Retry Attempt"), + (r"(?i)slow query|slow request", "Slow Operation"), + (r"(?i)memory usage|memory pressure", "Memory Pressure"), +] + + +def _get_ssh_params(instance: InstanceConfig) -> dict: + """Extract SSH parameters from instance config.""" + return { + "ssh_host": instance.ssh_host, + "ssh_user": instance.ssh_user, + "ssh_key": instance.ssh_key, + "ssh_port": instance.ssh_port, + } + + +class DockerComparator: + """Compares Docker deployment health across two AudioMuse-AI instances.""" + + def __init__(self, config: ComparisonConfig): + self.config = config + self.inst_a = config.instance_a + self.inst_b = config.instance_b + + def run_all(self, report: ComparisonReport): + """Run all Docker comparison tests.""" + logger.info("Starting Docker comparison tests...") + + # Test each container type on both instances + containers = [ + ("flask", "docker_flask_container", "Flask App Server"), + ("worker", "docker_worker_container", "RQ Worker"), + ("postgres", "docker_postgres_container", "PostgreSQL"), + ("redis", "docker_redis_container", "Redis"), + ] + + for container_key, attr_name, description in containers: + name_a = getattr(self.inst_a, attr_name) + name_b = getattr(self.inst_b, attr_name) + + self._test_container_health(report, name_a, name_b, description, + self.inst_a, self.inst_b) + self._test_container_resource_usage(report, name_a, name_b, description, + self.inst_a, self.inst_b) + + # Log analysis for flask and worker containers + for attr_name, description in [ + ("docker_flask_container", "Flask"), + ("docker_worker_container", "Worker"), + ]: + name_a = getattr(self.inst_a, attr_name) + name_b = getattr(self.inst_b, attr_name) + self._test_log_error_analysis(report, name_a, name_b, description, + self.inst_a, self.inst_b) + + # Service connectivity tests + self._test_redis_connectivity(report) + self._test_postgres_connectivity(report) + + logger.info("Docker comparison tests complete.") + + # ------------------------------------------------------------------ + # Container health + # ------------------------------------------------------------------ + + def _test_container_health(self, report: ComparisonReport, + name_a: str, name_b: str, description: str, + inst_a: InstanceConfig, inst_b: InstanceConfig): + """Check container status, uptime, and restart count.""" + t0 = time.time() + + info_a = docker_inspect(name_a, **_get_ssh_params(inst_a)) + info_b = docker_inspect(name_b, **_get_ssh_params(inst_b)) + + # Container running status + running_a = self._is_running(info_a) + running_b = self._is_running(info_b) + + if running_a is None and running_b is None: + report.add_result(TestResult( + category="docker", + name=f"{description}: container status", + status=TestStatus.SKIP, + message="Cannot inspect containers (Docker not available or containers not found)", + duration_seconds=time.time() - t0, + )) + return + + if running_a and running_b: + status = TestStatus.PASS + msg = "Both containers running" + elif running_a or running_b: + status = TestStatus.FAIL + msg = f"Only {'A' if running_a else 'B'} is running" + else: + status = TestStatus.FAIL + msg = "Neither container is running" + + report.add_result(TestResult( + category="docker", + name=f"{description}: container status", + status=status, + message=msg, + instance_a_value=f"running={running_a}", + instance_b_value=f"running={running_b}", + duration_seconds=time.time() - t0, + )) + + # Restart count + restarts_a = self._get_restart_count(info_a) + restarts_b = self._get_restart_count(info_b) + + if restarts_a is not None or restarts_b is not None: + if (restarts_a or 0) == 0 and (restarts_b or 0) == 0: + r_status = TestStatus.PASS + r_msg = "No restarts on either instance" + elif (restarts_a or 0) > 5 or (restarts_b or 0) > 5: + r_status = TestStatus.FAIL + r_msg = f"High restart count: A={restarts_a}, B={restarts_b}" + else: + r_status = TestStatus.WARN + r_msg = f"Restarts: A={restarts_a}, B={restarts_b}" + + report.add_result(TestResult( + category="docker", + name=f"{description}: restart count", + status=r_status, + message=r_msg, + instance_a_value=restarts_a, + instance_b_value=restarts_b, + duration_seconds=time.time() - t0, + )) + + # Health check status + health_a = self._get_health_status(info_a) + health_b = self._get_health_status(info_b) + + if health_a or health_b: + if health_a == health_b: + h_status = TestStatus.PASS + h_msg = f"Same health status: {health_a}" + else: + h_status = TestStatus.WARN + h_msg = f"Health differs: A={health_a}, B={health_b}" + + report.add_result(TestResult( + category="docker", + name=f"{description}: health check", + status=h_status, + message=h_msg, + instance_a_value=health_a, + instance_b_value=health_b, + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Resource usage + # ------------------------------------------------------------------ + + def _test_container_resource_usage(self, report: ComparisonReport, + name_a: str, name_b: str, description: str, + inst_a: InstanceConfig, inst_b: InstanceConfig): + """Compare memory and CPU usage between containers.""" + t0 = time.time() + + stats_a = self._get_container_stats(name_a, inst_a) + stats_b = self._get_container_stats(name_b, inst_b) + + if not stats_a and not stats_b: + report.add_result(TestResult( + category="docker", + name=f"{description}: resource usage", + status=TestStatus.SKIP, + message="Cannot get container stats", + duration_seconds=time.time() - t0, + )) + return + + # Memory usage + mem_a = stats_a.get('memory_mb') if stats_a else None + mem_b = stats_b.get('memory_mb') if stats_b else None + + if mem_a is not None and mem_b is not None: + diff = pct_diff(mem_a, mem_b) + if diff <= 20: + status = TestStatus.PASS + elif diff <= 50: + status = TestStatus.WARN + else: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="docker", + name=f"{description}: memory usage", + status=status, + message=f"A={mem_a:.1f}MB, B={mem_b:.1f}MB (diff {diff:.1f}%)", + instance_a_value=mem_a, + instance_b_value=mem_b, + diff=diff, + duration_seconds=time.time() - t0, + )) + + # CPU usage + cpu_a = stats_a.get('cpu_pct') if stats_a else None + cpu_b = stats_b.get('cpu_pct') if stats_b else None + + if cpu_a is not None and cpu_b is not None: + report.add_result(TestResult( + category="docker", + name=f"{description}: CPU usage", + status=TestStatus.PASS, + message=f"A={cpu_a:.1f}%, B={cpu_b:.1f}%", + instance_a_value=cpu_a, + instance_b_value=cpu_b, + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Log error analysis + # ------------------------------------------------------------------ + + def _test_log_error_analysis(self, report: ComparisonReport, + name_a: str, name_b: str, description: str, + inst_a: InstanceConfig, inst_b: InstanceConfig): + """Analyze container logs for error patterns.""" + t0 = time.time() + + # Fetch logs + logs_a_stdout, logs_a_stderr, rc_a = docker_logs( + name_a, tail=2000, **_get_ssh_params(inst_a), timeout=30) + logs_b_stdout, logs_b_stderr, rc_b = docker_logs( + name_b, tail=2000, **_get_ssh_params(inst_b), timeout=30) + + if rc_a != 0 and rc_b != 0: + report.add_result(TestResult( + category="docker", + name=f"{description}: log analysis", + status=TestStatus.SKIP, + message="Cannot fetch logs from either container", + duration_seconds=time.time() - t0, + )) + return + + # Combine stdout + stderr for analysis + logs_a = (logs_a_stdout or '') + '\n' + (logs_a_stderr or '') + logs_b = (logs_b_stdout or '') + '\n' + (logs_b_stderr or '') + + # Error pattern matching + errors_a = self._count_patterns(logs_a, ERROR_PATTERNS) + errors_b = self._count_patterns(logs_b, ERROR_PATTERNS) + + total_errors_a = sum(errors_a.values()) + total_errors_b = sum(errors_b.values()) + + if total_errors_a == 0 and total_errors_b == 0: + report.add_result(TestResult( + category="docker", + name=f"{description}: error patterns", + status=TestStatus.PASS, + message="No error patterns detected in recent logs", + duration_seconds=time.time() - t0, + )) + else: + # Compare error counts + if total_errors_a <= total_errors_b: + status = TestStatus.WARN + else: + status = TestStatus.WARN + + if total_errors_a > 50 or total_errors_b > 50: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="docker", + name=f"{description}: error count", + status=status, + message=f"Errors in last 2000 log lines: A={total_errors_a}, B={total_errors_b}", + instance_a_value=total_errors_a, + instance_b_value=total_errors_b, + duration_seconds=time.time() - t0, + details={"errors_a": dict(errors_a), "errors_b": dict(errors_b)}, + )) + + # Detailed per-pattern breakdown + all_patterns = set(errors_a.keys()) | set(errors_b.keys()) + for pattern_name in sorted(all_patterns): + cnt_a = errors_a.get(pattern_name, 0) + cnt_b = errors_b.get(pattern_name, 0) + + if cnt_a == 0 and cnt_b == 0: + continue + + if cnt_a == cnt_b: + p_status = TestStatus.WARN + elif cnt_a > cnt_b * 2 or cnt_b > cnt_a * 2: + p_status = TestStatus.FAIL + else: + p_status = TestStatus.WARN + + report.add_result(TestResult( + category="docker", + name=f"{description}: {pattern_name}", + status=p_status, + message=f"A={cnt_a}, B={cnt_b}", + instance_a_value=cnt_a, + instance_b_value=cnt_b, + duration_seconds=time.time() - t0, + )) + + # Warning pattern matching + warnings_a = self._count_patterns(logs_a, WARNING_PATTERNS) + warnings_b = self._count_patterns(logs_b, WARNING_PATTERNS) + + total_warnings_a = sum(warnings_a.values()) + total_warnings_b = sum(warnings_b.values()) + + report.add_result(TestResult( + category="docker", + name=f"{description}: warning count", + status=TestStatus.PASS if total_warnings_a < 100 and total_warnings_b < 100 else TestStatus.WARN, + message=f"Warnings: A={total_warnings_a}, B={total_warnings_b}", + instance_a_value=total_warnings_a, + instance_b_value=total_warnings_b, + duration_seconds=time.time() - t0, + details={"warnings_a": dict(warnings_a), "warnings_b": dict(warnings_b)}, + )) + + # Check for Python tracebacks specifically (important indicator) + tb_count_a = logs_a.count("Traceback (most recent call last)") + tb_count_b = logs_b.count("Traceback (most recent call last)") + + if tb_count_a == 0 and tb_count_b == 0: + tb_status = TestStatus.PASS + tb_msg = "No Python tracebacks in recent logs" + elif tb_count_a > 10 or tb_count_b > 10: + tb_status = TestStatus.FAIL + tb_msg = f"Tracebacks: A={tb_count_a}, B={tb_count_b}" + else: + tb_status = TestStatus.WARN + tb_msg = f"Tracebacks: A={tb_count_a}, B={tb_count_b}" + + report.add_result(TestResult( + category="docker", + name=f"{description}: Python tracebacks", + status=tb_status, + message=tb_msg, + instance_a_value=tb_count_a, + instance_b_value=tb_count_b, + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Service connectivity + # ------------------------------------------------------------------ + + def _test_redis_connectivity(self, report: ComparisonReport): + """Test Redis connectivity from within the Flask container.""" + t0 = time.time() + for label, inst in [("A", self.inst_a), ("B", self.inst_b)]: + container = inst.docker_flask_container + stdout, stderr, rc = docker_exec( + container, "python -c \"from redis import Redis; r=Redis.from_url('" + + inst.redis_url + "'); print(r.ping())\"", + **_get_ssh_params(inst), timeout=15, + ) + if rc == 0 and "True" in stdout: + status = TestStatus.PASS + msg = "Redis ping successful" + elif rc == -2: + status = TestStatus.SKIP + msg = "Docker not available" + else: + status = TestStatus.WARN + msg = f"Redis ping failed: {stderr[:200] if stderr else stdout[:200]}" + + report.add_result(TestResult( + category="docker", + name=f"Redis connectivity ({label})", + status=status, + message=msg, + duration_seconds=time.time() - t0, + )) + + def _test_postgres_connectivity(self, report: ComparisonReport): + """Test PostgreSQL connectivity from within the Flask container.""" + t0 = time.time() + for label, inst in [("A", self.inst_a), ("B", self.inst_b)]: + container = inst.docker_flask_container + stdout, stderr, rc = docker_exec( + container, + f"python -c \"import psycopg2; c=psycopg2.connect('{inst.pg_dsn}', connect_timeout=5); " + f"cur=c.cursor(); cur.execute('SELECT 1'); print(cur.fetchone()[0]); c.close()\"", + **_get_ssh_params(inst), timeout=15, + ) + if rc == 0 and "1" in stdout: + status = TestStatus.PASS + msg = "PostgreSQL SELECT 1 successful" + elif rc == -2: + status = TestStatus.SKIP + msg = "Docker not available" + else: + status = TestStatus.WARN + msg = f"PostgreSQL test failed: {stderr[:200] if stderr else stdout[:200]}" + + report.add_result(TestResult( + category="docker", + name=f"PostgreSQL connectivity ({label})", + status=status, + message=msg, + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _is_running(self, info: Optional[dict]) -> Optional[bool]: + """Check if container is running from inspect data.""" + if not info: + return None + state = info.get("State", {}) + return state.get("Running", False) + + def _get_restart_count(self, info: Optional[dict]) -> Optional[int]: + """Get container restart count from inspect data.""" + if not info: + return None + return info.get("RestartCount", 0) + + def _get_health_status(self, info: Optional[dict]) -> Optional[str]: + """Get container health status from inspect data.""" + if not info: + return None + state = info.get("State", {}) + health = state.get("Health", {}) + return health.get("Status") if health else None + + def _get_container_stats(self, name: str, inst: InstanceConfig) -> Optional[dict]: + """Get container resource stats via docker stats --no-stream.""" + ssh_params = _get_ssh_params(inst) + cmd_parts = ["docker", "stats", "--no-stream", "--format", + "{{.MemUsage}}|||{{.CPUPerc}}", name] + + if ssh_params.get("ssh_host"): + import subprocess + ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", + "-p", str(ssh_params.get("ssh_port", 22))] + if ssh_params.get("ssh_key"): + ssh_cmd += ["-i", ssh_params["ssh_key"]] + host = f"{ssh_params['ssh_user']}@{ssh_params['ssh_host']}" \ + if ssh_params.get("ssh_user") else ssh_params["ssh_host"] + ssh_cmd.append(host) + ssh_cmd.append(" ".join(cmd_parts)) + full_cmd = ssh_cmd + else: + full_cmd = cmd_parts + + try: + import subprocess + proc = subprocess.run(full_cmd, capture_output=True, text=True, timeout=15) + if proc.returncode == 0 and proc.stdout.strip(): + parts = proc.stdout.strip().split("|||") + if len(parts) == 2: + mem_str = parts[0].strip() + cpu_str = parts[1].strip().rstrip('%') + + # Parse memory (e.g., "512MiB / 16GiB") + mem_match = re.search(r'([\d.]+)(Ki|Mi|Gi|B)', mem_str) + memory_mb = 0.0 + if mem_match: + val = float(mem_match.group(1)) + unit = mem_match.group(2) + if unit == "Gi": + memory_mb = val * 1024 + elif unit == "Mi": + memory_mb = val + elif unit == "Ki": + memory_mb = val / 1024 + else: + memory_mb = val / (1024 * 1024) + + cpu_pct = float(cpu_str) if cpu_str else 0.0 + + return {"memory_mb": memory_mb, "cpu_pct": cpu_pct} + except Exception as e: + logger.debug(f"Stats fetch failed for {name}: {e}") + return None + + def _count_patterns(self, log_text: str, patterns: list) -> Counter: + """Count occurrences of each pattern in log text.""" + counts = Counter() + for pattern, name in patterns: + matches = re.findall(pattern, log_text) + if matches: + counts[name] = len(matches) + return counts diff --git a/testing_suite/comparators/performance_comparator.py b/testing_suite/comparators/performance_comparator.py new file mode 100644 index 00000000..21a573d1 --- /dev/null +++ b/testing_suite/comparators/performance_comparator.py @@ -0,0 +1,447 @@ +""" +Performance Comparison Module for AudioMuse-AI Testing Suite. + +Benchmarks and compares performance between two instances: + - API endpoint latency (p50, p95, p99, mean, max) + - Throughput under concurrent load + - Database query performance + - Search/similarity response times + - Memory-intensive operations (map, alchemy, clustering) + - Warmup vs steady-state performance +""" + +import logging +import statistics +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any, Dict, List, Optional, Tuple + +from testing_suite.config import ComparisonConfig, InstanceConfig +from testing_suite.utils import ( + ComparisonReport, TestResult, TestStatus, + http_get, http_post, timed_request, pct_diff, format_duration, + pg_query, pg_scalar +) + +logger = logging.getLogger(__name__) + + +def _percentile(data: List[float], pct: float) -> float: + """Calculate a percentile from a sorted list of values.""" + if not data: + return 0.0 + sorted_data = sorted(data) + k = (len(sorted_data) - 1) * (pct / 100.0) + f = int(k) + c = f + 1 + if c >= len(sorted_data): + return sorted_data[-1] + return sorted_data[f] + (k - f) * (sorted_data[c] - sorted_data[f]) + + +def _latency_stats(latencies: List[float]) -> dict: + """Compute latency statistics from a list of measurements.""" + if not latencies: + return {"count": 0, "mean": 0, "median": 0, "p95": 0, "p99": 0, + "min": 0, "max": 0, "stddev": 0} + return { + "count": len(latencies), + "mean": statistics.mean(latencies), + "median": statistics.median(latencies), + "p95": _percentile(latencies, 95), + "p99": _percentile(latencies, 99), + "min": min(latencies), + "max": max(latencies), + "stddev": statistics.stdev(latencies) if len(latencies) > 1 else 0.0, + } + + +class PerformanceComparator: + """Benchmarks and compares performance between two instances.""" + + def __init__(self, config: ComparisonConfig): + self.config = config + self.url_a = config.instance_a.api_url.rstrip('/') + self.url_b = config.instance_b.api_url.rstrip('/') + self.name_a = config.instance_a.name + self.name_b = config.instance_b.name + self.warmup_n = config.perf_warmup_requests + self.bench_n = config.perf_benchmark_requests + self.concurrent = config.perf_concurrent_users + + def run_all(self, report: ComparisonReport): + """Run all performance comparison tests.""" + logger.info("Starting performance comparison tests...") + + # Check connectivity + alive_a = self._check_alive(self.url_a) + alive_b = self._check_alive(self.url_b) + + if not alive_a and not alive_b: + report.add_result(TestResult( + category="performance", + name="Connectivity", + status=TestStatus.ERROR, + message="Neither instance reachable; skipping performance tests", + )) + return + + # Define endpoint benchmarks + benchmarks = [ + # (path, method, params/json, description, expected_status) + ("/api/config", "GET", None, "Config endpoint", 200), + ("/api/playlists", "GET", None, "Playlists list", 200), + ("/api/active_tasks", "GET", None, "Active tasks", 200), + ("/api/last_task", "GET", None, "Last task", 200), + ("/api/search_tracks?artist=Red+Hot&title=By", "GET", None, "Track search", 200), + ("/api/similar_tracks?title=By+the+Way&artist=Red+Hot+Chili+Peppers&n=5", "GET", None, "Similar tracks", 200), + ("/api/map?percent=10", "GET", None, "Map visualization", 200), + ("/api/map_cache_status", "GET", None, "Map cache status", 200), + ("/api/clap/stats", "GET", None, "CLAP stats", 200), + ("/api/clap/top_queries", "GET", None, "CLAP top queries", 200), + ("/api/setup/status", "GET", None, "Setup status", 200), + ("/api/setup/providers", "GET", None, "Providers list", 200), + ("/api/setup/settings", "GET", None, "App settings", 200), + ("/api/cron", "GET", None, "Cron entries", 200), + ("/api/search_artists?q=Red", "GET", None, "Artist search", 200), + ("/external/search?q=piano", "GET", None, "External search", 200), + ] + + # Run latency benchmarks for each endpoint + for path, method, data, desc, expected_status in benchmarks: + self._benchmark_endpoint(report, path, method, data, desc, + expected_status, alive_a, alive_b) + + # Concurrent load test on a few key endpoints + self._concurrent_load_test(report, alive_a, alive_b) + + # Database query performance + self._benchmark_db_queries(report) + + logger.info("Performance comparison tests complete.") + + # ------------------------------------------------------------------ + # Endpoint latency benchmark + # ------------------------------------------------------------------ + + def _benchmark_endpoint(self, report: ComparisonReport, path: str, + method: str, data: Any, description: str, + expected_status: int, alive_a: bool, alive_b: bool): + """Benchmark a single endpoint on both instances.""" + t0 = time.time() + latencies_a = [] + latencies_b = [] + errors_a = 0 + errors_b = 0 + + # Warmup phase + for _ in range(self.warmup_n): + try: + if alive_a: + if method == "GET": + http_get(f"{self.url_a}{path}", timeout=30, retries=1) + else: + http_post(f"{self.url_a}{path}", json_data=data, timeout=30, retries=1) + except Exception: + pass + try: + if alive_b: + if method == "GET": + http_get(f"{self.url_b}{path}", timeout=30, retries=1) + else: + http_post(f"{self.url_b}{path}", json_data=data, timeout=30, retries=1) + except Exception: + pass + + # Benchmark phase + for i in range(self.bench_n): + if alive_a: + try: + start = time.perf_counter() + if method == "GET": + resp = http_get(f"{self.url_a}{path}", timeout=60, retries=1) + else: + resp = http_post(f"{self.url_a}{path}", json_data=data, timeout=60, retries=1) + elapsed = time.perf_counter() - start + if resp.status_code == expected_status: + latencies_a.append(elapsed) + else: + errors_a += 1 + except Exception: + errors_a += 1 + + if alive_b: + try: + start = time.perf_counter() + if method == "GET": + resp = http_get(f"{self.url_b}{path}", timeout=60, retries=1) + else: + resp = http_post(f"{self.url_b}{path}", json_data=data, timeout=60, retries=1) + elapsed = time.perf_counter() - start + if resp.status_code == expected_status: + latencies_b.append(elapsed) + else: + errors_b += 1 + except Exception: + errors_b += 1 + + stats_a = _latency_stats(latencies_a) + stats_b = _latency_stats(latencies_b) + + # Determine status based on relative performance + if stats_a['mean'] > 0 and stats_b['mean'] > 0: + ratio = stats_b['mean'] / stats_a['mean'] + if ratio <= 1.2: # B is within 20% of A + status = TestStatus.PASS + comparison = f"B is {ratio:.2f}x vs A" + elif ratio <= 2.0: + status = TestStatus.WARN + comparison = f"B is {ratio:.2f}x slower than A" + else: + status = TestStatus.FAIL + comparison = f"B is {ratio:.2f}x slower than A" + + # Also check if B is faster + if ratio < 0.8: + status = TestStatus.PASS + comparison = f"B is {1/ratio:.2f}x faster than A" + else: + status = TestStatus.WARN if (latencies_a or latencies_b) else TestStatus.SKIP + comparison = "Cannot compare (one or both had no successful requests)" + + report.add_result(TestResult( + category="performance", + name=f"Latency: {description}", + status=status, + message=( + f"{comparison} | " + f"A: mean={format_duration(stats_a['mean'])}, " + f"p95={format_duration(stats_a['p95'])}, " + f"p99={format_duration(stats_a['p99'])} " + f"({errors_a} errors) | " + f"B: mean={format_duration(stats_b['mean'])}, " + f"p95={format_duration(stats_b['p95'])}, " + f"p99={format_duration(stats_b['p99'])} " + f"({errors_b} errors)" + ), + instance_a_value=stats_a, + instance_b_value=stats_b, + diff=pct_diff(stats_a['mean'], stats_b['mean']) if stats_a['mean'] and stats_b['mean'] else None, + duration_seconds=time.time() - t0, + details={ + "path": path, + "method": method, + "warmup_requests": self.warmup_n, + "benchmark_requests": self.bench_n, + "errors_a": errors_a, + "errors_b": errors_b, + }, + )) + + # ------------------------------------------------------------------ + # Concurrent load test + # ------------------------------------------------------------------ + + def _concurrent_load_test(self, report: ComparisonReport, + alive_a: bool, alive_b: bool): + """Test throughput under concurrent load.""" + endpoints = [ + "/api/config", + "/api/search_tracks?artist=Red+Hot&title=By", + "/api/playlists", + ] + + for path in endpoints: + t0 = time.time() + results_a = self._run_concurrent(self.url_a, path, self.concurrent, + self.bench_n) if alive_a else None + results_b = self._run_concurrent(self.url_b, path, self.concurrent, + self.bench_n) if alive_b else None + + if results_a and results_b: + throughput_a = results_a['successful'] / results_a['total_time'] if results_a['total_time'] > 0 else 0 + throughput_b = results_b['successful'] / results_b['total_time'] if results_b['total_time'] > 0 else 0 + + if throughput_a > 0 and throughput_b > 0: + ratio = throughput_b / throughput_a + if ratio >= 0.8: + status = TestStatus.PASS + elif ratio >= 0.5: + status = TestStatus.WARN + else: + status = TestStatus.FAIL + else: + status = TestStatus.WARN + ratio = 0 + + report.add_result(TestResult( + category="performance", + name=f"Concurrent Load: {path.split('?')[0]}", + status=status, + message=( + f"{self.concurrent} concurrent users, {self.bench_n} requests each | " + f"A: {throughput_a:.1f} req/s, " + f"mean={format_duration(results_a['mean_latency'])}, " + f"{results_a['errors']} errors | " + f"B: {throughput_b:.1f} req/s, " + f"mean={format_duration(results_b['mean_latency'])}, " + f"{results_b['errors']} errors" + ), + instance_a_value={"throughput_rps": throughput_a, **results_a}, + instance_b_value={"throughput_rps": throughput_b, **results_b}, + duration_seconds=time.time() - t0, + )) + else: + report.add_result(TestResult( + category="performance", + name=f"Concurrent Load: {path.split('?')[0]}", + status=TestStatus.SKIP, + message="Cannot run concurrent test (one or both instances unavailable)", + duration_seconds=time.time() - t0, + )) + + def _run_concurrent(self, base_url: str, path: str, + concurrent: int, requests_per_worker: int) -> dict: + """Run concurrent requests and measure throughput.""" + latencies = [] + errors = 0 + + def worker(): + nonlocal errors + local_latencies = [] + for _ in range(requests_per_worker): + try: + start = time.perf_counter() + resp = http_get(f"{base_url}{path}", timeout=30, retries=1) + elapsed = time.perf_counter() - start + if resp.status_code == 200: + local_latencies.append(elapsed) + else: + errors += 1 + except Exception: + errors += 1 + return local_latencies + + overall_start = time.perf_counter() + with ThreadPoolExecutor(max_workers=concurrent) as executor: + futures = [executor.submit(worker) for _ in range(concurrent)] + for f in as_completed(futures): + try: + latencies.extend(f.result()) + except Exception: + errors += 1 + total_time = time.perf_counter() - overall_start + + return { + "successful": len(latencies), + "errors": errors, + "total_time": total_time, + "mean_latency": statistics.mean(latencies) if latencies else 0, + "p95_latency": _percentile(latencies, 95) if latencies else 0, + } + + # ------------------------------------------------------------------ + # Database query performance + # ------------------------------------------------------------------ + + def _benchmark_db_queries(self, report: ComparisonReport): + """Benchmark critical database queries on both instances.""" + queries = [ + ("SELECT COUNT(*) FROM score", "Score count"), + ("SELECT COUNT(*) FROM embedding", "Embedding count"), + ("SELECT COUNT(*) FROM playlist", "Playlist count"), + ("SELECT COUNT(DISTINCT playlist_name) FROM playlist", "Distinct playlists"), + ("SELECT item_id, title, author FROM score LIMIT 100", "Score fetch 100"), + ("SELECT s.item_id, s.title FROM score s JOIN embedding e ON s.item_id = e.item_id LIMIT 50", + "Score-embedding join 50"), + ("SELECT AVG(tempo), AVG(energy) FROM score WHERE tempo IS NOT NULL", "Score aggregation"), + ("SELECT key, COUNT(*) FROM score WHERE key IS NOT NULL GROUP BY key", "Key distribution"), + ] + + dsn_a = self.config.instance_a.pg_dsn + dsn_b = self.config.instance_b.pg_dsn + + can_a = self._test_db(dsn_a) + can_b = self._test_db(dsn_b) + + if not can_a and not can_b: + report.add_result(TestResult( + category="performance", + name="DB Query Performance", + status=TestStatus.SKIP, + message="Cannot connect to either database", + )) + return + + for sql, desc in queries: + t0 = time.time() + latencies_a = [] + latencies_b = [] + + for _ in range(max(3, self.bench_n // 2)): + if can_a: + try: + start = time.perf_counter() + pg_query(dsn_a, sql) + latencies_a.append(time.perf_counter() - start) + except Exception: + pass + + if can_b: + try: + start = time.perf_counter() + pg_query(dsn_b, sql) + latencies_b.append(time.perf_counter() - start) + except Exception: + pass + + stats_a = _latency_stats(latencies_a) + stats_b = _latency_stats(latencies_b) + + if stats_a['mean'] > 0 and stats_b['mean'] > 0: + ratio = stats_b['mean'] / stats_a['mean'] + if ratio <= 1.5: + status = TestStatus.PASS + elif ratio <= 3.0: + status = TestStatus.WARN + else: + status = TestStatus.FAIL + comparison = f"B/A ratio: {ratio:.2f}x" + else: + status = TestStatus.WARN + comparison = "Insufficient data" + + report.add_result(TestResult( + category="performance", + name=f"DB Query: {desc}", + status=status, + message=( + f"{comparison} | " + f"A: mean={format_duration(stats_a['mean'])}, " + f"p95={format_duration(stats_a['p95'])} | " + f"B: mean={format_duration(stats_b['mean'])}, " + f"p95={format_duration(stats_b['p95'])}" + ), + instance_a_value=stats_a, + instance_b_value=stats_b, + diff=pct_diff(stats_a['mean'], stats_b['mean']) if stats_a['mean'] and stats_b['mean'] else None, + duration_seconds=time.time() - t0, + )) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _check_alive(self, url: str) -> bool: + try: + resp = http_get(f"{url}/api/config", timeout=15, retries=2, retry_delay=1) + return resp.status_code == 200 + except Exception: + return False + + def _test_db(self, dsn: str) -> bool: + try: + pg_scalar(dsn, "SELECT 1") + return True + except Exception: + return False diff --git a/testing_suite/comparison_config.example.yaml b/testing_suite/comparison_config.example.yaml new file mode 100644 index 00000000..a73217f3 --- /dev/null +++ b/testing_suite/comparison_config.example.yaml @@ -0,0 +1,91 @@ +# AudioMuse-AI Testing & Comparison Suite - Example Configuration +# +# Copy this file to comparison_config.yaml and customize for your setup. +# +# Usage: +# python -m testing_suite --config testing_suite/comparison_config.yaml +# python -m testing_suite --config testing_suite/comparison_config.yaml --only api,performance +# python -m testing_suite --config testing_suite/comparison_config.yaml --skip docker + +# Instance A - Main branch (baseline) +instance_a: + name: "main" + branch: "main" + # API connection + api_url: "http://192.168.1.100:8000" + api_timeout: 120 + # PostgreSQL connection + pg_host: "192.168.1.100" + pg_port: 5432 + pg_user: "audiomuse" + pg_password: "audiomusepassword" + pg_database: "audiomusedb" + # Redis connection + redis_url: "redis://192.168.1.100:6379/0" + # Docker container names + docker_flask_container: "audiomuse-ai-flask-app" + docker_worker_container: "audiomuse-ai-worker-instance" + docker_postgres_container: "audiomuse-postgres" + docker_redis_container: "audiomuse-redis" + # SSH for remote Docker access (leave empty for local) + ssh_host: "" + ssh_user: "" + ssh_key: "" + ssh_port: 22 + +# Instance B - Feature branch (under test) +instance_b: + name: "feature" + branch: "feature-branch" + api_url: "http://192.168.1.101:8000" + api_timeout: 120 + pg_host: "192.168.1.101" + pg_port: 5432 + pg_user: "audiomuse" + pg_password: "audiomusepassword" + pg_database: "audiomusedb" + redis_url: "redis://192.168.1.101:6379/0" + docker_flask_container: "audiomuse-ai-flask-app" + docker_worker_container: "audiomuse-ai-worker-instance" + docker_postgres_container: "audiomuse-postgres" + docker_redis_container: "audiomuse-redis" + ssh_host: "" + ssh_user: "" + ssh_key: "" + ssh_port: 22 + +# Test modules to run (true/false) +run_api_tests: true +run_db_tests: true +run_docker_tests: true +run_performance_tests: true +run_existing_unit_tests: true +run_existing_integration_tests: true + +# Performance test settings +perf_warmup_requests: 3 # Warmup requests per endpoint before measuring +perf_benchmark_requests: 10 # Measured requests per endpoint +perf_concurrent_users: 5 # Concurrent users for load tests + +# API test settings +api_retries: 3 # Retries on connection errors +api_retry_delay: 2.0 # Seconds between retries +api_task_timeout: 1200 # Timeout for long-running tasks (20 min) + +# Database quality thresholds +db_row_count_tolerance_pct: 5.0 # % difference allowed in row counts +db_embedding_dimension_expected: 200 +db_clap_dimension_expected: 512 +db_score_null_threshold_pct: 10.0 # Max % NULLs in critical columns + +# Reporting +output_dir: "testing_suite/reports/output" +report_format: "both" # html, json, or both +verbose: false + +# Test track references (used for functional API tests) +# These should be tracks that exist in both instances' music libraries +test_track_artist_1: "Red Hot Chili Peppers" +test_track_title_1: "By the Way" +test_track_artist_2: "System of a Down" +test_track_title_2: "Attack" diff --git a/testing_suite/config.py b/testing_suite/config.py new file mode 100644 index 00000000..1419c884 --- /dev/null +++ b/testing_suite/config.py @@ -0,0 +1,211 @@ +""" +Configuration for the AudioMuse-AI Testing & Comparison Suite. + +Defines connection parameters for two instances (Instance A / Instance B) +which typically correspond to main branch and feature branch deployments. + +Configuration can be provided via: + 1. Environment variables (INSTANCE_A_*, INSTANCE_B_*) + 2. A YAML config file (--config flag) + 3. CLI arguments +""" + +import os +import json +import logging +from dataclasses import dataclass, field, asdict +from typing import Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class InstanceConfig: + """Connection configuration for a single AudioMuse-AI instance.""" + + # Identity + name: str = "instance" + branch: str = "unknown" + + # API connection + api_url: str = "http://localhost:8000" + api_timeout: int = 120 + + # PostgreSQL connection + pg_host: str = "localhost" + pg_port: int = 5432 + pg_user: str = "audiomuse" + pg_password: str = "audiomusepassword" + pg_database: str = "audiomusedb" + + # Redis connection + redis_url: str = "redis://localhost:6379/0" + + # Docker container names (for log collection) + docker_flask_container: str = "audiomuse-ai-flask-app" + docker_worker_container: str = "audiomuse-ai-worker-instance" + docker_postgres_container: str = "audiomuse-postgres" + docker_redis_container: str = "audiomuse-redis" + + # Docker compose file (optional, for status checks) + docker_compose_file: str = "" + + # SSH details if instances are remote + ssh_host: str = "" + ssh_user: str = "" + ssh_key: str = "" + ssh_port: int = 22 + + @property + def pg_dsn(self) -> str: + """Construct PostgreSQL DSN from components.""" + from urllib.parse import quote + user = quote(self.pg_user, safe='') + password = quote(self.pg_password, safe='') + return f"postgresql://{user}:{password}@{self.pg_host}:{self.pg_port}/{self.pg_database}" + + def to_dict(self) -> dict: + return asdict(self) + + +@dataclass +class ComparisonConfig: + """Top-level configuration for the comparison suite.""" + + # Instance configurations + instance_a: InstanceConfig = field(default_factory=lambda: InstanceConfig( + name="main", branch="main" + )) + instance_b: InstanceConfig = field(default_factory=lambda: InstanceConfig( + name="feature", branch="feature" + )) + + # Test control flags + run_api_tests: bool = True + run_db_tests: bool = True + run_docker_tests: bool = True + run_performance_tests: bool = True + run_existing_unit_tests: bool = True + run_existing_integration_tests: bool = True + + # Performance test settings + perf_warmup_requests: int = 3 + perf_benchmark_requests: int = 10 + perf_concurrent_users: int = 5 + + # API test settings + api_retries: int = 3 + api_retry_delay: float = 2.0 + api_task_timeout: int = 1200 # 20 minutes for long-running tasks + + # Database comparison thresholds + db_row_count_tolerance_pct: float = 5.0 # % difference allowed in row counts + db_embedding_dimension_expected: int = 200 + db_clap_dimension_expected: int = 512 + db_score_null_threshold_pct: float = 10.0 # Max % of NULL values in critical columns + + # Reporting + output_dir: str = "testing_suite/reports/output" + report_format: str = "html" # html, json, or both + verbose: bool = False + + # Test track references for functional tests + test_track_artist_1: str = "Red Hot Chili Peppers" + test_track_title_1: str = "By the Way" + test_track_artist_2: str = "System of a Down" + test_track_title_2: str = "Attack" + + def to_dict(self) -> dict: + return { + "instance_a": self.instance_a.to_dict(), + "instance_b": self.instance_b.to_dict(), + "run_api_tests": self.run_api_tests, + "run_db_tests": self.run_db_tests, + "run_docker_tests": self.run_docker_tests, + "run_performance_tests": self.run_performance_tests, + "run_existing_unit_tests": self.run_existing_unit_tests, + "run_existing_integration_tests": self.run_existing_integration_tests, + "perf_warmup_requests": self.perf_warmup_requests, + "perf_benchmark_requests": self.perf_benchmark_requests, + "perf_concurrent_users": self.perf_concurrent_users, + "output_dir": self.output_dir, + "report_format": self.report_format, + } + + +def load_config_from_yaml(path: str) -> ComparisonConfig: + """Load comparison config from a YAML file.""" + try: + import yaml + except ImportError: + raise ImportError("PyYAML is required for YAML config files: pip install pyyaml") + + with open(path, 'r') as f: + data = yaml.safe_load(f) + + config = ComparisonConfig() + + if 'instance_a' in data: + for k, v in data['instance_a'].items(): + if hasattr(config.instance_a, k): + setattr(config.instance_a, k, v) + + if 'instance_b' in data: + for k, v in data['instance_b'].items(): + if hasattr(config.instance_b, k): + setattr(config.instance_b, k, v) + + # Top-level settings + for k, v in data.items(): + if k not in ('instance_a', 'instance_b') and hasattr(config, k): + setattr(config, k, v) + + return config + + +def load_config_from_env() -> ComparisonConfig: + """Load comparison config from environment variables.""" + config = ComparisonConfig() + + # Instance A + a = config.instance_a + a.name = os.getenv("INSTANCE_A_NAME", a.name) + a.branch = os.getenv("INSTANCE_A_BRANCH", a.branch) + a.api_url = os.getenv("INSTANCE_A_API_URL", a.api_url) + a.pg_host = os.getenv("INSTANCE_A_PG_HOST", a.pg_host) + a.pg_port = int(os.getenv("INSTANCE_A_PG_PORT", str(a.pg_port))) + a.pg_user = os.getenv("INSTANCE_A_PG_USER", a.pg_user) + a.pg_password = os.getenv("INSTANCE_A_PG_PASSWORD", a.pg_password) + a.pg_database = os.getenv("INSTANCE_A_PG_DATABASE", a.pg_database) + a.redis_url = os.getenv("INSTANCE_A_REDIS_URL", a.redis_url) + a.docker_flask_container = os.getenv("INSTANCE_A_FLASK_CONTAINER", a.docker_flask_container) + a.docker_worker_container = os.getenv("INSTANCE_A_WORKER_CONTAINER", a.docker_worker_container) + a.docker_postgres_container = os.getenv("INSTANCE_A_PG_CONTAINER", a.docker_postgres_container) + a.ssh_host = os.getenv("INSTANCE_A_SSH_HOST", a.ssh_host) + a.ssh_user = os.getenv("INSTANCE_A_SSH_USER", a.ssh_user) + a.ssh_key = os.getenv("INSTANCE_A_SSH_KEY", a.ssh_key) + + # Instance B + b = config.instance_b + b.name = os.getenv("INSTANCE_B_NAME", b.name) + b.branch = os.getenv("INSTANCE_B_BRANCH", b.branch) + b.api_url = os.getenv("INSTANCE_B_API_URL", b.api_url) + b.pg_host = os.getenv("INSTANCE_B_PG_HOST", b.pg_host) + b.pg_port = int(os.getenv("INSTANCE_B_PG_PORT", str(b.pg_port))) + b.pg_user = os.getenv("INSTANCE_B_PG_USER", b.pg_user) + b.pg_password = os.getenv("INSTANCE_B_PG_PASSWORD", b.pg_password) + b.pg_database = os.getenv("INSTANCE_B_PG_DATABASE", b.pg_database) + b.redis_url = os.getenv("INSTANCE_B_REDIS_URL", b.redis_url) + b.docker_flask_container = os.getenv("INSTANCE_B_FLASK_CONTAINER", b.docker_flask_container) + b.docker_worker_container = os.getenv("INSTANCE_B_WORKER_CONTAINER", b.docker_worker_container) + b.docker_postgres_container = os.getenv("INSTANCE_B_PG_CONTAINER", b.docker_postgres_container) + b.ssh_host = os.getenv("INSTANCE_B_SSH_HOST", b.ssh_host) + b.ssh_user = os.getenv("INSTANCE_B_SSH_USER", b.ssh_user) + b.ssh_key = os.getenv("INSTANCE_B_SSH_KEY", b.ssh_key) + + # Global settings + config.verbose = os.getenv("COMPARISON_VERBOSE", "false").lower() == "true" + config.output_dir = os.getenv("COMPARISON_OUTPUT_DIR", config.output_dir) + config.report_format = os.getenv("COMPARISON_REPORT_FORMAT", config.report_format) + + return config diff --git a/testing_suite/orchestrator.py b/testing_suite/orchestrator.py new file mode 100644 index 00000000..7e6849f9 --- /dev/null +++ b/testing_suite/orchestrator.py @@ -0,0 +1,163 @@ +""" +Main Orchestrator for the AudioMuse-AI Testing & Comparison Suite. + +Coordinates all comparison modules and generates the final report. +""" + +import json +import logging +import os +import time +from datetime import datetime + +from testing_suite.config import ComparisonConfig +from testing_suite.utils import ComparisonReport +from testing_suite.comparators.api_comparator import APIComparator +from testing_suite.comparators.db_comparator import DatabaseComparator +from testing_suite.comparators.docker_comparator import DockerComparator +from testing_suite.comparators.performance_comparator import PerformanceComparator +from testing_suite.test_runner.existing_tests import ExistingTestRunner +from testing_suite.reports.html_report import generate_html_report + +logger = logging.getLogger(__name__) + + +class ComparisonOrchestrator: + """Orchestrates all comparison modules and produces the final report.""" + + def __init__(self, config: ComparisonConfig): + self.config = config + self.report = ComparisonReport( + instance_a_name=config.instance_a.name, + instance_b_name=config.instance_b.name, + instance_a_branch=config.instance_a.branch, + instance_b_branch=config.instance_b.branch, + config_snapshot=config.to_dict(), + ) + + def run(self) -> ComparisonReport: + """Run all configured comparison modules and return the report.""" + overall_start = time.time() + + logger.info("=" * 70) + logger.info("AudioMuse-AI Testing & Comparison Suite") + logger.info("=" * 70) + logger.info(f"Instance A: {self.config.instance_a.name} " + f"({self.config.instance_a.branch}) " + f"@ {self.config.instance_a.api_url}") + logger.info(f"Instance B: {self.config.instance_b.name} " + f"({self.config.instance_b.branch}) " + f"@ {self.config.instance_b.api_url}") + logger.info("-" * 70) + + # Run each module based on config flags + modules = [] + + if self.config.run_db_tests: + modules.append(("Database Comparison", DatabaseComparator(self.config))) + + if self.config.run_api_tests: + modules.append(("API Comparison", APIComparator(self.config))) + + if self.config.run_docker_tests: + modules.append(("Docker Comparison", DockerComparator(self.config))) + + if self.config.run_performance_tests: + modules.append(("Performance Benchmark", PerformanceComparator(self.config))) + + if self.config.run_existing_unit_tests or self.config.run_existing_integration_tests: + modules.append(("Existing Tests", ExistingTestRunner(self.config))) + + for name, module in modules: + logger.info(f"\n{'='*50}") + logger.info(f"Running: {name}") + logger.info(f"{'='*50}") + try: + module_start = time.time() + module.run_all(self.report) + module_duration = time.time() - module_start + logger.info(f"{name} completed in {module_duration:.1f}s") + except Exception as e: + logger.error(f"{name} failed with error: {e}", exc_info=True) + from testing_suite.utils import TestResult, TestStatus + self.report.add_result(TestResult( + category=name.lower().replace(" ", "_"), + name=f"{name}: Module Error", + status=TestStatus.ERROR, + message=f"Module failed: {str(e)}", + )) + + # Generate reports + self._generate_reports() + + overall_duration = time.time() - overall_start + logger.info(f"\n{'='*70}") + logger.info(f"Testing complete in {overall_duration:.1f}s") + logger.info(f"Overall status: {self.report.overall_status.value}") + logger.info(f"Total: {self.report.total_tests} tests, " + f"{self.report.total_passed} passed, " + f"{self.report.total_failed} failed, " + f"{self.report.total_errors} errors") + logger.info(f"{'='*70}") + + return self.report + + def _generate_reports(self): + """Generate output reports in configured formats.""" + os.makedirs(self.config.output_dir, exist_ok=True) + timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S") + + # JSON report (always generated) + json_path = os.path.join(self.config.output_dir, f"comparison_{timestamp}.json") + with open(json_path, 'w') as f: + json.dump(self.report.to_dict(), f, indent=2, default=str) + logger.info(f"JSON report saved: {json_path}") + + # HTML report + if self.config.report_format in ("html", "both"): + html_path = os.path.join(self.config.output_dir, f"comparison_{timestamp}.html") + generate_html_report(self.report, html_path) + logger.info(f"HTML report saved: {html_path}") + + # Also save a latest symlink/copy + for ext in ["json", "html"]: + src = os.path.join(self.config.output_dir, f"comparison_{timestamp}.{ext}") + dst = os.path.join(self.config.output_dir, f"comparison_latest.{ext}") + if os.path.exists(src): + try: + if os.path.exists(dst) or os.path.islink(dst): + os.unlink(dst) + os.symlink(os.path.basename(src), dst) + except OSError: + # Symlinks may not work on all systems; copy instead + import shutil + shutil.copy2(src, dst) + + def print_summary(self): + """Print a concise summary to stdout.""" + print(f"\n{'='*60}") + print(f" COMPARISON REPORT SUMMARY") + print(f"{'='*60}") + print(f" Instance A: {self.report.instance_a_name} ({self.report.instance_a_branch})") + print(f" Instance B: {self.report.instance_b_name} ({self.report.instance_b_branch})") + print(f" Overall: {self.report.overall_status.value}") + print(f" Total: {self.report.total_tests} | " + f"Pass: {self.report.total_passed} | " + f"Fail: {self.report.total_failed} | " + f"Error: {self.report.total_errors}") + print(f"{'='*60}") + + for cat_name, cat in self.report.categories.items(): + indicator = "PASS" if cat.failed == 0 and cat.errors == 0 else "FAIL" + print(f" [{indicator:4s}] {cat_name:25s} " + f"P:{cat.passed:3d} F:{cat.failed:3d} " + f"W:{cat.warned:3d} S:{cat.skipped:3d} E:{cat.errors:3d}") + + # Show failed tests + for r in cat.results: + if r.status.value in ("FAIL", "ERROR"): + print(f" X {r.name}: {r.message[:80]}") + + print(f"{'='*60}") + print(f" Reports: {self.config.output_dir}/comparison_latest.*") + print(f"{'='*60}\n") diff --git a/testing_suite/reports/__init__.py b/testing_suite/reports/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testing_suite/reports/html_report.py b/testing_suite/reports/html_report.py new file mode 100644 index 00000000..527268df --- /dev/null +++ b/testing_suite/reports/html_report.py @@ -0,0 +1,359 @@ +""" +HTML Report Generator for AudioMuse-AI Testing Suite. + +Generates a comprehensive, self-contained HTML report with: + - Overall pass/fail summary with status badges + - Per-category breakdowns with expandable details + - Color-coded results (green/red/yellow/gray) + - Performance charts (latency comparisons) + - Filterable and sortable tables + - Instance A vs B side-by-side comparisons +""" + +import json +import os +from datetime import datetime +from typing import Dict + +from testing_suite.utils import ComparisonReport, TestStatus + + +def generate_html_report(report: ComparisonReport, output_path: str) -> str: + """Generate a self-contained HTML report and write it to output_path.""" + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + + report_dict = report.to_dict() + categories = report_dict.get("categories", {}) + + # Build category HTML sections + category_sections = "" + for cat_name, cat_data in categories.items(): + category_sections += _build_category_section(cat_name, cat_data) + + # Build performance chart data (if performance category exists) + perf_chart_data = _build_performance_chart_data(categories.get("performance", {})) + + html = f""" + + + + +AudioMuse-AI Comparison Report + + + +
    +
    +

    AudioMuse-AI Comparison Report

    +
    + Generated: {report.timestamp} UTC
    + Instance A: {report.instance_a_name} (branch: {report.instance_a_branch})
    + Instance B: {report.instance_b_name} (branch: {report.instance_b_branch}) +
    +
    + +
    +
    +
    {report.overall_status.value}
    +
    Overall Status
    +
    +
    +
    {report.total_tests}
    +
    Total Tests
    +
    +
    +
    {report.total_passed}
    +
    Passed
    +
    +
    +
    {report.total_failed}
    +
    Failed
    +
    +
    +
    {report.total_errors}
    +
    Errors
    +
    +
    +
    {sum(c.warned for c in report.categories.values())}
    +
    Warnings
    +
    +
    + + {category_sections} + + {_build_perf_visual(perf_chart_data) if perf_chart_data else ""} + +
    + +
    + AudioMuse-AI Testing & Comparison Suite v1.0.0 +
    + + + +""" + + with open(output_path, 'w') as f: + f.write(html) + + return output_path + + +def _status_class(status) -> str: + """Get CSS class for a status value.""" + if isinstance(status, TestStatus): + status = status.value + return f"status-{status.lower()}" + + +def _badge_html(status) -> str: + """Generate a badge HTML element for a status.""" + if isinstance(status, TestStatus): + status = status.value + return f'{status}' + + +def _build_category_section(cat_name: str, cat_data: dict) -> str: + """Build HTML section for a test category.""" + total = cat_data.get("total", 0) + passed = cat_data.get("passed", 0) + failed = cat_data.get("failed", 0) + warned = cat_data.get("warned", 0) + skipped = cat_data.get("skipped", 0) + errors = cat_data.get("errors", 0) + results = cat_data.get("results", []) + + # Category display name + display_names = { + "api": "API Endpoints", + "database": "Database Quality", + "docker": "Docker & Infrastructure", + "performance": "Performance Benchmarks", + "existing_tests": "Existing Test Suite", + } + display_name = display_names.get(cat_name, cat_name.replace("_", " ").title()) + + # Build table rows + rows = "" + for r in results: + status = r.get("status", "SKIP") + duration = r.get("duration_seconds", 0) + duration_str = f"{duration:.2f}s" if duration else "-" + + # Format values for display + val_a = r.get("instance_a_value", "") + val_b = r.get("instance_b_value", "") + if isinstance(val_a, (dict, list)): + val_a = json.dumps(val_a, indent=1, default=str)[:200] + if isinstance(val_b, (dict, list)): + val_b = json.dumps(val_b, indent=1, default=str)[:200] + + rows += f""" + + {_badge_html(status)} + {r.get('name', '')} + {_escape_html(r.get('message', ''))} + {_escape_html(str(val_a)[:150])} + {_escape_html(str(val_b)[:150])} + {duration_str} + """ + + return f""" +
    +
    +
    + + {display_name} +
    +
    + {passed} passed + {failed} failed + {warned} warn + {skipped} skip + {errors} err + ({total} total) +
    +
    +
    +
    + Filter: + + + + + + +
    + + + + + + + + + + + + + {rows} + +
    StatusTest NameMessageInstance AInstance BDuration
    +
    +
    """ + + +def _build_performance_chart_data(perf_category: dict) -> list: + """Extract performance comparison data for visualization.""" + if not perf_category: + return [] + + chart_data = [] + for result in perf_category.get("results", []): + name = result.get("name", "") + if not name.startswith("Latency:"): + continue + + val_a = result.get("instance_a_value", {}) + val_b = result.get("instance_b_value", {}) + + if isinstance(val_a, dict) and isinstance(val_b, dict): + mean_a = val_a.get("mean", 0) + mean_b = val_b.get("mean", 0) + if mean_a > 0 or mean_b > 0: + chart_data.append({ + "name": name.replace("Latency: ", ""), + "mean_a": mean_a, + "mean_b": mean_b, + "p95_a": val_a.get("p95", 0), + "p95_b": val_b.get("p95", 0), + }) + + return chart_data + + +def _build_perf_visual(chart_data: list) -> str: + """Build a visual performance comparison section.""" + if not chart_data: + return "" + + max_val = max( + max(d["mean_a"], d["mean_b"], d["p95_a"], d["p95_b"]) + for d in chart_data + ) or 1 + + bars = "" + for d in chart_data: + width_a = max(2, int(d["mean_a"] / max_val * 400)) + width_b = max(2, int(d["mean_b"] / max_val * 400)) + + bars += f""" +
    +
    {d['name']}
    +
    +
    A
    +
    + {d['mean_a']*1000:.1f}ms +
    +
    +
    B
    +
    + {d['mean_b']*1000:.1f}ms +
    +
    """ + + return f""" +

    Performance Visual Comparison (Mean Latency)

    +
    + Instance A + Instance B +
    +
    + {bars} +
    """ + + +def _escape_html(text: str) -> str: + """Escape HTML special characters.""" + return (str(text) + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'")) diff --git a/testing_suite/requirements.txt b/testing_suite/requirements.txt new file mode 100644 index 00000000..ecec1016 --- /dev/null +++ b/testing_suite/requirements.txt @@ -0,0 +1,7 @@ +# Requirements for the AudioMuse-AI Testing & Comparison Suite +requests>=2.28.0 +psycopg2-binary>=2.9.0 +pyyaml>=6.0 +pytest>=7.0.0 +pytest-json-report>=1.5.0 +pytest-timeout>=2.1.0 diff --git a/testing_suite/run_comparison.py b/testing_suite/run_comparison.py new file mode 100644 index 00000000..b07ea7ce --- /dev/null +++ b/testing_suite/run_comparison.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python3 +""" +AudioMuse-AI Testing & Comparison Suite - CLI Entry Point + +Comprehensive tool to test all features, database quality, API results, +and performance between two AudioMuse-AI instances (e.g., main branch vs feature branch). + +Usage: + # Quick comparison with two API URLs (minimal config): + python -m testing_suite.run_comparison \ + --url-a http://main-instance:8000 \ + --url-b http://feature-instance:8000 + + # Full comparison with database and Docker access: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 \ + --pg-host-a main-db-host --pg-host-b feature-db-host \ + --flask-container-a audiomuse-main-flask --flask-container-b audiomuse-feature-flask + + # From YAML config file: + python -m testing_suite.run_comparison --config comparison_config.yaml + + # Only run specific test categories: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 \ + --only api,performance + + # Skip slow tests: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 \ + --skip docker,existing_tests + + # Discover available tests: + python -m testing_suite.run_comparison --discover + + # Verbose output: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 -v +""" + +import argparse +import logging +import sys +import os + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from testing_suite.config import ( + ComparisonConfig, InstanceConfig, + load_config_from_yaml, load_config_from_env, +) +from testing_suite.orchestrator import ComparisonOrchestrator +from testing_suite.test_runner.existing_tests import ExistingTestRunner + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="AudioMuse-AI Testing & Comparison Suite", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Config source + parser.add_argument("--config", "-c", type=str, default="", + help="Path to YAML config file") + + # Discovery mode + parser.add_argument("--discover", action="store_true", + help="Discover and list all available tests, then exit") + + # Instance A + grp_a = parser.add_argument_group("Instance A (main/baseline)") + grp_a.add_argument("--url-a", type=str, default="", + help="API URL for instance A (e.g., http://localhost:8000)") + grp_a.add_argument("--name-a", type=str, default="main", + help="Name for instance A (default: main)") + grp_a.add_argument("--branch-a", type=str, default="main", + help="Branch name for instance A") + grp_a.add_argument("--pg-host-a", type=str, default="", + help="PostgreSQL host for instance A") + grp_a.add_argument("--pg-port-a", type=int, default=5432, + help="PostgreSQL port for instance A") + grp_a.add_argument("--pg-user-a", type=str, default="audiomuse", + help="PostgreSQL user for instance A") + grp_a.add_argument("--pg-pass-a", type=str, default="audiomusepassword", + help="PostgreSQL password for instance A") + grp_a.add_argument("--pg-db-a", type=str, default="audiomusedb", + help="PostgreSQL database for instance A") + grp_a.add_argument("--redis-a", type=str, default="", + help="Redis URL for instance A") + grp_a.add_argument("--flask-container-a", type=str, default="audiomuse-ai-flask-app", + help="Docker flask container name for A") + grp_a.add_argument("--worker-container-a", type=str, default="audiomuse-ai-worker-instance", + help="Docker worker container name for A") + grp_a.add_argument("--ssh-host-a", type=str, default="", + help="SSH host for remote Docker access (instance A)") + grp_a.add_argument("--ssh-user-a", type=str, default="", + help="SSH user for remote Docker access (instance A)") + grp_a.add_argument("--ssh-key-a", type=str, default="", + help="SSH key file for remote Docker access (instance A)") + + # Instance B + grp_b = parser.add_argument_group("Instance B (feature/test)") + grp_b.add_argument("--url-b", type=str, default="", + help="API URL for instance B (e.g., http://localhost:8001)") + grp_b.add_argument("--name-b", type=str, default="feature", + help="Name for instance B (default: feature)") + grp_b.add_argument("--branch-b", type=str, default="feature", + help="Branch name for instance B") + grp_b.add_argument("--pg-host-b", type=str, default="", + help="PostgreSQL host for instance B") + grp_b.add_argument("--pg-port-b", type=int, default=5432, + help="PostgreSQL port for instance B") + grp_b.add_argument("--pg-user-b", type=str, default="audiomuse", + help="PostgreSQL user for instance B") + grp_b.add_argument("--pg-pass-b", type=str, default="audiomusepassword", + help="PostgreSQL password for instance B") + grp_b.add_argument("--pg-db-b", type=str, default="audiomusedb", + help="PostgreSQL database for instance B") + grp_b.add_argument("--redis-b", type=str, default="", + help="Redis URL for instance B") + grp_b.add_argument("--flask-container-b", type=str, default="audiomuse-ai-flask-app", + help="Docker flask container name for B") + grp_b.add_argument("--worker-container-b", type=str, default="audiomuse-ai-worker-instance", + help="Docker worker container name for B") + grp_b.add_argument("--ssh-host-b", type=str, default="", + help="SSH host for remote Docker access (instance B)") + grp_b.add_argument("--ssh-user-b", type=str, default="", + help="SSH user for remote Docker access (instance B)") + grp_b.add_argument("--ssh-key-b", type=str, default="", + help="SSH key file for remote Docker access (instance B)") + + # Test selection + grp_t = parser.add_argument_group("Test Selection") + grp_t.add_argument("--only", type=str, default="", + help="Only run these categories (comma-separated: api,db,docker,performance,existing_tests)") + grp_t.add_argument("--skip", type=str, default="", + help="Skip these categories (comma-separated)") + + # Performance settings + grp_p = parser.add_argument_group("Performance Settings") + grp_p.add_argument("--warmup", type=int, default=3, + help="Warmup requests before benchmarking (default: 3)") + grp_p.add_argument("--bench-requests", type=int, default=10, + help="Benchmark requests per endpoint (default: 10)") + grp_p.add_argument("--concurrent", type=int, default=5, + help="Concurrent users for load test (default: 5)") + + # Output + grp_o = parser.add_argument_group("Output") + grp_o.add_argument("--output-dir", "-o", type=str, default="testing_suite/reports/output", + help="Output directory for reports") + grp_o.add_argument("--format", type=str, default="both", choices=["html", "json", "both"], + help="Report format (default: both)") + grp_o.add_argument("-v", "--verbose", action="store_true", + help="Verbose output") + + return parser + + +def build_config(args) -> ComparisonConfig: + """Build ComparisonConfig from CLI arguments.""" + # Start with YAML file or env if specified + if args.config: + config = load_config_from_yaml(args.config) + else: + config = load_config_from_env() + + # CLI overrides + a = config.instance_a + b = config.instance_b + + if args.url_a: + a.api_url = args.url_a + if args.name_a: + a.name = args.name_a + if args.branch_a: + a.branch = args.branch_a + if args.pg_host_a: + a.pg_host = args.pg_host_a + a.pg_port = args.pg_port_a + if args.pg_user_a: + a.pg_user = args.pg_user_a + if args.pg_pass_a: + a.pg_password = args.pg_pass_a + if args.pg_db_a: + a.pg_database = args.pg_db_a + if args.redis_a: + a.redis_url = args.redis_a + if args.flask_container_a: + a.docker_flask_container = args.flask_container_a + if args.worker_container_a: + a.docker_worker_container = args.worker_container_a + if args.ssh_host_a: + a.ssh_host = args.ssh_host_a + if args.ssh_user_a: + a.ssh_user = args.ssh_user_a + if args.ssh_key_a: + a.ssh_key = args.ssh_key_a + + if args.url_b: + b.api_url = args.url_b + if args.name_b: + b.name = args.name_b + if args.branch_b: + b.branch = args.branch_b + if args.pg_host_b: + b.pg_host = args.pg_host_b + b.pg_port = args.pg_port_b + if args.pg_user_b: + b.pg_user = args.pg_user_b + if args.pg_pass_b: + b.pg_password = args.pg_pass_b + if args.pg_db_b: + b.pg_database = args.pg_db_b + if args.redis_b: + b.redis_url = args.redis_b + if args.flask_container_b: + b.docker_flask_container = args.flask_container_b + if args.worker_container_b: + b.docker_worker_container = args.worker_container_b + if args.ssh_host_b: + b.ssh_host = args.ssh_host_b + if args.ssh_user_b: + b.ssh_user = args.ssh_user_b + if args.ssh_key_b: + b.ssh_key = args.ssh_key_b + + # Performance settings + config.perf_warmup_requests = args.warmup + config.perf_benchmark_requests = args.bench_requests + config.perf_concurrent_users = args.concurrent + + # Output settings + config.output_dir = args.output_dir + config.report_format = args.format + config.verbose = args.verbose + + # Test selection + if args.only: + categories = set(args.only.split(",")) + config.run_api_tests = "api" in categories + config.run_db_tests = "db" in categories or "database" in categories + config.run_docker_tests = "docker" in categories + config.run_performance_tests = "performance" in categories or "perf" in categories + config.run_existing_unit_tests = "existing_tests" in categories or "unit" in categories + config.run_existing_integration_tests = "existing_tests" in categories or "integration" in categories + + if args.skip: + skip = set(args.skip.split(",")) + if "api" in skip: + config.run_api_tests = False + if "db" in skip or "database" in skip: + config.run_db_tests = False + if "docker" in skip: + config.run_docker_tests = False + if "performance" in skip or "perf" in skip: + config.run_performance_tests = False + if "existing_tests" in skip or "unit" in skip: + config.run_existing_unit_tests = False + if "existing_tests" in skip or "integration" in skip: + config.run_existing_integration_tests = False + + return config + + +def main(): + parser = build_parser() + args = parser.parse_args() + + # Setup logging + level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + + # Discovery mode + if args.discover: + discovery = ExistingTestRunner.discover_tests() + print("\n=== AudioMuse-AI Test Discovery ===\n") + + print(f"Unit Tests ({len(discovery['unit_tests'])} files):") + for t in discovery["unit_tests"]: + status = "OK" if t["exists"] else "MISSING" + print(f" [{status}] {t['file']}") + + print(f"\nIntegration Tests ({len(discovery['integration_tests'])} files):") + for t in discovery["integration_tests"]: + status = "OK" if t["exists"] else "MISSING" + print(f" [{status}] {t['file']}") + + print(f"\nE2E API Tests ({len(discovery['e2e_tests'])} tests):") + for t in discovery["e2e_tests"]: + print(f" [OK] {t['name']} ({t['file']})") + + total = (len(discovery["unit_tests"]) + + len(discovery["integration_tests"]) + + len(discovery["e2e_tests"])) + print(f"\nTotal: {total} test files/entries discovered.\n") + return 0 + + # Validate minimum config + config = build_config(args) + + if not config.instance_a.api_url or not config.instance_b.api_url: + if not args.config: + parser.error("At least --url-a and --url-b are required " + "(or use --config for YAML config)") + + # Run comparison + orchestrator = ComparisonOrchestrator(config) + report = orchestrator.run() + orchestrator.print_summary() + + # Exit code based on results + if report.total_failed > 0 or report.total_errors > 0: + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/testing_suite/test_runner/__init__.py b/testing_suite/test_runner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testing_suite/test_runner/existing_tests.py b/testing_suite/test_runner/existing_tests.py new file mode 100644 index 00000000..c8a388fd --- /dev/null +++ b/testing_suite/test_runner/existing_tests.py @@ -0,0 +1,469 @@ +""" +Existing Test Integration Module for AudioMuse-AI Testing Suite. + +Discovers and runs existing unit and integration tests from the codebase +against both instances, collecting and comparing results. + +Integrates: + - tests/unit/ (17 unit test modules via pytest) + - test/test.py (E2E API integration tests) + - test/test_analysis_integration.py (ONNX model integration) + - test/test_clap_analysis_integration.py (CLAP model integration) +""" + +import json +import logging +import os +import subprocess +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from testing_suite.config import ComparisonConfig, InstanceConfig +from testing_suite.utils import ( + ComparisonReport, TestResult, TestStatus, pct_diff +) + +logger = logging.getLogger(__name__) + +PROJECT_ROOT = Path(__file__).resolve().parents[2] + +# Known unit test files +UNIT_TEST_DIR = PROJECT_ROOT / "tests" / "unit" +UNIT_TEST_FILES = [ + "test_ai.py", + "test_analysis.py", + "test_app_analysis.py", + "test_artist_gmm_manager.py", + "test_clap_text_search.py", + "test_clustering.py", + "test_clustering_helper.py", + "test_clustering_postprocessing.py", + "test_commons.py", + "test_mediaserver.py", + "test_memory_cleanup.py", + "test_memory_utils.py", + "test_path_manager.py", + "test_song_alchemy.py", + "test_sonic_fingerprint_manager.py", + "test_string_sanitization.py", + "test_voyager_manager.py", +] + +# Integration test files +INTEGRATION_TEST_DIR = PROJECT_ROOT / "test" +INTEGRATION_TEST_FILES = [ + "test_analysis_integration.py", + "test_clap_analysis_integration.py", +] + +# E2E API test (requires a running instance) +E2E_TEST_FILE = PROJECT_ROOT / "test" / "test.py" + +# Individual E2E test names (from test/test.py) +E2E_TEST_NAMES = [ + "test_analysis_smoke_flow", + "test_instant_playlist_functionality", + "test_sonic_fingerprint_and_playlist", + "test_song_alchemy_and_playlist", + "test_map_visualization", + "test_annoy_similarity_and_playlist", + "test_song_path_and_playlist", + "test_clustering_smoke_flow", +] + + +def _parse_pytest_json(json_path: str) -> dict: + """Parse pytest JSON report.""" + try: + with open(json_path, 'r') as f: + return json.load(f) + except Exception as e: + logger.warning(f"Could not parse pytest JSON report: {e}") + return {} + + +def _run_pytest(test_path: str, extra_args: list = None, + env_override: dict = None, timeout: int = 600, + json_report: bool = True) -> Tuple[dict, str, int]: + """ + Run pytest and capture results. + Returns (parsed_json_result, stdout, returncode). + """ + cmd = ["python", "-m", "pytest", "-v", "--tb=short"] + + json_path = None + if json_report: + json_path = f"/tmp/pytest_report_{int(time.time() * 1000)}.json" + cmd += [f"--json-report", f"--json-report-file={json_path}"] + + cmd.append(str(test_path)) + + if extra_args: + cmd.extend(extra_args) + + env = os.environ.copy() + if env_override: + env.update(env_override) + + # Ensure project root is in PYTHONPATH + python_path = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = f"{PROJECT_ROOT}:{python_path}" if python_path else str(PROJECT_ROOT) + + try: + proc = subprocess.run( + cmd, capture_output=True, text=True, timeout=timeout, + cwd=str(PROJECT_ROOT), env=env, + ) + stdout = proc.stdout + proc.stderr + returncode = proc.returncode + + # Parse JSON report if available + result = {} + if json_path and os.path.exists(json_path): + result = _parse_pytest_json(json_path) + os.unlink(json_path) + + return result, stdout, returncode + + except subprocess.TimeoutExpired: + return {}, f"pytest timed out after {timeout}s", -1 + except Exception as e: + return {}, str(e), -2 + + +def _parse_stdout_results(stdout: str) -> dict: + """ + Parse pytest stdout for test results when JSON report is not available. + Returns dict with passed, failed, error, skipped counts and test names. + """ + import re + + results = { + "passed": 0, + "failed": 0, + "errors": 0, + "skipped": 0, + "tests": [], + } + + # Parse individual test results: PASSED, FAILED, ERROR, SKIPPED + for line in stdout.split('\n'): + line = line.strip() + if " PASSED" in line: + results["passed"] += 1 + results["tests"].append({"name": line.split(" PASSED")[0].strip(), "status": "passed"}) + elif " FAILED" in line: + results["failed"] += 1 + results["tests"].append({"name": line.split(" FAILED")[0].strip(), "status": "failed"}) + elif " ERROR" in line: + results["errors"] += 1 + results["tests"].append({"name": line.split(" ERROR")[0].strip(), "status": "error"}) + elif " SKIPPED" in line: + results["skipped"] += 1 + results["tests"].append({"name": line.split(" SKIPPED")[0].strip(), "status": "skipped"}) + + # Try to parse summary line like "5 passed, 1 failed, 2 skipped" + summary_match = re.search( + r'(\d+)\s+passed.*?(?:(\d+)\s+failed)?.*?(?:(\d+)\s+skipped)?.*?(?:(\d+)\s+error)?', + stdout + ) + if summary_match: + if summary_match.group(1): + results["passed"] = max(results["passed"], int(summary_match.group(1))) + if summary_match.group(2): + results["failed"] = max(results["failed"], int(summary_match.group(2))) + if summary_match.group(3): + results["skipped"] = max(results["skipped"], int(summary_match.group(3))) + if summary_match.group(4): + results["errors"] = max(results["errors"], int(summary_match.group(4))) + + return results + + +class ExistingTestRunner: + """Runs existing tests and integrates results into the comparison report.""" + + def __init__(self, config: ComparisonConfig): + self.config = config + + def run_all(self, report: ComparisonReport): + """Run all existing test suites.""" + logger.info("Starting existing test integration...") + + if self.config.run_existing_unit_tests: + self._run_unit_tests(report) + + if self.config.run_existing_integration_tests: + self._run_integration_tests(report) + self._run_e2e_tests(report) + + logger.info("Existing test integration complete.") + + # ------------------------------------------------------------------ + # Unit tests + # ------------------------------------------------------------------ + + def _run_unit_tests(self, report: ComparisonReport): + """Run unit tests from tests/unit/ directory.""" + t0 = time.time() + + if not UNIT_TEST_DIR.exists(): + report.add_result(TestResult( + category="existing_tests", + name="Unit Tests: directory check", + status=TestStatus.ERROR, + message=f"Unit test directory not found: {UNIT_TEST_DIR}", + duration_seconds=time.time() - t0, + )) + return + + # Run entire unit test suite + logger.info("Running unit test suite...") + result, stdout, rc = _run_pytest( + str(UNIT_TEST_DIR), + extra_args=["-x", "--timeout=120"], + timeout=600, + json_report=True, + ) + + # Parse results + if result and "summary" in result: + summary = result["summary"] + passed = summary.get("passed", 0) + failed = summary.get("failed", 0) + errors = summary.get("error", 0) + skipped = summary.get("skipped", 0) + total = summary.get("total", passed + failed + errors + skipped) + else: + # Fallback to stdout parsing + parsed = _parse_stdout_results(stdout) + passed = parsed["passed"] + failed = parsed["failed"] + errors = parsed["errors"] + skipped = parsed["skipped"] + total = passed + failed + errors + skipped + + if failed == 0 and errors == 0: + status = TestStatus.PASS + elif errors > 0: + status = TestStatus.ERROR + else: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="existing_tests", + name="Unit Tests: overall", + status=status, + message=( + f"Total={total}, Passed={passed}, Failed={failed}, " + f"Errors={errors}, Skipped={skipped} | " + f"Return code: {rc}" + ), + instance_a_value={ + "total": total, "passed": passed, "failed": failed, + "errors": errors, "skipped": skipped, + }, + duration_seconds=time.time() - t0, + details={"returncode": rc, "stdout_tail": stdout[-2000:] if stdout else ""}, + )) + + # Report individual test file results + for test_file in UNIT_TEST_FILES: + test_path = UNIT_TEST_DIR / test_file + if not test_path.exists(): + report.add_result(TestResult( + category="existing_tests", + name=f"Unit: {test_file}", + status=TestStatus.SKIP, + message=f"File not found: {test_path}", + )) + continue + + tf0 = time.time() + file_result, file_stdout, file_rc = _run_pytest( + str(test_path), + extra_args=["--timeout=60"], + timeout=120, + json_report=False, + ) + + parsed = _parse_stdout_results(file_stdout) + + if file_rc == 0: + file_status = TestStatus.PASS + elif file_rc == 1: + file_status = TestStatus.FAIL + elif file_rc == 5: + file_status = TestStatus.SKIP # No tests collected + else: + file_status = TestStatus.ERROR + + report.add_result(TestResult( + category="existing_tests", + name=f"Unit: {test_file}", + status=file_status, + message=( + f"Passed={parsed['passed']}, Failed={parsed['failed']}, " + f"Errors={parsed['errors']}, Skipped={parsed['skipped']}" + ), + instance_a_value=parsed, + duration_seconds=time.time() - tf0, + details={"returncode": file_rc}, + )) + + # ------------------------------------------------------------------ + # Integration tests + # ------------------------------------------------------------------ + + def _run_integration_tests(self, report: ComparisonReport): + """Run integration tests from test/ directory.""" + for test_file in INTEGRATION_TEST_FILES: + test_path = INTEGRATION_TEST_DIR / test_file + if not test_path.exists(): + report.add_result(TestResult( + category="existing_tests", + name=f"Integration: {test_file}", + status=TestStatus.SKIP, + message=f"File not found: {test_path}", + )) + continue + + t0 = time.time() + result, stdout, rc = _run_pytest( + str(test_path), + extra_args=["-m", "integration", "--timeout=300"], + timeout=600, + json_report=False, + ) + + parsed = _parse_stdout_results(stdout) + + if rc == 0: + status = TestStatus.PASS + elif rc == 5: + status = TestStatus.SKIP + elif rc == 1: + status = TestStatus.FAIL + else: + status = TestStatus.ERROR + + report.add_result(TestResult( + category="existing_tests", + name=f"Integration: {test_file}", + status=status, + message=( + f"Passed={parsed['passed']}, Failed={parsed['failed']}, " + f"Errors={parsed['errors']}, Skipped={parsed['skipped']}" + ), + instance_a_value=parsed, + duration_seconds=time.time() - t0, + details={"returncode": rc, "stdout_tail": stdout[-1000:] if stdout else ""}, + )) + + # ------------------------------------------------------------------ + # E2E API tests (against both instances) + # ------------------------------------------------------------------ + + def _run_e2e_tests(self, report: ComparisonReport): + """Run E2E API tests from test/test.py against both instances.""" + if not E2E_TEST_FILE.exists(): + report.add_result(TestResult( + category="existing_tests", + name="E2E Tests: file check", + status=TestStatus.SKIP, + message=f"E2E test file not found: {E2E_TEST_FILE}", + )) + return + + instances = [] + if self.config.instance_a.api_url: + instances.append(("A", self.config.instance_a)) + if self.config.instance_b.api_url: + instances.append(("B", self.config.instance_b)) + + for label, instance in instances: + # Run non-destructive E2E tests (skip analysis and clustering which modify state) + safe_tests = [ + "test_map_visualization", + "test_annoy_similarity_and_playlist", + ] + + for test_name in safe_tests: + t0 = time.time() + result, stdout, rc = _run_pytest( + str(E2E_TEST_FILE), + extra_args=["-k", test_name, "--timeout=300"], + env_override={"BASE_URL": instance.api_url}, + timeout=600, + json_report=False, + ) + + parsed = _parse_stdout_results(stdout) + + if rc == 0: + status = TestStatus.PASS + elif rc == 5: + status = TestStatus.SKIP + elif rc == 1: + status = TestStatus.FAIL + else: + status = TestStatus.ERROR + + report.add_result(TestResult( + category="existing_tests", + name=f"E2E ({label}): {test_name}", + status=status, + message=( + f"Instance {label} ({instance.api_url}): " + f"Passed={parsed['passed']}, Failed={parsed['failed']}" + ), + instance_a_value=parsed if label == "A" else None, + instance_b_value=parsed if label == "B" else None, + duration_seconds=time.time() - t0, + details={"returncode": rc, "instance": label, + "api_url": instance.api_url, + "stdout_tail": stdout[-500:] if stdout else ""}, + )) + + # ------------------------------------------------------------------ + # Discovery: list all available tests + # ------------------------------------------------------------------ + + @staticmethod + def discover_tests() -> dict: + """Discover all available tests and return a structured summary.""" + discovery = { + "unit_tests": [], + "integration_tests": [], + "e2e_tests": [], + } + + # Unit tests + if UNIT_TEST_DIR.exists(): + for f in sorted(UNIT_TEST_DIR.glob("test_*.py")): + discovery["unit_tests"].append({ + "file": str(f.relative_to(PROJECT_ROOT)), + "name": f.stem, + "exists": True, + }) + + # Integration tests + for f in INTEGRATION_TEST_FILES: + path = INTEGRATION_TEST_DIR / f + discovery["integration_tests"].append({ + "file": str(path.relative_to(PROJECT_ROOT)), + "name": Path(f).stem, + "exists": path.exists(), + }) + + # E2E tests + if E2E_TEST_FILE.exists(): + for name in E2E_TEST_NAMES: + discovery["e2e_tests"].append({ + "file": str(E2E_TEST_FILE.relative_to(PROJECT_ROOT)), + "name": name, + "exists": True, + }) + + return discovery diff --git a/testing_suite/utils.py b/testing_suite/utils.py new file mode 100644 index 00000000..e296f471 --- /dev/null +++ b/testing_suite/utils.py @@ -0,0 +1,434 @@ +""" +Shared utilities for the AudioMuse-AI Testing & Comparison Suite. + +Provides HTTP helpers, database connectors, Docker log fetchers, +timing utilities, and result aggregation primitives. +""" + +import json +import logging +import subprocess +import time +from dataclasses import dataclass, field, asdict +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import requests + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Result types +# --------------------------------------------------------------------------- + +class TestStatus(str, Enum): + PASS = "PASS" + FAIL = "FAIL" + WARN = "WARN" + SKIP = "SKIP" + ERROR = "ERROR" + + +@dataclass +class TestResult: + """A single test result entry.""" + category: str # e.g. "api", "database", "docker", "performance" + name: str # descriptive test name + status: TestStatus + message: str = "" + instance_a_value: Any = None + instance_b_value: Any = None + diff: Any = None # computed difference + duration_seconds: float = 0.0 + details: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + d = asdict(self) + d['status'] = self.status.value + # Ensure JSON-serializable + for k in ('instance_a_value', 'instance_b_value', 'diff', 'details'): + try: + json.dumps(d[k]) + except (TypeError, ValueError): + d[k] = str(d[k]) + return d + + +@dataclass +class CategorySummary: + """Summary for a test category.""" + category: str + total: int = 0 + passed: int = 0 + failed: int = 0 + warned: int = 0 + skipped: int = 0 + errors: int = 0 + results: List[TestResult] = field(default_factory=list) + + def add(self, result: TestResult): + self.results.append(result) + self.total += 1 + if result.status == TestStatus.PASS: + self.passed += 1 + elif result.status == TestStatus.FAIL: + self.failed += 1 + elif result.status == TestStatus.WARN: + self.warned += 1 + elif result.status == TestStatus.SKIP: + self.skipped += 1 + elif result.status == TestStatus.ERROR: + self.errors += 1 + + +@dataclass +class ComparisonReport: + """Full comparison report across all categories.""" + timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) + instance_a_name: str = "" + instance_b_name: str = "" + instance_a_branch: str = "" + instance_b_branch: str = "" + categories: Dict[str, CategorySummary] = field(default_factory=dict) + config_snapshot: Dict[str, Any] = field(default_factory=dict) + + def add_result(self, result: TestResult): + cat = result.category + if cat not in self.categories: + self.categories[cat] = CategorySummary(category=cat) + self.categories[cat].add(result) + + @property + def total_tests(self) -> int: + return sum(c.total for c in self.categories.values()) + + @property + def total_passed(self) -> int: + return sum(c.passed for c in self.categories.values()) + + @property + def total_failed(self) -> int: + return sum(c.failed for c in self.categories.values()) + + @property + def total_errors(self) -> int: + return sum(c.errors for c in self.categories.values()) + + @property + def overall_status(self) -> TestStatus: + if self.total_failed > 0 or self.total_errors > 0: + return TestStatus.FAIL + return TestStatus.PASS + + def to_dict(self) -> dict: + return { + "timestamp": self.timestamp, + "instance_a": {"name": self.instance_a_name, "branch": self.instance_a_branch}, + "instance_b": {"name": self.instance_b_name, "branch": self.instance_b_branch}, + "overall_status": self.overall_status.value, + "summary": { + "total": self.total_tests, + "passed": self.total_passed, + "failed": self.total_failed, + "errors": self.total_errors, + }, + "categories": { + name: { + "total": cat.total, + "passed": cat.passed, + "failed": cat.failed, + "warned": cat.warned, + "skipped": cat.skipped, + "errors": cat.errors, + "results": [r.to_dict() for r in cat.results], + } + for name, cat in self.categories.items() + }, + "config": self.config_snapshot, + } + + +# --------------------------------------------------------------------------- +# HTTP Helpers +# --------------------------------------------------------------------------- + +def http_get(url: str, params: dict = None, timeout: int = 120, + retries: int = 3, retry_delay: float = 2.0) -> requests.Response: + """HTTP GET with retries on connection errors.""" + last_exc = None + for attempt in range(1, retries + 1): + try: + resp = requests.get(url, params=params, timeout=timeout) + return resp + except requests.RequestException as e: + last_exc = e + if attempt < retries: + time.sleep(retry_delay) + logger.debug(f"Retry {attempt}/{retries} for GET {url}: {e}") + raise last_exc + + +def http_post(url: str, json_data: dict = None, timeout: int = 120, + retries: int = 3, retry_delay: float = 2.0) -> requests.Response: + """HTTP POST with retries on connection errors.""" + last_exc = None + for attempt in range(1, retries + 1): + try: + resp = requests.post(url, json=json_data, timeout=timeout) + return resp + except requests.RequestException as e: + last_exc = e + if attempt < retries: + time.sleep(retry_delay) + logger.debug(f"Retry {attempt}/{retries} for POST {url}: {e}") + raise last_exc + + +def timed_request(method: str, url: str, **kwargs) -> Tuple[requests.Response, float]: + """Execute an HTTP request and return (response, elapsed_seconds).""" + start = time.perf_counter() + if method.upper() == "GET": + resp = http_get(url, **kwargs) + else: + resp = http_post(url, **kwargs) + elapsed = time.perf_counter() - start + return resp, elapsed + + +def wait_for_task_success(base_url: str, task_id: str, timeout: int = 1200, + retries: int = 3, retry_delay: float = 2.0) -> dict: + """Poll active_tasks until task completes, then verify success via last_task.""" + start = time.time() + while time.time() - start < timeout: + act_resp = http_get(f'{base_url}/api/active_tasks', retries=retries, + retry_delay=retry_delay) + act_resp.raise_for_status() + active = act_resp.json() + + if active and active.get('task_id') == task_id: + time.sleep(2) + continue + + last_resp = http_get(f'{base_url}/api/last_task', retries=retries, + retry_delay=retry_delay) + last_resp.raise_for_status() + final = last_resp.json() + final_id = final.get('task_id') + final_state = (final.get('status') or final.get('state') or '').upper() + + if final_id == task_id: + return final + # Task might have been superseded; keep polling briefly + time.sleep(2) + + return {"status": "TIMEOUT", "task_id": task_id} + + +# --------------------------------------------------------------------------- +# Database Helpers +# --------------------------------------------------------------------------- + +def get_pg_connection(dsn: str): + """Create a psycopg2 connection from DSN.""" + import psycopg2 + return psycopg2.connect(dsn, connect_timeout=30, + options='-c statement_timeout=120000') + + +def pg_query(dsn: str, sql: str, params: tuple = None) -> List[tuple]: + """Execute a read-only query and return all rows.""" + import psycopg2 + conn = psycopg2.connect(dsn, connect_timeout=30, + options='-c statement_timeout=120000') + try: + with conn.cursor() as cur: + cur.execute(sql, params) + return cur.fetchall() + finally: + conn.close() + + +def pg_query_dict(dsn: str, sql: str, params: tuple = None) -> List[dict]: + """Execute a query and return rows as dicts.""" + import psycopg2 + from psycopg2.extras import RealDictCursor + conn = psycopg2.connect(dsn, connect_timeout=30, + options='-c statement_timeout=120000') + try: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(sql, params) + return [dict(row) for row in cur.fetchall()] + finally: + conn.close() + + +def pg_scalar(dsn: str, sql: str, params: tuple = None): + """Execute a query that returns a single scalar value.""" + rows = pg_query(dsn, sql, params) + if rows and rows[0]: + return rows[0][0] + return None + + +# --------------------------------------------------------------------------- +# Docker Helpers +# --------------------------------------------------------------------------- + +def docker_exec(container: str, command: str, ssh_host: str = "", + ssh_user: str = "", ssh_key: str = "", + ssh_port: int = 22, timeout: int = 30) -> Tuple[str, str, int]: + """ + Run a command inside a Docker container (locally or via SSH). + Returns (stdout, stderr, returncode). + """ + if ssh_host: + ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", + "-p", str(ssh_port)] + if ssh_key: + ssh_cmd += ["-i", ssh_key] + ssh_cmd.append(f"{ssh_user}@{ssh_host}" if ssh_user else ssh_host) + ssh_cmd.append(f"docker exec {container} {command}") + full_cmd = ssh_cmd + else: + full_cmd = ["docker", "exec", container] + command.split() + + try: + proc = subprocess.run(full_cmd, capture_output=True, text=True, timeout=timeout) + return proc.stdout, proc.stderr, proc.returncode + except subprocess.TimeoutExpired: + return "", "Command timed out", -1 + except FileNotFoundError: + return "", "docker or ssh command not found", -2 + + +def docker_logs(container: str, tail: int = 500, since: str = "", + ssh_host: str = "", ssh_user: str = "", ssh_key: str = "", + ssh_port: int = 22, timeout: int = 30) -> Tuple[str, str, int]: + """ + Fetch Docker container logs (locally or via SSH). + Returns (stdout, stderr, returncode). + """ + cmd_parts = ["docker", "logs", f"--tail={tail}"] + if since: + cmd_parts += [f"--since={since}"] + cmd_parts.append(container) + + if ssh_host: + ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", + "-p", str(ssh_port)] + if ssh_key: + ssh_cmd += ["-i", ssh_key] + ssh_cmd.append(f"{ssh_user}@{ssh_host}" if ssh_user else ssh_host) + ssh_cmd.append(" ".join(cmd_parts)) + full_cmd = ssh_cmd + else: + full_cmd = cmd_parts + + try: + proc = subprocess.run(full_cmd, capture_output=True, text=True, timeout=timeout) + return proc.stdout, proc.stderr, proc.returncode + except subprocess.TimeoutExpired: + return "", "Logs fetch timed out", -1 + except FileNotFoundError: + return "", "docker or ssh command not found", -2 + + +def docker_inspect(container: str, ssh_host: str = "", ssh_user: str = "", + ssh_key: str = "", ssh_port: int = 22, + timeout: int = 15) -> Optional[dict]: + """ + Run docker inspect on a container and return the parsed JSON. + Returns None on failure. + """ + cmd_parts = ["docker", "inspect", container] + + if ssh_host: + ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", + "-p", str(ssh_port)] + if ssh_key: + ssh_cmd += ["-i", ssh_key] + ssh_cmd.append(f"{ssh_user}@{ssh_host}" if ssh_user else ssh_host) + ssh_cmd.append(" ".join(cmd_parts)) + full_cmd = ssh_cmd + else: + full_cmd = cmd_parts + + try: + proc = subprocess.run(full_cmd, capture_output=True, text=True, timeout=timeout) + if proc.returncode == 0: + data = json.loads(proc.stdout) + return data[0] if isinstance(data, list) and data else data + except Exception as e: + logger.debug(f"docker inspect failed for {container}: {e}") + return None + + +# --------------------------------------------------------------------------- +# Comparison Helpers +# --------------------------------------------------------------------------- + +def compare_values(a, b, tolerance_pct: float = 0.0) -> Tuple[bool, str]: + """ + Compare two values. For numeric types, allow a percentage tolerance. + Returns (is_equal, description). + """ + if a is None and b is None: + return True, "Both None" + if a is None or b is None: + return False, f"One is None: A={a}, B={b}" + + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + if a == 0 and b == 0: + return True, "Both zero" + if a == 0 or b == 0: + return False, f"A={a}, B={b}" + pct_diff = abs(a - b) / max(abs(a), abs(b)) * 100 + if pct_diff <= tolerance_pct: + return True, f"Within tolerance ({pct_diff:.2f}% <= {tolerance_pct}%)" + return False, f"Difference {pct_diff:.2f}% exceeds tolerance {tolerance_pct}%" + + if isinstance(a, str) and isinstance(b, str): + if a == b: + return True, "Strings match" + return False, f"Strings differ: '{a[:100]}' vs '{b[:100]}'" + + if isinstance(a, dict) and isinstance(b, dict): + keys_a = set(a.keys()) + keys_b = set(b.keys()) + if keys_a != keys_b: + missing_in_b = keys_a - keys_b + missing_in_a = keys_b - keys_a + return False, f"Key mismatch: missing_in_B={missing_in_b}, missing_in_A={missing_in_a}" + return True, "Dict keys match" + + if isinstance(a, list) and isinstance(b, list): + if len(a) == len(b): + return True, f"Lists same length ({len(a)})" + return False, f"List length differs: {len(a)} vs {len(b)}" + + # Fallback + if a == b: + return True, "Values equal" + return False, f"Values differ: {a} vs {b}" + + +def pct_diff(a: float, b: float) -> float: + """Calculate percentage difference between two values.""" + if a == 0 and b == 0: + return 0.0 + if a == 0 or b == 0: + return 100.0 + return abs(a - b) / max(abs(a), abs(b)) * 100 + + +def format_duration(seconds: float) -> str: + """Format seconds into a human-readable string.""" + if seconds < 1: + return f"{seconds*1000:.1f}ms" + if seconds < 60: + return f"{seconds:.2f}s" + minutes = int(seconds // 60) + secs = seconds % 60 + return f"{minutes}m {secs:.1f}s" From f42fe8f0614473f78b32b79b3fba0b00b2d90871 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 5 Feb 2026 09:27:30 +0000 Subject: [PATCH 15/33] Add .gitignore for testing suite report output directory https://claude.ai/code/session_0122SF3fSXM3e2dNqaJB5NDn --- testing_suite/reports/output/.gitignore | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 testing_suite/reports/output/.gitignore diff --git a/testing_suite/reports/output/.gitignore b/testing_suite/reports/output/.gitignore new file mode 100644 index 00000000..58f3ef52 --- /dev/null +++ b/testing_suite/reports/output/.gitignore @@ -0,0 +1,3 @@ +# Report output files are generated at runtime - do not commit +* +!.gitignore From bfaa3f0907d57cc0fd61a2385fa5febb84adb060 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 5 Feb 2026 11:01:02 +0000 Subject: [PATCH 16/33] Add comprehensive README guide for the testing and comparison suite Covers architecture, prerequisites, quick start, all three config methods (CLI, YAML, env vars), detailed descriptions of all 5 test categories, deployment scenarios (local, remote SSH, API-only), report formats, selective testing, test discovery, result interpretation, and troubleshooting. https://claude.ai/code/session_0122SF3fSXM3e2dNqaJB5NDn --- testing_suite/README.md | 526 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 526 insertions(+) create mode 100644 testing_suite/README.md diff --git a/testing_suite/README.md b/testing_suite/README.md new file mode 100644 index 00000000..97296e9c --- /dev/null +++ b/testing_suite/README.md @@ -0,0 +1,526 @@ +# AudioMuse-AI Testing & Comparison Suite + +A comprehensive tool for testing all features, database quality, API results, and performance of AudioMuse-AI — comparing two live instances side-by-side (e.g., **main branch** vs **feature branch**). + +## Table of Contents + +- [Overview](#overview) +- [Architecture](#architecture) +- [Prerequisites](#prerequisites) +- [Quick Start](#quick-start) +- [Configuration](#configuration) + - [CLI Arguments](#cli-arguments) + - [YAML Config File](#yaml-config-file) + - [Environment Variables](#environment-variables) +- [Test Categories](#test-categories) + - [API Comparison](#1-api-comparison-30-endpoints) + - [Database Comparison](#2-database-comparison-17-tables) + - [Docker & Infrastructure](#3-docker--infrastructure) + - [Performance Benchmarks](#4-performance-benchmarks) + - [Existing Test Suite](#5-existing-test-suite-27-tests) +- [Deployment Scenarios](#deployment-scenarios) + - [Same Machine (Different Ports)](#scenario-1-same-machine-different-ports) + - [Two Remote Machines](#scenario-2-two-remote-machines-via-ssh) + - [API-Only Comparison](#scenario-3-api-only-no-db-or-docker) +- [Reports](#reports) +- [Selective Testing](#selective-testing) +- [Test Discovery](#test-discovery) +- [Interpreting Results](#interpreting-results) +- [Troubleshooting](#troubleshooting) + +--- + +## Overview + +The testing suite connects to **two AudioMuse-AI instances** simultaneously via: + +| Connection | What it tests | +|----------------|-----------------------------------------------------------------| +| **API (HTTP)** | All 30+ REST endpoints — response codes, shapes, values, errors | +| **PostgreSQL** | Schema integrity, data quality, embedding health, distributions | +| **Docker** | Container health, resource usage, log error analysis | +| **Performance**| Latency benchmarks (p50/p95/p99), concurrent load, DB queries | + +It also discovers and runs the **27 existing tests** (unit, integration, E2E) already in the codebase. + +The final output is a **self-contained HTML report** (dark theme, filterable, with visual performance charts) plus a **JSON report** for programmatic consumption. + +--- + +## Architecture + +``` +testing_suite/ +├── run_comparison.py # CLI entry point +├── __main__.py # python -m testing_suite +├── config.py # Configuration (CLI / YAML / env vars) +├── utils.py # HTTP helpers, DB connectors, Docker log fetchers +├── orchestrator.py # Coordinates all modules, generates reports +├── comparators/ +│ ├── api_comparator.py # 30+ API endpoint tests +│ ├── db_comparator.py # Schema, data quality, embeddings, integrity +│ ├── docker_comparator.py # Container health, logs, resource usage +│ └── performance_comparator.py # Latency, throughput, DB query benchmarks +├── test_runner/ +│ └── existing_tests.py # Discovers & runs 27 existing test files +├── reports/ +│ ├── html_report.py # Self-contained HTML report generator +│ └── output/ # Generated reports (gitignored) +├── comparison_config.example.yaml +└── requirements.txt +``` + +--- + +## Prerequisites + +1. **Python 3.10+** on the machine running the suite +2. **Both AudioMuse-AI instances running** (Flask app + Worker + PostgreSQL + Redis) +3. **Network access** from the test runner to both instances (API ports, DB ports) + +Install dependencies: + +```bash +pip install -r testing_suite/requirements.txt +``` + +The suite requires: `requests`, `psycopg2-binary`, `pyyaml`, `pytest`, `pytest-json-report`, `pytest-timeout`. + +> **Note:** Docker comparison features require the `docker` CLI accessible from the test runner (either locally or via SSH to the remote hosts). + +--- + +## Quick Start + +### Minimal (API-only comparison) + +```bash +python -m testing_suite \ + --url-a http://192.168.1.100:8000 \ + --url-b http://192.168.1.101:8000 +``` + +This tests all API endpoints and runs existing unit tests. Database and Docker tests will be skipped if hosts aren't specified. + +### Full comparison + +```bash +python -m testing_suite \ + --url-a http://192.168.1.100:8000 \ + --url-b http://192.168.1.101:8000 \ + --pg-host-a 192.168.1.100 \ + --pg-host-b 192.168.1.101 \ + --flask-container-a audiomuse-main-flask \ + --flask-container-b audiomuse-feature-flask \ + --name-a main --branch-a main \ + --name-b feature --branch-b fix/my-feature +``` + +### From config file + +```bash +cp testing_suite/comparison_config.example.yaml my_config.yaml +# Edit my_config.yaml with your instance details +python -m testing_suite --config my_config.yaml +``` + +--- + +## Configuration + +There are three ways to configure the suite (in order of precedence): + +### CLI Arguments + +All settings can be passed as command-line flags. Each instance has a matching set of flags suffixed with `-a` or `-b`: + +``` +Instance Connection: + --url-a / --url-b API base URL (e.g., http://host:8000) + --name-a / --name-b Display name (default: main / feature) + --branch-a / --branch-b Git branch name for reporting + +PostgreSQL: + --pg-host-a / --pg-host-b Database host + --pg-port-a / --pg-port-b Database port (default: 5432) + --pg-user-a / --pg-user-b Database user (default: audiomuse) + --pg-pass-a / --pg-pass-b Database password (default: audiomusepassword) + --pg-db-a / --pg-db-b Database name (default: audiomusedb) + +Redis: + --redis-a / --redis-b Redis URL (default: redis://localhost:6379/0) + +Docker: + --flask-container-a / -b Flask app container name + --worker-container-a / -b RQ worker container name + --ssh-host-a / -b SSH host for remote Docker access + --ssh-user-a / -b SSH username + --ssh-key-a / -b SSH private key path + +Test Control: + --only CATEGORIES Only run listed categories (comma-separated) + --skip CATEGORIES Skip listed categories + --warmup N Warmup requests before benchmarking (default: 3) + --bench-requests N Benchmark iterations per endpoint (default: 10) + --concurrent N Concurrent users for load test (default: 5) + +Output: + --output-dir PATH Report output directory + --format {html,json,both} Report format (default: both) + -v / --verbose Debug-level logging +``` + +### YAML Config File + +Copy and edit the example: + +```bash +cp testing_suite/comparison_config.example.yaml comparison_config.yaml +``` + +The YAML file supports all the same settings. See `comparison_config.example.yaml` for the full annotated template with all available options including quality thresholds, test track references, and performance parameters. + +### Environment Variables + +Every setting can be set via environment variables with the `INSTANCE_A_` / `INSTANCE_B_` prefix: + +```bash +export INSTANCE_A_API_URL=http://192.168.1.100:8000 +export INSTANCE_A_PG_HOST=192.168.1.100 +export INSTANCE_B_API_URL=http://192.168.1.101:8000 +export INSTANCE_B_PG_HOST=192.168.1.101 +export COMPARISON_VERBOSE=true +export COMPARISON_OUTPUT_DIR=./reports + +python -m testing_suite +``` + +--- + +## Test Categories + +### 1. API Comparison (30+ endpoints) + +Tests every AudioMuse-AI API endpoint on both instances and compares: + +| What's tested | Details | +|---------------|---------| +| **Status codes** | Both instances return the same HTTP status | +| **Response shape** | Same JSON keys, same structure | +| **List lengths** | Playlists, search results, etc. have comparable sizes | +| **Required fields** | Track objects have `item_id`, `title`, etc. | +| **Error handling** | Both handle invalid inputs the same way | +| **Functional tests** | Track search, similarity, alchemy, path finding with real data | + +**Endpoints covered:** +- `/api/config`, `/api/playlists`, `/api/active_tasks`, `/api/last_task` +- `/api/search_tracks`, `/api/similar_tracks`, `/api/max_distance` +- `/api/map`, `/api/map_cache_status` +- `/api/clap/stats`, `/api/clap/warmup/status`, `/api/clap/top_queries` +- `/api/alchemy`, `/api/find_path`, `/api/sonic_fingerprint/generate` +- `/api/artist_projections`, `/api/search_artists` +- `/api/setup/status`, `/api/setup/providers`, `/api/setup/settings`, `/api/setup/server-info` +- `/api/setup/providers/types`, `/api/providers/enabled` +- `/api/cron`, `/api/waveform`, `/api/collection/last_task` +- `/external/search`, `/chat/api/config_defaults` +- Error cases: nonexistent tasks, missing parameters + +### 2. Database Comparison (17 tables) + +Connects directly to both PostgreSQL instances and validates: + +| Category | Tests | +|----------|-------| +| **Schema** | All 17 expected tables exist with correct columns | +| **Row counts** | Compared with configurable tolerance (default 5%) | +| **Data quality** | NULL rates in critical score columns (item_id, title, author, tempo, key, scale, mood_vector) | +| **Duplicates** | No duplicate item_ids in score table | +| **Mood vector format** | Validates mood_vector string format | +| **Embedding integrity** | Coverage (% of scores with embeddings), NULL checks, dimension consistency | +| **Referential integrity** | No orphaned rows in embedding→score, provider_track→provider | +| **Score distributions** | Statistical comparison of tempo, energy (min/max/avg/stddev) | +| **Key distribution** | Musical key value comparison between instances | +| **Playlist quality** | Distinct count, avg tracks per playlist, NULL item_ids | +| **Index data** | Voyager HNSW, Artist GMM, Map projection, Artist projection presence | +| **Task health** | Failed task count, stuck tasks (>2hr), success rate comparison | +| **Provider config** | Same provider types and settings | +| **App settings** | Same configuration keys | + +### 3. Docker & Infrastructure + +Analyzes container health and logs via the Docker CLI (local or SSH): + +| Category | Tests | +|----------|-------| +| **Container status** | Running/stopped for Flask, Worker, PostgreSQL, Redis | +| **Restart counts** | Flags high restart counts (>5 = FAIL) | +| **Health checks** | Docker health check status comparison | +| **Memory usage** | MB comparison with % difference threshold | +| **CPU usage** | Percentage comparison | +| **Error patterns** | 11 patterns: tracebacks, OOM, connection errors, timeouts, permission, disk, crashes, worker deaths, DB errors, Redis errors | +| **Warning patterns** | 5 patterns: deprecation, warnings, retries, slow ops, memory pressure | +| **Python tracebacks** | Exact count comparison (>10 = FAIL) | +| **Redis connectivity** | Ping test from inside Flask container | +| **PostgreSQL connectivity** | SELECT 1 test from inside Flask container | + +### 4. Performance Benchmarks + +Measures and compares response times with statistical rigor: + +| Category | Details | +|----------|---------| +| **Endpoint latency** | 16 endpoints benchmarked with warmup phase, measuring p50/p95/p99/mean/max/stddev | +| **Concurrent load** | Configurable concurrent users hitting key endpoints simultaneously, measuring throughput (req/s) | +| **DB query performance** | 8 critical queries benchmarked: counts, joins, aggregations, group-bys | + +**Thresholds:** +- **PASS**: Instance B within 20% of A (or faster) +- **WARN**: Instance B 20-100% slower +- **FAIL**: Instance B >2x slower + +### 5. Existing Test Suite (27 tests) + +Discovers and runs all tests already in the codebase: + +| Category | Files | Tests | +|----------|-------|-------| +| **Unit tests** | 17 files in `tests/unit/` | test_ai, test_analysis, test_app_analysis, test_artist_gmm_manager, test_clap_text_search, test_clustering, test_clustering_helper, test_clustering_postprocessing, test_commons, test_mediaserver, test_memory_cleanup, test_memory_utils, test_path_manager, test_song_alchemy, test_sonic_fingerprint_manager, test_string_sanitization, test_voyager_manager | +| **Integration tests** | 2 files in `test/` | test_analysis_integration, test_clap_analysis_integration | +| **E2E API tests** | 8 tests in `test/test.py` | analysis smoke, instant playlist, sonic fingerprint, song alchemy, map visualization, similarity, song path, clustering smoke | + +Unit tests run once (they mock dependencies). E2E tests run against both instances with the `BASE_URL` pointed at each. + +--- + +## Deployment Scenarios + +### Scenario 1: Same Machine, Different Ports + +Two Docker Compose stacks running on ports 8000 and 8001: + +```bash +python -m testing_suite \ + --url-a http://localhost:8000 --url-b http://localhost:8001 \ + --pg-host-a localhost --pg-port-a 5432 \ + --pg-host-b localhost --pg-port-b 5433 \ + --flask-container-a audiomuse-main-flask \ + --flask-container-b audiomuse-feature-flask \ + --worker-container-a audiomuse-main-worker \ + --worker-container-b audiomuse-feature-worker +``` + +### Scenario 2: Two Remote Machines (via SSH) + +Instance A on server1, Instance B on server2: + +```bash +python -m testing_suite \ + --url-a http://server1:8000 --url-b http://server2:8000 \ + --pg-host-a server1 --pg-host-b server2 \ + --ssh-host-a server1 --ssh-user-a deploy --ssh-key-a ~/.ssh/id_rsa \ + --ssh-host-b server2 --ssh-user-b deploy --ssh-key-b ~/.ssh/id_rsa +``` + +The suite will SSH into each server to run `docker inspect`, `docker logs`, and `docker stats`. + +### Scenario 3: API-Only (No DB or Docker) + +If you only have HTTP access to both instances: + +```bash +python -m testing_suite \ + --url-a http://main.example.com --url-b http://feature.example.com \ + --only api,performance +``` + +--- + +## Reports + +Every run produces two reports in the output directory (`testing_suite/reports/output/` by default): + +### HTML Report + +A self-contained, dark-themed HTML file with: +- Overall pass/fail status badge +- Summary cards (total, passed, failed, errors, warnings) +- Per-category expandable sections +- Filterable result tables (filter by Pass/Fail/Warn/Error/Skip) +- Side-by-side Instance A vs Instance B values +- Visual performance bar charts comparing latency + +Open in any browser: `testing_suite/reports/output/comparison_latest.html` + +### JSON Report + +Machine-readable format with full test details: + +```json +{ + "timestamp": "2025-01-15T10:30:00.000000", + "instance_a": {"name": "main", "branch": "main"}, + "instance_b": {"name": "feature", "branch": "feature"}, + "overall_status": "PASS", + "summary": {"total": 150, "passed": 142, "failed": 3, "errors": 0}, + "categories": { + "api": {"total": 50, "passed": 48, "failed": 2, ...}, + "database": {"total": 40, "passed": 38, ...}, + ... + } +} +``` + +--- + +## Selective Testing + +### Run only specific categories + +```bash +# API and database only +python -m testing_suite --url-a ... --url-b ... --only api,db + +# Performance only +python -m testing_suite --url-a ... --url-b ... --only performance + +# Existing unit tests only +python -m testing_suite --url-a ... --url-b ... --only unit +``` + +**Category names:** `api`, `db` (or `database`), `docker`, `performance` (or `perf`), `existing_tests`, `unit`, `integration` + +### Skip specific categories + +```bash +# Skip Docker and existing tests (faster) +python -m testing_suite --url-a ... --url-b ... --skip docker,existing_tests + +# Skip performance benchmarks +python -m testing_suite --url-a ... --url-b ... --skip perf +``` + +### Tune performance test parameters + +```bash +# Light benchmarking (fast) +python -m testing_suite --url-a ... --url-b ... --warmup 1 --bench-requests 3 --concurrent 2 + +# Heavy benchmarking (thorough) +python -m testing_suite --url-a ... --url-b ... --warmup 10 --bench-requests 50 --concurrent 20 +``` + +--- + +## Test Discovery + +List all available tests without running anything: + +```bash +python -m testing_suite --discover +``` + +Output: + +``` +=== AudioMuse-AI Test Discovery === + +Unit Tests (17 files): + [OK] tests/unit/test_ai.py + [OK] tests/unit/test_analysis.py + ... + +Integration Tests (2 files): + [OK] test/test_analysis_integration.py + [OK] test/test_clap_analysis_integration.py + +E2E API Tests (8 tests): + [OK] test_analysis_smoke_flow (test/test.py) + ... + +Total: 27 test files/entries discovered. +``` + +--- + +## Interpreting Results + +### Status Codes + +| Status | Meaning | +|--------|---------| +| **PASS** | Both instances match or values are within acceptable thresholds | +| **FAIL** | Significant difference detected, or a quality check failed | +| **WARN** | Minor difference or non-critical issue detected | +| **SKIP** | Test could not run (missing table, unreachable endpoint, etc.) | +| **ERROR** | Test itself errored (connection failure, timeout, exception) | + +### Exit Codes + +The CLI returns: +- `0` — All tests passed (or only warnings/skips) +- `1` — One or more tests failed or errored + +This makes it suitable for CI/CD pipelines: + +```bash +python -m testing_suite --config config.yaml || echo "Comparison found regressions!" +``` + +### Performance Comparison Logic + +- **B/A ratio ≤ 1.2** → PASS (B is within 20% of A) +- **B/A ratio ≤ 2.0** → WARN (B is up to 2x slower) +- **B/A ratio > 2.0** → FAIL (B is more than 2x slower) +- If B is faster than A, that's always a PASS + +--- + +## Troubleshooting + +### "Cannot connect to either database instance" + +- Verify PostgreSQL is accessible from the test runner machine +- Check `--pg-host-a/b`, `--pg-port-a/b`, `--pg-user-a/b`, `--pg-pass-a/b` +- Ensure `pg_hba.conf` allows connections from the test runner IP +- Try: `psql -h -p -U -d -c "SELECT 1"` + +### "Neither instance is reachable" + +- Verify the API URLs are correct and the Flask servers are running +- Check firewall rules allow traffic on port 8000 +- Try: `curl http://:8000/api/config` + +### "Cannot inspect containers (Docker not available)" + +- Docker CLI must be installed on the test runner (or accessible via SSH) +- Container names must match what's running (`docker ps --format '{{.Names}}'`) +- For remote Docker access, SSH must be configured: `--ssh-host-a/b`, `--ssh-user-a/b` + +### "pytest-json-report not found" + +```bash +pip install pytest-json-report +``` + +Or install all dependencies: + +```bash +pip install -r testing_suite/requirements.txt +``` + +### Customizing test tracks + +The API functional tests (search, similarity, alchemy, path) use reference tracks. Set them to tracks that exist in your library: + +```bash +python -m testing_suite --config my_config.yaml +``` + +In your YAML config: + +```yaml +test_track_artist_1: "Artist In Your Library" +test_track_title_1: "Song Title" +test_track_artist_2: "Another Artist" +test_track_title_2: "Another Song" +``` From 1446375a40a86c6bfb967bc61d9336a4f3b87f1a Mon Sep 17 00:00:00 2001 From: Rendy Date: Thu, 5 Feb 2026 19:47:09 +0100 Subject: [PATCH 17/33] bug fixes and hardware selection on setup --- app_setup.py | 40 ++++++++ tasks/analysis.py | 34 +++++++ tasks/mediaserver.py | 2 +- templates/setup.html | 220 +++++++++++++++++++++++++++++++++++++------ 4 files changed, 267 insertions(+), 29 deletions(-) diff --git a/app_setup.py b/app_setup.py index 813ed3a0..263cd5e6 100644 --- a/app_setup.py +++ b/app_setup.py @@ -71,6 +71,8 @@ def get_all_settings(): settings = {} for row in rows: key, value, category, description = row + # Handle None category - use 'general' as default + category = category or 'general' if category not in settings: settings[category] = {} settings[category][key] = { @@ -483,6 +485,16 @@ def create_provider(): return jsonify({'error': 'Validation failed', 'details': validation_errors}), 400 try: + # Check if provider of this type already exists - upsert to prevent duplicates + existing_providers = get_all_providers() + existing = next((p for p in existing_providers if p['provider_type'] == provider_type), None) + + if existing: + # Update existing provider instead of creating duplicate + update_provider(existing['id'], name=name, config_data=config_data, enabled=enabled, priority=priority) + logger.info(f"Updated existing provider {existing['id']} ({provider_type}) instead of creating duplicate") + return jsonify({'id': existing['id'], 'message': 'Provider updated', 'was_update': True}), 200 + provider_id = add_provider(provider_type, name, config_data, enabled, priority) return jsonify({'id': provider_id, 'message': 'Provider created'}), 201 except Exception as e: @@ -836,6 +848,7 @@ def get_server_info(): """ import socket import os + import subprocess # Try to get the server's IP address try: @@ -856,6 +869,31 @@ def get_server_info(): except Exception: host_ip = 'localhost' + # Detect GPU availability + gpu_available = False + gpu_name = None + + # Method 1: Check if onnxruntime-gpu CUDA provider is available + try: + import onnxruntime as ort + providers = ort.get_available_providers() + if 'CUDAExecutionProvider' in providers: + gpu_available = True + except Exception: + pass + + # Method 2: Try nvidia-smi for GPU name (if available) + if gpu_available: + try: + result = subprocess.run( + ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader,nounits'], + capture_output=True, text=True, timeout=5 + ) + if result.returncode == 0 and result.stdout.strip(): + gpu_name = result.stdout.strip().split('\n')[0] # First GPU + except Exception: + pass + return jsonify({ 'host': host_ip, 'hostname': socket.gethostname() if hasattr(socket, 'gethostname') else 'unknown', @@ -863,6 +901,8 @@ def get_server_info(): 'postgres_port': os.environ.get('POSTGRES_PORT', '5432'), 'postgres_host': os.environ.get('POSTGRES_HOST', 'postgres'), 'redis_url': os.environ.get('REDIS_URL', 'redis://redis:6379/0'), + 'gpu_available': gpu_available, + 'gpu_name': gpu_name, }) diff --git a/tasks/analysis.py b/tasks/analysis.py index dfe4a737..5cfc7b93 100644 --- a/tasks/analysis.py +++ b/tasks/analysis.py @@ -1343,6 +1343,40 @@ def monitor_and_clear_jobs(): log_and_update_main(status_message, progress, checked_album_ids=list(checked_album_ids)) time.sleep(5) + # Wait for any album analysis jobs still running on the queue from a previous run. + # This handles the case where the main task resumes, finds all albums already checked, + # but their album tasks are still executing from the previous run. + from rq import Queue + default_queue = Queue('default', connection=redis_conn) + wait_count = 0 + while True: + # Count album analysis jobs still running or queued + pending_album_jobs = 0 + for job in default_queue.jobs: + if hasattr(job, 'func_name') and 'analyze_album_task' in str(job.func_name): + pending_album_jobs += 1 + + # Also check started job registry for running jobs + started_registry = default_queue.started_job_registry + for job_id in started_registry.get_job_ids(): + try: + from rq.job import Job + job = Job.fetch(job_id, connection=redis_conn) + if hasattr(job, 'func_name') and 'analyze_album_task' in str(job.func_name): + pending_album_jobs += 1 + except Exception: + pass + + if pending_album_jobs == 0: + break + + wait_count += 1 + if wait_count == 1: + log_and_update_main(f"Waiting for {pending_album_jobs} album analysis job(s) from previous run to complete...", 90) + elif wait_count % 6 == 0: # Log every 30 seconds + log_and_update_main(f"Still waiting for {pending_album_jobs} album analysis job(s)...", 90) + time.sleep(5) + log_and_update_main("Performing final index rebuild...", 95) # Build Voyager index (song embeddings) build_and_store_voyager_index(get_db()) diff --git a/tasks/mediaserver.py b/tasks/mediaserver.py index 1c8e06a7..2ada2e63 100644 --- a/tasks/mediaserver.py +++ b/tasks/mediaserver.py @@ -1108,7 +1108,7 @@ def get_enabled_providers_for_playlists(): Returns: List of dicts with 'id', 'name', 'type' for each enabled provider """ - from app_helper import get_providers + from app_setup import get_providers providers = get_providers(enabled_only=True) return [ diff --git a/templates/setup.html b/templates/setup.html index 624b2711..99f07972 100644 --- a/templates/setup.html +++ b/templates/setup.html @@ -687,6 +687,71 @@ .worker-connection-info .config-field { margin-bottom: 1.25rem; } + + /* Hardware Info Display */ + .hardware-info-box { + display: flex; + align-items: flex-start; + gap: 1rem; + padding: 1.25rem; + background: var(--bg-card); + border-radius: 8px; + border: 2px solid var(--border-color); + } + + .hardware-info-box.gpu-detected { + border-color: #28a745; + background: rgba(40, 167, 69, 0.1); + } + + .hardware-info-box .hw-icon { + font-size: 2.5rem; + line-height: 1; + } + + .hardware-info-box .hw-details { + flex: 1; + } + + .hardware-info-box .hw-title { + font-weight: bold; + font-size: 1.1rem; + margin-bottom: 0.25rem; + } + + .hardware-info-box .hw-subtitle { + color: var(--text-muted); + font-size: 0.9rem; + margin-bottom: 0.75rem; + } + + .hardware-info-box .hw-description { + font-size: 0.9rem; + line-height: 1.5; + } + + .hardware-benefits { + margin-top: 1rem; + padding: 1rem; + background: rgba(37, 99, 235, 0.1); + border-radius: 8px; + border-left: 3px solid var(--color-primary); + } + + .hardware-benefits h4 { + margin: 0 0 0.5rem 0; + font-size: 0.95rem; + } + + .hardware-benefits ul { + margin: 0; + padding-left: 1.25rem; + font-size: 0.9rem; + } + + .hardware-benefits li { + margin-bottom: 0.25rem; + } {% endblock %} @@ -751,11 +816,21 @@

    Select Deployment Mode

    -
    -

    Select Hardware Configuration

    + +
    +

    Current Hardware Configuration

    +
    + +

    Detecting hardware...

    +
    +
    + + + @@ -727,6 +793,7 @@ {% block bodyAdditions %} + +