diff --git a/.gitignore b/.gitignore index a3815e45..a9108213 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,15 @@ env/ .pytest_cache/ htmlcov/ nul +testing_suite/reports/ + +# Testing suite configs (contain API keys / DB passwords) +testing_suite/instant_playlist_test_config.yaml +testing_suite/ai_naming_test_config.yaml +testing_suite/comparison_config.yaml + +# Deployment secrets +deployment/main.env # Large model files in query folder /query/*.pt diff --git a/README.md b/README.md index 164e5463..3fe913f4 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,13 @@ For the architecture design of AudioMuse-AI, take a look to the [ARCHITECTURE](d EMBY_TOKEN=your-api-token ``` + **For Local Files (No Media Server):** + ```env + MEDIASERVER_TYPE=localfiles + LOCALFILES_MUSIC_DIRECTORY=/path/to/your/music + LOCALFILES_PLAYLIST_DIR=/path/to/your/music/playlists + ``` + 3. **Start the services:** ```bash docker compose -f deployment/docker-compose.yaml up -d @@ -129,6 +136,53 @@ For the architecture design of AudioMuse-AI, take a look to the [ARCHITECTURE](d docker compose -f deployment/docker-compose.yaml down ``` +## Multi-Provider Support + +AudioMuse-AI supports connecting to multiple media servers simultaneously, allowing you to: +- Share analysis data across providers (analyze once, use everywhere) +- Create playlists on multiple servers at once +- Use a GUI Setup Wizard for easy configuration + +### GUI Setup Wizard + +Access the Setup Wizard at `http://localhost:8000/setup` to: +1. Add and configure multiple providers +2. Test connections before saving +3. Auto-detect music path prefixes for cross-provider matching +4. Set a primary provider for playlist creation + +### Local Files Provider + +The Local Files provider scans your music directory directly without requiring a media server: +- Supports MP3, FLAC, OGG, M4A, WAV, WMA, AAC, and OPUS formats +- Extracts metadata from embedded tags (ID3, Vorbis comments, etc.) +- Creates M3U playlists in a configurable directory +- Extracts ratings from POPM, TXXX:RATING, and Vorbis RATING tags + +**Configuration:** +```env +MEDIASERVER_TYPE=localfiles +LOCALFILES_MUSIC_DIRECTORY=/music # Path to your music library +LOCALFILES_PLAYLIST_DIR=/music/playlists # Where to save generated playlists +LOCALFILES_FORMATS=.mp3,.flac,.ogg,.m4a,.wav # Supported audio formats +LOCALFILES_SCAN_SUBDIRS=true # Scan subdirectories +``` + +### Cross-Provider Track Matching + +When using multiple providers, tracks are matched across servers using normalized file paths. This allows: +- Analysis data to be reused across providers +- Playlists to be created on any provider using tracks from another +- Automatic ID remapping when creating cross-provider playlists + +### Extended Metadata Fields + +AudioMuse-AI now stores additional metadata for each track: +- **album_artist**: The album artist (useful for compilations) +- **year**: Release year extracted from various tag formats +- **rating**: User rating on 0-5 scale (from tags or media server) +- **file_path**: Normalized file path for cross-provider linking + > NOTE: by default AudioMuse-AI is deployed WITHOUT authentication layer and its suited only for LOCAL deployment. If you want to configure it have a look to the [AUTHENTICATION](docs/AUTH.md) docs. If you enable the Authentication Layer, you need to be sure that any plugin used support and use the AudioMuse-AI API TOKEN ## **Hardware Requirements** diff --git a/TEST_CHECKLIST.md b/TEST_CHECKLIST.md new file mode 100644 index 00000000..08fc9d72 --- /dev/null +++ b/TEST_CHECKLIST.md @@ -0,0 +1,995 @@ +# AudioMuse-AI v0.9.0 - Comprehensive Test Checklist + +## Branch: `multi-provider-v2` vs `main` + +**Scope**: 96 changed files, +22,091/-1,489 lines, 47 commits across 7 feature areas. + +--- + +## Table of Contents + +1. [How to Use the Test Suite](#1-how-to-use-the-test-suite) +2. [Automated vs Manual Testing Summary](#2-automated-vs-manual-testing-summary) +3. [Multi-Provider Architecture](#3-multi-provider-architecture) +4. [GUI Setup Wizard](#4-gui-setup-wizard) +5. [Environment / Config Setup](#5-environment--config-setup) +6. [API Endpoints](#6-api-endpoints) +7. [App Interactions (UI/UX)](#7-app-interactions-uiux) +8. [Instant Playlist & AI Changes](#8-instant-playlist--ai-changes) +9. [MCP Tools](#9-mcp-tools) +10. [Provider-Specific Testing](#10-provider-specific-testing) +11. [Database & Schema Changes](#11-database--schema-changes) +12. [Dark Mode](#12-dark-mode) +13. [Analysis Pipeline](#13-analysis-pipeline) +14. [Playlist Ordering](#14-playlist-ordering) +15. [Deployment & Docker](#15-deployment--docker) +16. [Regression Tests](#16-regression-tests) +17. [Security](#17-security) + +--- + +## 1. How to Use the Test Suite + +### Test Directory Structure + +``` +AudioMuse-AI/ +├── tests/ +│ ├── conftest.py # Shared fixtures (importlib bypass, DB mocks, config restore) +│ └── unit/ # Unit tests (no external services needed) +│ ├── test_analysis.py # Audio analysis (50+ tests) +│ ├── test_ai.py # AI provider routing (30+ tests) +│ ├── test_ai_mcp_client.py # NEW - AI MCP client (60+ tests) +│ ├── test_clustering.py # Clustering helpers (60+ tests) +│ ├── test_clustering_helper.py +│ ├── test_clustering_postprocessing.py +│ ├── test_mediaserver.py # Jellyfin provider (15+ tests) +│ ├── test_voyager_manager.py # Similarity search (20+ tests) +│ ├── test_commons.py # Score vectors (10+ tests) +│ ├── test_app_analysis.py +│ ├── test_clap_text_search.py +│ ├── test_artist_gmm_manager.py +│ ├── test_memory_cleanup.py +│ ├── test_memory_utils.py +│ ├── test_path_manager.py +│ ├── test_song_alchemy.py +│ ├── test_sonic_fingerprint_manager.py +│ ├── test_string_sanitization.py +│ ├── test_mcp_server.py # NEW - MCP tools +│ ├── test_playlist_ordering.py # NEW - Playlist ordering +│ ├── test_app_setup.py # NEW - Setup wizard & providers +│ ├── test_app_chat.py # NEW - Instant playlist pipeline +│ └── test_mediaserver_localfiles.py # NEW - LocalFiles provider +├── test/ # Integration tests (require running services) +│ ├── test.py # End-to-end smoke tests +│ ├── test_analysis_integration.py +│ ├── test_clap_analysis_integration.py +│ ├── test_gpu_status.py +│ ├── verify_onnx_embeddings.py +│ └── provider_testing_stack/ # Multi-provider Docker test stack +│ ├── docker-compose-test-audiomuse.yaml +│ ├── docker-compose-test-providers.yaml +│ └── TEST_GUIDE.md +├── testing_suite/ # Comparison & benchmarking +│ ├── __main__.py # CLI entry point +│ ├── config.py # Suite configuration +│ ├── orchestrator.py # Test orchestration +│ ├── utils.py # Shared utilities +│ ├── test_instant_playlist.py # Instant playlist scenarios +│ ├── test_ai_naming.py # AI naming quality +│ ├── comparators/ # Cross-instance comparison +│ │ ├── api_comparator.py +│ │ ├── db_comparator.py +│ │ ├── docker_comparator.py +│ │ └── performance_comparator.py +│ ├── test_runner/ # Existing test runner +│ │ └── existing_tests.py +│ ├── run_comparison.py # Entry point +│ └── reports/html_report.py +└── pytest.ini # Test configuration +``` + +### Running Tests + +```bash +# ============================================================ +# UNIT TESTS (no external services, 2-5 minutes) +# ============================================================ + +# Run ALL unit tests +pytest tests/unit/ -v + +# Run specific test file +pytest tests/unit/test_mcp_server.py -v + +# Run specific test class +pytest tests/unit/test_mcp_server.py::TestSearchDatabase -v + +# Run specific test method +pytest tests/unit/test_mcp_server.py::TestSearchDatabase::test_genre_regex_prevents_substring_match -v + +# Skip slow tests +pytest tests/unit/ -v -m "not slow" + +# Run only new tests (for this branch) +pytest tests/unit/test_mcp_server.py tests/unit/test_playlist_ordering.py tests/unit/test_app_setup.py tests/unit/test_app_chat.py tests/unit/test_mediaserver_localfiles.py tests/unit/test_ai_mcp_client.py -v + +# ============================================================ +# INTEGRATION TESTS (require running services, 20+ minutes) +# ============================================================ + +# Requires: Flask server, PostgreSQL, Redis, ONNX models +pytest test/ -v -s --timeout=1200 + +# ============================================================ +# COMPARISON SUITE (require Docker + AI services) +# ============================================================ + +# Run full comparison between two instances +python testing_suite/run_comparison.py + +# Run instant playlist benchmarks +python testing_suite/test_instant_playlist.py --runs 5 + +# Run AI naming benchmarks +python testing_suite/test_ai_naming.py --runs 5 + +# ── Benchmark Configuration ────────────────────────────── +# Both benchmarks are self-contained — they do NOT use the +# main app's AI keys (Gemini, OpenAI, Mistral). Instead they +# route cloud models through OpenRouter. +# +# Config files (gitignored — copy from .example.yaml first): +# cp testing_suite/instant_playlist_test_config.example.yaml \ +# testing_suite/instant_playlist_test_config.yaml +# cp testing_suite/ai_naming_test_config.example.yaml \ +# testing_suite/ai_naming_test_config.yaml +# +# Provider setup (in each YAML under "defaults"): +# Ollama (local) – no API key needed, just run Ollama +# OpenRouter – set defaults.openrouter.api_key +# +# To disable a model, set enabled: false in the YAML. +# If no local config exists the scripts fall back to the +# .example.yaml automatically. + +# ============================================================ +# MULTI-PROVIDER TEST STACK (Docker-based) +# ============================================================ + +# Start test providers (Jellyfin, Navidrome) +cd test/provider_testing_stack +docker compose -f docker-compose-test-providers.yaml up -d + +# Start AudioMuse test instance +docker compose -f docker-compose-test-audiomuse.yaml up -d + +# See TEST_GUIDE.md for detailed instructions +``` + +### Test Markers + +```bash +pytest -m unit -v # Unit tests only +pytest -m integration -v # Integration tests only +pytest -m "not slow" -v # Skip slow tests +``` + +### URL Prefix Note + +Some unit tests register blueprints **without** `url_prefix` for simplicity (e.g., `test_app_chat.py` tests `/api/config_defaults`). In production, these routes have prefixes: +- `chat_bp` → `/chat/...` (e.g., `/chat/api/config_defaults`) +- `external_bp` → `/external/...` (e.g., `/external/get_score`) + +The endpoint paths in this checklist reflect **production** URLs. + +### Test Dependencies + +```bash +# Unit tests +pip install pytest>=7.0.0 + +# Integration tests +pip install -r test/requirements.txt + +# Comparison suite +pip install -r testing_suite/requirements.txt +``` + +--- + +## 2. Automated vs Manual Testing Summary + +### Can Be Automated (Unit Tests) + +| Area | Tests | Status | +|------|-------|--------| +| MCP tool logic (genre regex, brainstorm matching, relevance scoring) | 40+ | **NEW** | +| AI MCP client (system prompt, tool defs, provider dispatch, energy conversion) | 60+ | **NEW** | +| Playlist ordering (greedy NN, Circle of Fifths, energy arc) | 25+ | **NEW** | +| Setup wizard (provider CRUD, settings, validation) | 30+ | **NEW** | +| Instant playlist pipeline (iteration loop, diversity, sampling) | 35+ | **NEW** | +| LocalFiles provider (hashing, metadata, M3U) | 25+ | **NEW** | +| Energy normalization (0-1 to raw conversion) | 10+ | **NEW** | +| Config validation (defaults, env parsing) | 10+ | **NEW** | +| Existing core tests (analysis, clustering, voyager, AI) | 200+ | Existing | + +### Can Be Automated (Integration Tests) + +| Area | Tests | Status | +|------|-------|--------| +| API endpoint responses (status codes, JSON shape) | 50+ | Partially exists | +| Provider connection testing | 5+ | Via test stack | +| Cross-provider ID remapping | 5+ | Via test stack | +| Database schema migration | 5+ | Via test stack | + +### Requires Manual Testing + +| Area | Why Manual | Steps | +|------|-----------|-------| +| Setup Wizard UI flow | Multi-step interactive wizard | See [Section 4](#4-gui-setup-wizard) | +| Dark mode visual correctness | Visual inspection of 18 templates | See [Section 12](#12-dark-mode) | +| Sidebar navigation | Interactive menu behavior | See [Section 7](#7-app-interactions-uiux) | +| Chart.js dark mode colors | Canvas-rendered, no DOM assertion | See [Section 12](#12-dark-mode) | +| Provider-specific playlist creation | Requires real media servers | See [Section 10](#10-provider-specific-testing) | +| AI quality assessment | Subjective playlist quality | See [Section 8](#8-instant-playlist--ai-changes) | +| Docker deployment | Full stack spin-up | See [Section 15](#15-deployment--docker) | +| Instant playlist UX | Streaming response, progress display | See [Section 8](#8-instant-playlist--ai-changes) | + +--- + +## 3. Multi-Provider Architecture + +### 3.1 Fresh Install (No Existing Data) + +| # | Test Case | Type | Steps | Expected | +|---|-----------|------|-------|----------| +| 3.1.1 | First-run redirect | Auto | Start app with empty DB | Redirects to `/setup` | +| 3.1.2 | Provider table creation | Auto | Check DB after `init_db()` | `provider` table exists with correct schema | +| 3.1.3 | Settings table creation | Auto | Check DB after `init_db()` | `app_settings` table exists | +| 3.1.4 | Add Jellyfin provider | Manual | Setup wizard: select Jellyfin, enter URL/token/user | Provider saved, connection test passes | +| 3.1.5 | Add Navidrome provider | Manual | Setup wizard: select Navidrome, enter URL/user/pass | Provider saved, connection test passes | +| 3.1.6 | Add Lyrion provider | Manual | Setup wizard: select Lyrion, enter URL | Provider saved, connection test passes | +| 3.1.7 | Add Emby provider | Manual | Setup wizard: select Emby, enter URL/token/user | Provider saved, connection test passes | +| 3.1.8 | Add LocalFiles provider | Manual | Setup wizard: select LocalFiles, enter music dir | Provider saved, directory scan succeeds | +| 3.1.9 | Multiple providers | Manual | Add 2+ providers of different types | All listed, all enabled | +| 3.1.10 | Provider priority ordering | Auto | Add providers with different priorities | Returned in priority order | +| 3.1.11 | Duplicate provider rejection | Auto | Add same type+name twice | Returns error, no duplicate | +| 3.1.12 | music_path_prefix detection | Manual | Add provider, click auto-detect | Correct prefix detected from sample tracks | + +### 3.2 Migration (Existing Single-Provider Data) + +| # | Test Case | Type | Steps | Expected | +|---|-----------|------|-------|----------| +| 3.2.1 | Existing env vars preserved | Auto | Start with existing `.env` (JELLYFIN_*) | Config values still work | +| 3.2.2 | Setup wizard shows on upgrade | Manual | Upgrade from main, start app | Redirects to `/setup` once | +| 3.2.3 | Existing score data intact | Auto | Check `score` table after migration | All existing rows preserved | +| 3.2.4 | New columns added | Auto | Check `score` table schema | `album_artist`, `year`, `rating`, `file_path` columns exist | +| 3.2.5 | New columns nullable | Auto | Check existing rows | New columns are NULL for old data | +| 3.2.6 | Re-analysis populates new fields | Manual | Run analysis on existing library | New fields populated | +| 3.2.7 | Cross-provider file_path linking | Auto | Analyze same track via 2 providers | `find_existing_analysis_by_file_path()` finds match | +| 3.2.8 | Analysis reuse via file_path | Auto | Mock existing analysis, add new provider | `link_provider_to_existing_track()` links instead of re-analyzing | + +### 3.3 Provider CRUD API + +| # | Test Case | Type | Endpoint | Expected | +|---|-----------|------|----------|----------| +| 3.3.1 | List providers (empty) | Auto | `GET /api/setup/providers` | `[]` | +| 3.3.2 | Add provider | Auto | `POST /api/setup/providers` | 201, provider returned | +| 3.3.3 | Get provider by ID | Auto | `GET /api/setup/providers/` | Provider details returned | +| 3.3.4 | Update provider | Auto | `PUT /api/setup/providers/` | Updated fields reflected | +| 3.3.5 | Delete provider | Auto | `DELETE /api/setup/providers/` | 200, provider removed | +| 3.3.6 | Test connection (by ID) | Auto | `POST /api/setup/providers//test` | `{"success": true}` | +| 3.3.7 | Test connection (inline) | Auto | `POST /api/setup/providers/test` with config | `{"success": true}` or `{"success": false, "error": "..."}` | +| 3.3.8 | Get libraries | Manual | `POST /api/setup/providers/libraries` | Library list returned | +| 3.3.9 | Rescan paths | Manual | `POST /api/setup/providers//rescan-paths` | Track list with file paths | +| 3.3.10 | Get enabled providers | Auto | `GET /api/providers/enabled` | Only enabled providers | +| 3.3.11 | Invalid provider type | Auto | `POST /api/setup/providers` with bad type | 400 error | +| 3.3.12 | Missing required fields | Auto | `POST /api/setup/providers` incomplete | 400 error | +| 3.3.13 | Get provider types | Auto | `GET /api/setup/providers/types` | List of supported provider types | +| 3.3.14 | Multi-provider config | Auto | `POST /api/setup/multi-provider` | Multi-provider setup applied | +| 3.3.15 | Set primary provider | Auto | `PUT /api/setup/primary-provider` | Primary provider updated | +| 3.3.16 | Server info | Auto | `GET /api/setup/server-info` | Server configuration returned | +| 3.3.17 | Browse directories | Manual | `GET /api/setup/browse-directories` | Directory listing returned | +| 3.3.18 | Complete setup | Auto | `POST /api/setup/complete` | Setup marked as complete | + +### 3.4 Multi-Provider Playlist Creation + +| # | Test Case | Type | Steps | Expected | +|---|-----------|------|-------|----------| +| 3.4.1 | Single provider playlist | Auto | `create_playlist_from_ids(ids, provider_ids=1)` | Playlist on provider 1 only | +| 3.4.2 | All providers playlist | Auto | `create_playlist_from_ids(ids, provider_ids='all')` | Playlist on all enabled providers | +| 3.4.3 | Specific providers list | Auto | `create_playlist_from_ids(ids, provider_ids=[1,3])` | Playlist on providers 1 and 3 | +| 3.4.4 | Cross-provider ID remapping | Auto | Create playlist with IDs from provider A on provider B | file_path hash lookup maps IDs correctly | +| 3.4.5 | Unmapped track handling | Auto | Create playlist with track missing on target provider | Track skipped, warning logged | +| 3.4.6 | Provider selector UI | Manual | Open instant playlist, select target providers | Dropdown shows enabled providers | + +--- + +## 4. GUI Setup Wizard + +> **All Manual** - Interactive multi-step wizard UI + +| # | Test Case | Steps | Expected | +|---|-----------|-------|----------| +| 4.1 | Wizard loads on first run | Navigate to app URL | Setup wizard renders with welcome step | +| 4.2 | Step 1: Welcome | Read welcome text, click Next | Advances to provider selection | +| 4.3 | Step 2: Provider selection | Select provider type from dropdown | Configuration form appears for selected type | +| 4.4 | Step 2: Jellyfin config form | Select Jellyfin | Shows URL, User ID, Token fields | +| 4.5 | Step 2: Navidrome config form | Select Navidrome | Shows URL, Username, Password fields | +| 4.6 | Step 2: Lyrion config form | Select Lyrion | Shows URL field | +| 4.7 | Step 2: Emby config form | Select Emby | Shows URL, User ID, Token fields | +| 4.8 | Step 2: LocalFiles config form | Select LocalFiles | Shows Music Directory, Formats, Scan Subdirs fields | +| 4.9 | Step 3: Connection test success | Enter valid credentials, click Test | Green checkmark, "Connection successful" | +| 4.10 | Step 3: Connection test failure | Enter invalid credentials, click Test | Red X, error message displayed | +| 4.11 | Step 3: Library discovery | After successful test | Music libraries listed for selection | +| 4.12 | Step 3: Path prefix auto-detect | Click auto-detect button | Prefix field populated | +| 4.13 | Step 4: Add another provider | Click "Add Another Provider" | Returns to provider selection | +| 4.14 | Step 5: Complete setup | Click Complete | Redirects to main app, setup marked complete | +| 4.15 | Wizard skipped after setup | Return to app after completion | No redirect to `/setup` | +| 4.16 | Settings page access | Navigate to `/settings` | Settings page loads with current config | +| 4.17 | Settings: update AI provider | Change AI provider dropdown | Saves, applies to next instant playlist | +| 4.18 | Settings: update clustering | Change clustering algorithm | Saves, applies to next clustering run | +| 4.19 | Settings: disable provider | Toggle provider off | Provider excluded from playlist creation | +| 4.20 | Settings: re-enable provider | Toggle provider back on | Provider included in playlist creation | +| 4.21 | Form validation | Submit empty required fields | Client-side validation error shown | +| 4.22 | XSS prevention | Enter ` + + + +{% endblock %} diff --git a/templates/setup.html b/templates/setup.html new file mode 100644 index 00000000..ac63a814 --- /dev/null +++ b/templates/setup.html @@ -0,0 +1,2265 @@ +{% extends "includes/layout.html" %} + +{% block headAdditions %} + +{% endblock %} + +{% block content %} +
+
+

AudioMuse-AI Setup

+

Configure your music analysis system

+
+ + +
+
+
1
+ Welcome +
+
+
2
+ Providers +
+
+
3
+ Settings +
+
+
4
+ Complete +
+
+ + +
+
+

Welcome to AudioMuse-AI

+

+ This setup wizard will help you configure AudioMuse-AI to analyze your music library + and create intelligent playlists. Let's get started! +

+ + +
+ +
+

Select Deployment Mode

+

+ Choose how you want to deploy AudioMuse-AI. +

+ +
+
+
📦
+
Unified
+
Server and worker on the same machine. Best for most users.
+
+
+
🖧
+
Split
+
Run workers on separate machines. For distributed setups or dedicated GPU servers.
+
+
+
+ + +
+

Current Hardware Configuration

+
+ +

Detecting hardware...

+
+
+ + + + +
+
+ +
+
+ + +
+
+

Configure Media Providers

+

+ Select one or more media providers. You can add multiple providers to analyze music + from different sources. Tracks are linked by file path, so the same song in different + providers will share analysis data. +

+ +
+ +
+
+ + + + + +
+ + +
+
+ + +
+
+

Analysis Settings

+

+ Configure optional analysis features. The defaults work well for most users. +

+ +
+ +
+ Allows searching your music using natural language queries like "upbeat summer songs". + Uses additional memory (~750MB). +
+
+ +
+ +
+ Accelerate clustering with NVIDIA RAPIDS cuML. Only available with NVIDIA GPU. +
+
+
+ + + +
+

AI Integration (Optional)

+

+ AudioMuse-AI can use AI to generate creative playlist names based on the songs in each cluster. + This is purely cosmetic - it doesn't affect the analysis or clustering, just the names given to playlists. +

+ +
+ + +
AI service for generating creative playlist names
+
+ +
+ +
+
+ +
+ + Advanced Settings +
+ +
+
+

Database Configuration

+

+ These settings are typically configured via environment variables in Docker. + Only change if you know what you're doing. +

+ +
+ + +
Database host (set via POSTGRES_HOST env var)
+
+ +
+ + +
Redis connection URL (set via REDIS_URL env var)
+
+
+
+ +
+ + +
+
+ + +
+
+

Setup Summary

+

+ Review your configuration before completing the setup. +

+ +
+ +
+
+ +
+

What's Next?

+
    +
  • Run the analysis to scan and analyze your music library
  • +
  • Generate intelligent playlists based on audio fingerprints
  • +
  • Explore your music with similarity search and visualization tools
  • +
+
+ +
+ + +
+
+
+ + + +{% endblock %} + +{% block bodyAdditions %} + + + +{% endblock %} diff --git a/templates/sidebar_navi.html b/templates/sidebar_navi.html index 8dfa842e..2b11fb4d 100644 --- a/templates/sidebar_navi.html +++ b/templates/sidebar_navi.html @@ -12,7 +12,9 @@
  • Cleaning
  • Scheduled Tasks
  • +
  • Settings
  • +
  • Setup Wizard
  • {% if auth_enabled %}
  • 🔓 Logout
  • -{% endif %} \ No newline at end of file +{% endif %} diff --git a/templates/similarity.html b/templates/similarity.html index e9b4e532..38585040 100644 --- a/templates/similarity.html +++ b/templates/similarity.html @@ -101,7 +101,7 @@

    Create a Playlist from Results

    - +
    @@ -111,6 +111,7 @@

    Create a Playlist from Results

    {% endblock %} {% block bodyAdditions %} + + +""" + + with open(output_path, 'w') as f: + f.write(html) + + return output_path + + +def _status_class(status) -> str: + """Get CSS class for a status value.""" + if isinstance(status, TestStatus): + status = status.value + return f"status-{status.lower()}" + + +def _badge_html(status) -> str: + """Generate a badge HTML element for a status.""" + if isinstance(status, TestStatus): + status = status.value + return f'{status}' + + +def _build_category_section(cat_name: str, cat_data: dict) -> str: + """Build HTML section for a test category.""" + total = cat_data.get("total", 0) + passed = cat_data.get("passed", 0) + failed = cat_data.get("failed", 0) + warned = cat_data.get("warned", 0) + skipped = cat_data.get("skipped", 0) + errors = cat_data.get("errors", 0) + results = cat_data.get("results", []) + + # Category display name + display_names = { + "api": "API Endpoints", + "database": "Database Quality", + "docker": "Docker & Infrastructure", + "performance": "Performance Benchmarks", + "existing_tests": "Existing Test Suite", + } + display_name = display_names.get(cat_name, cat_name.replace("_", " ").title()) + + # Build table rows + rows = "" + for r in results: + status = r.get("status", "SKIP") + duration = r.get("duration_seconds", 0) + duration_str = f"{duration:.2f}s" if duration else "-" + + # Format values for display + val_a = r.get("instance_a_value", "") + val_b = r.get("instance_b_value", "") + if isinstance(val_a, (dict, list)): + val_a = json.dumps(val_a, indent=1, default=str)[:200] + if isinstance(val_b, (dict, list)): + val_b = json.dumps(val_b, indent=1, default=str)[:200] + + rows += f""" + + {_badge_html(status)} + {r.get('name', '')} + {_escape_html(r.get('message', ''))} + {_escape_html(str(val_a)[:150])} + {_escape_html(str(val_b)[:150])} + {duration_str} + """ + + return f""" +
    +
    +
    + + {display_name} +
    +
    + {passed} passed + {failed} failed + {warned} warn + {skipped} skip + {errors} err + ({total} total) +
    +
    +
    +
    + Filter: + + + + + + +
    + + + + + + + + + + + + + {rows} + +
    StatusTest NameMessageInstance AInstance BDuration
    +
    +
    """ + + +def _build_performance_chart_data(perf_category: dict) -> list: + """Extract performance comparison data for visualization.""" + if not perf_category: + return [] + + chart_data = [] + for result in perf_category.get("results", []): + name = result.get("name", "") + if not name.startswith("Latency:"): + continue + + val_a = result.get("instance_a_value", {}) + val_b = result.get("instance_b_value", {}) + + if isinstance(val_a, dict) and isinstance(val_b, dict): + mean_a = val_a.get("mean", 0) + mean_b = val_b.get("mean", 0) + if mean_a > 0 or mean_b > 0: + chart_data.append({ + "name": name.replace("Latency: ", ""), + "mean_a": mean_a, + "mean_b": mean_b, + "p95_a": val_a.get("p95", 0), + "p95_b": val_b.get("p95", 0), + }) + + return chart_data + + +def _build_perf_visual(chart_data: list) -> str: + """Build a visual performance comparison section.""" + if not chart_data: + return "" + + max_val = max( + max(d["mean_a"], d["mean_b"], d["p95_a"], d["p95_b"]) + for d in chart_data + ) or 1 + + bars = "" + for d in chart_data: + width_a = max(2, int(d["mean_a"] / max_val * 400)) + width_b = max(2, int(d["mean_b"] / max_val * 400)) + + bars += f""" +
    +
    {d['name']}
    +
    +
    A
    +
    + {d['mean_a']*1000:.1f}ms +
    +
    +
    B
    +
    + {d['mean_b']*1000:.1f}ms +
    +
    """ + + return f""" +

    Performance Visual Comparison (Mean Latency)

    +
    + Instance A + Instance B +
    +
    + {bars} +
    """ + + +def _escape_html(text: str) -> str: + """Escape HTML special characters.""" + return (str(text) + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace('"', """) + .replace("'", "'")) diff --git a/testing_suite/reports/output/.gitignore b/testing_suite/reports/output/.gitignore new file mode 100644 index 00000000..58f3ef52 --- /dev/null +++ b/testing_suite/reports/output/.gitignore @@ -0,0 +1,3 @@ +# Report output files are generated at runtime - do not commit +* +!.gitignore diff --git a/testing_suite/requirements.txt b/testing_suite/requirements.txt new file mode 100644 index 00000000..4842b788 --- /dev/null +++ b/testing_suite/requirements.txt @@ -0,0 +1,8 @@ +# Requirements for the AudioMuse-AI Testing & Comparison Suite +requests>=2.28.0 +psycopg2-binary>=2.9.0 +pyyaml>=6.0 +pytest>=7.0.0 +pytest-json-report>=1.5.0 +pytest-timeout>=2.1.0 +ftfy diff --git a/testing_suite/run_comparison.py b/testing_suite/run_comparison.py new file mode 100644 index 00000000..b2b7f579 --- /dev/null +++ b/testing_suite/run_comparison.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +""" +AudioMuse-AI Testing & Comparison Suite - CLI Entry Point + +Comprehensive tool to test all features, database quality, API results, +and performance between two AudioMuse-AI instances (e.g., main branch vs feature branch). + +Usage: + # Quick comparison with two API URLs (minimal config): + python -m testing_suite.run_comparison \ + --url-a http://main-instance:8000 \ + --url-b http://feature-instance:8000 + + # Full comparison with database and Docker access: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 \ + --pg-host-a main-db-host --pg-host-b feature-db-host \ + --flask-container-a audiomuse-main-flask --flask-container-b audiomuse-feature-flask + + # From YAML config file: + python -m testing_suite.run_comparison --config comparison_config.yaml + + # Only run specific test categories: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 \ + --only api,performance + + # Skip slow tests: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 \ + --skip docker,existing_tests + + # Discover available tests: + python -m testing_suite.run_comparison --discover + + # Verbose output: + python -m testing_suite.run_comparison \ + --url-a http://main:8000 --url-b http://feature:8000 -v +""" + +import argparse +import logging +import sys +import os + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from testing_suite.config import ( + ComparisonConfig, InstanceConfig, + load_config_from_yaml, load_config_from_env, +) +from testing_suite.orchestrator import ComparisonOrchestrator +from testing_suite.test_runner.existing_tests import ExistingTestRunner + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="AudioMuse-AI Testing & Comparison Suite", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Config source + parser.add_argument("--config", "-c", type=str, default="", + help="Path to YAML config file") + + # Discovery mode + parser.add_argument("--discover", action="store_true", + help="Discover and list all available tests, then exit") + + # Instance A + grp_a = parser.add_argument_group("Instance A (main/baseline)") + grp_a.add_argument("--url-a", type=str, default="", + help="API URL for instance A (e.g., http://localhost:8000)") + grp_a.add_argument("--name-a", type=str, default="main", + help="Name for instance A (default: main)") + grp_a.add_argument("--branch-a", type=str, default="main", + help="Branch name for instance A") + grp_a.add_argument("--pg-host-a", type=str, default="", + help="PostgreSQL host for instance A") + grp_a.add_argument("--pg-port-a", type=int, default=5432, + help="PostgreSQL port for instance A") + grp_a.add_argument("--pg-user-a", type=str, default="audiomuse", + help="PostgreSQL user for instance A") + grp_a.add_argument("--pg-pass-a", type=str, default="audiomusepassword", + help="PostgreSQL password for instance A") + grp_a.add_argument("--pg-db-a", type=str, default="audiomusedb", + help="PostgreSQL database for instance A") + grp_a.add_argument("--redis-a", type=str, default="", + help="Redis URL for instance A") + grp_a.add_argument("--flask-container-a", type=str, default="audiomuse-ai-flask-app", + help="Docker flask container name for A") + grp_a.add_argument("--worker-container-a", type=str, default="audiomuse-ai-worker-instance", + help="Docker worker container name for A") + grp_a.add_argument("--ssh-host-a", type=str, default="", + help="SSH host for remote Docker access (instance A)") + grp_a.add_argument("--ssh-user-a", type=str, default="", + help="SSH user for remote Docker access (instance A)") + grp_a.add_argument("--ssh-key-a", type=str, default="", + help="SSH key file for remote Docker access (instance A)") + + # Instance B + grp_b = parser.add_argument_group("Instance B (feature/test)") + grp_b.add_argument("--url-b", type=str, default="", + help="API URL for instance B (e.g., http://localhost:8001)") + grp_b.add_argument("--name-b", type=str, default="feature", + help="Name for instance B (default: feature)") + grp_b.add_argument("--branch-b", type=str, default="feature", + help="Branch name for instance B") + grp_b.add_argument("--pg-host-b", type=str, default="", + help="PostgreSQL host for instance B") + grp_b.add_argument("--pg-port-b", type=int, default=5432, + help="PostgreSQL port for instance B") + grp_b.add_argument("--pg-user-b", type=str, default="audiomuse", + help="PostgreSQL user for instance B") + grp_b.add_argument("--pg-pass-b", type=str, default="audiomusepassword", + help="PostgreSQL password for instance B") + grp_b.add_argument("--pg-db-b", type=str, default="audiomusedb", + help="PostgreSQL database for instance B") + grp_b.add_argument("--redis-b", type=str, default="", + help="Redis URL for instance B") + grp_b.add_argument("--flask-container-b", type=str, default="audiomuse-ai-flask-app", + help="Docker flask container name for B") + grp_b.add_argument("--worker-container-b", type=str, default="audiomuse-ai-worker-instance", + help="Docker worker container name for B") + grp_b.add_argument("--ssh-host-b", type=str, default="", + help="SSH host for remote Docker access (instance B)") + grp_b.add_argument("--ssh-user-b", type=str, default="", + help="SSH user for remote Docker access (instance B)") + grp_b.add_argument("--ssh-key-b", type=str, default="", + help="SSH key file for remote Docker access (instance B)") + + # Test selection + grp_t = parser.add_argument_group("Test Selection") + grp_t.add_argument("--only", type=str, default="", + help="Only run these categories (comma-separated: api,db,docker,performance,existing_tests)") + grp_t.add_argument("--skip", type=str, default="", + help="Skip these categories (comma-separated)") + grp_t.add_argument("--skip-setup-crud", action="store_true", + help="Skip setup wizard CRUD tests (provider create/update/delete)") + grp_t.add_argument("--enable-task-starts", action="store_true", + help="Enable task start smoke tests (analysis, clustering, cleaning)") + + # Performance settings + grp_p = parser.add_argument_group("Performance Settings") + grp_p.add_argument("--warmup", type=int, default=3, + help="Warmup requests before benchmarking (default: 3)") + grp_p.add_argument("--bench-requests", type=int, default=10, + help="Benchmark requests per endpoint (default: 10)") + grp_p.add_argument("--concurrent", type=int, default=5, + help="Concurrent users for load test (default: 5)") + + # Output + grp_o = parser.add_argument_group("Output") + grp_o.add_argument("--output-dir", "-o", type=str, default="testing_suite/reports/output", + help="Output directory for reports") + grp_o.add_argument("--format", type=str, default="both", choices=["html", "json", "both"], + help="Report format (default: both)") + grp_o.add_argument("-v", "--verbose", action="store_true", + help="Verbose output") + + return parser + + +def build_config(args) -> ComparisonConfig: + """Build ComparisonConfig from CLI arguments.""" + # Start with YAML file or env if specified + if args.config: + config = load_config_from_yaml(args.config) + else: + config = load_config_from_env() + + # CLI overrides + a = config.instance_a + b = config.instance_b + + if args.url_a: + a.api_url = args.url_a + if args.name_a: + a.name = args.name_a + if args.branch_a: + a.branch = args.branch_a + if args.pg_host_a: + a.pg_host = args.pg_host_a + a.pg_port = args.pg_port_a + if args.pg_user_a: + a.pg_user = args.pg_user_a + if args.pg_pass_a: + a.pg_password = args.pg_pass_a + if args.pg_db_a: + a.pg_database = args.pg_db_a + if args.redis_a: + a.redis_url = args.redis_a + if args.flask_container_a: + a.docker_flask_container = args.flask_container_a + if args.worker_container_a: + a.docker_worker_container = args.worker_container_a + if args.ssh_host_a: + a.ssh_host = args.ssh_host_a + if args.ssh_user_a: + a.ssh_user = args.ssh_user_a + if args.ssh_key_a: + a.ssh_key = args.ssh_key_a + + if args.url_b: + b.api_url = args.url_b + if args.name_b: + b.name = args.name_b + if args.branch_b: + b.branch = args.branch_b + if args.pg_host_b: + b.pg_host = args.pg_host_b + b.pg_port = args.pg_port_b + if args.pg_user_b: + b.pg_user = args.pg_user_b + if args.pg_pass_b: + b.pg_password = args.pg_pass_b + if args.pg_db_b: + b.pg_database = args.pg_db_b + if args.redis_b: + b.redis_url = args.redis_b + if args.flask_container_b: + b.docker_flask_container = args.flask_container_b + if args.worker_container_b: + b.docker_worker_container = args.worker_container_b + if args.ssh_host_b: + b.ssh_host = args.ssh_host_b + if args.ssh_user_b: + b.ssh_user = args.ssh_user_b + if args.ssh_key_b: + b.ssh_key = args.ssh_key_b + + # Performance settings + config.perf_warmup_requests = args.warmup + config.perf_benchmark_requests = args.bench_requests + config.perf_concurrent_users = args.concurrent + + # Output settings + config.output_dir = args.output_dir + config.report_format = args.format + config.verbose = args.verbose + + # Test selection + if args.only: + categories = set(args.only.split(",")) + config.run_api_tests = "api" in categories + config.run_db_tests = "db" in categories or "database" in categories + config.run_docker_tests = "docker" in categories + config.run_performance_tests = "performance" in categories or "perf" in categories + config.run_existing_unit_tests = "existing_tests" in categories or "unit" in categories + config.run_existing_integration_tests = "existing_tests" in categories or "integration" in categories + + if args.skip: + skip = set(args.skip.split(",")) + if "api" in skip: + config.run_api_tests = False + if "db" in skip or "database" in skip: + config.run_db_tests = False + if "docker" in skip: + config.run_docker_tests = False + if "performance" in skip or "perf" in skip: + config.run_performance_tests = False + if "existing_tests" in skip or "unit" in skip: + config.run_existing_unit_tests = False + if "existing_tests" in skip or "integration" in skip: + config.run_existing_integration_tests = False + + # Advanced test group flags + if args.skip_setup_crud: + config.run_setup_crud_tests = False + if args.enable_task_starts: + config.run_task_start_tests = True + + return config + + +def main(): + parser = build_parser() + args = parser.parse_args() + + # Setup logging + level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + + # Discovery mode + if args.discover: + discovery = ExistingTestRunner.discover_tests() + print("\n=== AudioMuse-AI Test Discovery ===\n") + + print(f"Unit Tests ({len(discovery['unit_tests'])} files):") + for t in discovery["unit_tests"]: + status = "OK" if t["exists"] else "MISSING" + print(f" [{status}] {t['file']}") + + print(f"\nIntegration Tests ({len(discovery['integration_tests'])} files):") + for t in discovery["integration_tests"]: + status = "OK" if t["exists"] else "MISSING" + print(f" [{status}] {t['file']}") + + print(f"\nE2E API Tests ({len(discovery['e2e_tests'])} tests):") + for t in discovery["e2e_tests"]: + print(f" [OK] {t['name']} ({t['file']})") + + total = (len(discovery["unit_tests"]) + + len(discovery["integration_tests"]) + + len(discovery["e2e_tests"])) + print(f"\nTotal: {total} test files/entries discovered.\n") + return 0 + + # Validate minimum config + config = build_config(args) + + if not config.instance_a.api_url or not config.instance_b.api_url: + if not args.config: + parser.error("At least --url-a and --url-b are required " + "(or use --config for YAML config)") + + # Run comparison + orchestrator = ComparisonOrchestrator(config) + report = orchestrator.run() + orchestrator.print_summary() + + # Exit code based on results + if report.total_failed > 0 or report.total_errors > 0: + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/testing_suite/test_ai_naming.py b/testing_suite/test_ai_naming.py new file mode 100644 index 00000000..398945fb --- /dev/null +++ b/testing_suite/test_ai_naming.py @@ -0,0 +1,849 @@ +#!/usr/bin/env python3 +""" +AudioMuse-AI - AI Playlist Naming Performance Test + +Compares how different AI models perform on the same playlist naming prompt. +Sends identical song lists to multiple Ollama + OpenRouter models, runs N times +each, and produces a comparison report (console, TXT, HTML, JSON). + +Usage: + python testing_suite/test_ai_naming.py + python testing_suite/test_ai_naming.py --config path/to/config.yaml + python testing_suite/test_ai_naming.py --runs 10 + python testing_suite/test_ai_naming.py --dry-run +""" + +import argparse +import json +import os +import re +import subprocess +import sys +import time +import unicodedata +from datetime import datetime + +import ftfy +import requests +import yaml + + +# --------------------------------------------------------------------------- +# Prompt template (inlined from ai.py:16-28) +# --------------------------------------------------------------------------- +CREATIVE_PROMPT_TEMPLATE = ( + "You are an expert music collector and MUST give a title to this playlist.\n" + "The title MUST represent the mood and the activity of when you are listening to the playlist.\n" + "The title MUST use ONLY standard ASCII (a-z, A-Z, 0-9, spaces, and - & ' ! . , ? ( ) [ ]).\n" + "The title MUST be within the range of 5 to 40 characters long.\n" + "No special fonts or emojis.\n" + "* BAD EXAMPLES: 'Ambient Electronic Space - Electric Soundscapes - Emotional Waves' (Too long/descriptive)\n" + "* BAD EXAMPLES: 'Blues Rock Fast Tracks' (Too direct/literal, not evocative enough)\n" + "* BAD EXAMPLES: '\U0001d46f\U0001d4f0\U0001d4ea \U0001d4ea\U0001d4fb\U0001d4f8\U0001d4f7\U0001d4f2 \U0001d4ed\U0001d4ea\U0001d4fd\U0001d4fc' (Non-standard characters)\n\n" + "CRITICAL: Your response MUST be ONLY the single playlist name. No explanations, no 'Playlist Name:', no numbering, no extra text or formatting whatsoever.\n\n" + "This is the playlist:\n{song_list_sample}\n\n" +) + +MIN_NAME_LENGTH = 5 +MAX_NAME_LENGTH = 40 + + +# --------------------------------------------------------------------------- +# Name cleaning (inlined from ai.py:30-43) +# --------------------------------------------------------------------------- +def clean_playlist_name(name: str) -> str: + if not isinstance(name, str): + return "" + name = ftfy.fix_text(name) + name = unicodedata.normalize('NFKC', name) + cleaned = re.sub(r'[^a-zA-Z0-9\s\-\&\'!\.\,\?\(\)\[\]]', '', name) + cleaned = re.sub(r'\s\(\d+\)$', '', cleaned) + cleaned = re.sub(r'\s+', ' ', cleaned).strip() + return cleaned + + +# --------------------------------------------------------------------------- +# Song data fetching +# --------------------------------------------------------------------------- +def fetch_songs_from_db(container: str, user: str, database: str, total: int) -> list[dict]: + """Fetch random songs from the PostgreSQL database via docker exec.""" + query = ( + f"SELECT title, author FROM score " + f"WHERE title IS NOT NULL AND author IS NOT NULL " + f"ORDER BY RANDOM() LIMIT {total}" + ) + cmd = [ + "docker", "exec", container, + "psql", "-U", user, "-d", database, + "-t", "-A", "-F", "|", + "-c", query, + ] + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=30) + except FileNotFoundError: + print("ERROR: 'docker' command not found. Is Docker installed and in PATH?") + sys.exit(1) + except subprocess.TimeoutExpired: + print("ERROR: Database query timed out after 30 seconds.") + sys.exit(1) + + if result.returncode != 0: + print(f"ERROR: Database query failed.\n Command: {' '.join(cmd)}\n Stderr: {result.stderr.strip()}") + sys.exit(1) + + songs = [] + for line in result.stdout.strip().split('\n'): + line = line.strip() + if not line: + continue + parts = line.split('|', 1) + if len(parts) == 2: + songs.append({"title": parts[0].strip(), "author": parts[1].strip()}) + + if len(songs) < total: + print(f"ERROR: Not enough songs in database. Found {len(songs)}, need {total}.") + sys.exit(1) + + return songs + + +def apply_defaults(config: dict) -> None: + """Merge provider defaults (url, api_key) into each model entry.""" + defaults = config.get("defaults", {}) + for model in config.get("models", []): + provider = model.get("provider", "") + provider_defaults = defaults.get(provider, {}) + for key, value in provider_defaults.items(): + if key not in model: + model[key] = value + + # Allow environment variable override for API keys + env_api_key = os.environ.get('OPENROUTER_API_KEY') + if env_api_key: + for model in config.get("models", []): + if model.get("provider") == "openrouter": + model["api_key"] = env_api_key + + +def split_into_playlists(songs: list[dict], num_playlists: int, per_playlist: int) -> list[list[dict]]: + """Split a flat song list into N playlists of M songs each.""" + playlists = [] + for i in range(num_playlists): + start = i * per_playlist + playlists.append(songs[start:start + per_playlist]) + return playlists + + +# --------------------------------------------------------------------------- +# Prompt building +# --------------------------------------------------------------------------- +def build_prompt(songs: list[dict], template: str | None = None) -> str: + """Build the full prompt from a list of songs and an optional template.""" + formatted = "\n".join(f"- {s['title']} by {s['author']}" for s in songs) + tpl = template if template else CREATIVE_PROMPT_TEMPLATE + return tpl.format(song_list_sample=formatted) + + +# --------------------------------------------------------------------------- +# API calling (inlined from ai.py:47-183) +# --------------------------------------------------------------------------- +def call_model(model_cfg: dict, prompt: str, timeout: int) -> dict: + """ + Call an AI model and return result dict with keys: + name, raw_response, cleaned_name, valid, elapsed, error + """ + provider = model_cfg["provider"] + url = model_cfg["url"] + model_id = model_cfg["model_id"] + api_key = model_cfg.get("api_key", "") + + is_openai_format = ( + bool(api_key) or + "openai" in url.lower() or + "openrouter" in url.lower() + ) + + headers = {"Content-Type": "application/json"} + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + if "openrouter" in url.lower(): + headers["HTTP-Referer"] = "https://github.com/NeptuneHub/AudioMuse-AI" + headers["X-Title"] = "AudioMuse-AI" + + if is_openai_format: + payload = { + "model": model_id, + "messages": [{"role": "user", "content": prompt}], + "stream": True, + "temperature": 0.7, + "max_tokens": 8000, + } + else: + payload = { + "model": model_id, + "prompt": prompt, + "stream": True, + "options": { + "num_predict": 8000, + "temperature": 0.7, + }, + } + + start_time = time.time() + try: + response = requests.post( + url, headers=headers, data=json.dumps(payload), + stream=True, timeout=timeout, + ) + response.raise_for_status() + + full_raw = "" + for line in response.iter_lines(): + if not line: + continue + line_str = line.decode('utf-8', errors='ignore').strip() + + if line_str.startswith(':'): + continue + + if line_str.startswith('data: '): + line_str = line_str[6:] + if line_str == '[DONE]': + break + + try: + chunk = json.loads(line_str) + if is_openai_format: + if 'choices' in chunk and len(chunk['choices']) > 0: + choice = chunk['choices'][0] + finish_reason = choice.get('finish_reason') + if finish_reason in ('stop', 'length'): + # Grab any final content before breaking + if 'delta' in choice: + c = choice['delta'].get('content') + if c: + full_raw += c + break + if 'delta' in choice: + c = choice['delta'].get('content') + if c is not None: + full_raw += c + elif 'text' in choice: + t = choice.get('text') + if t is not None: + full_raw += t + else: + if 'response' in chunk: + full_raw += chunk['response'] + if chunk.get('done'): + break + except json.JSONDecodeError: + continue + + elapsed = time.time() - start_time + + # Strip think tags (inlined from ai.py:178-182) + extracted = full_raw.strip() + for tag in ["", "[/INST]", "[/THOUGHT]"]: + if tag in extracted: + extracted = extracted.split(tag, 1)[-1].strip() + + if not extracted: + return { + "raw_response": full_raw, + "cleaned_name": "", + "valid": False, + "elapsed": elapsed, + "error": "Empty response after think-tag stripping", + } + + cleaned = clean_playlist_name(extracted) + valid = MIN_NAME_LENGTH <= len(cleaned) <= MAX_NAME_LENGTH + + return { + "raw_response": extracted, + "cleaned_name": cleaned, + "valid": valid, + "elapsed": elapsed, + "error": None, + } + + except requests.exceptions.ConnectionError: + elapsed = time.time() - start_time + return { + "raw_response": "", + "cleaned_name": "", + "valid": False, + "elapsed": elapsed, + "error": "Connection refused", + } + except requests.exceptions.Timeout: + elapsed = time.time() - start_time + return { + "raw_response": "", + "cleaned_name": "", + "valid": False, + "elapsed": elapsed, + "error": f"Timeout after {timeout}s", + } + except requests.exceptions.HTTPError as e: + elapsed = time.time() - start_time + detail = "" + try: + detail = e.response.text[:200] + except Exception: + pass + return { + "raw_response": "", + "cleaned_name": "", + "valid": False, + "elapsed": elapsed, + "error": f"HTTP {e.response.status_code}: {detail}", + } + except Exception as e: + elapsed = time.time() - start_time + return { + "raw_response": "", + "cleaned_name": "", + "valid": False, + "elapsed": elapsed, + "error": str(e), + } + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- +def generate_summary_table(results: dict, timestamp: str) -> str: + """Generate the ASCII summary table.""" + lines = [] + lines.append("=" * 70) + lines.append(f" RESULTS - AI Playlist Naming Test ({timestamp})") + lines.append("=" * 70) + lines.append(f" {'Model':<22} {'Tests':>5} {'Valid':>5} {'Rate':>6} {'Avg':>6} {'Min':>6} {'Max':>6}") + lines.append("-" * 70) + + for model_name, model_data in results.items(): + all_runs = model_data["runs"] + total = len(all_runs) + valid = sum(1 for r in all_runs if r["valid"]) + rate = (valid / total * 100) if total > 0 else 0 + times = [r["elapsed"] for r in all_runs if r["error"] is None] + avg_t = sum(times) / len(times) if times else 0 + min_t = min(times) if times else 0 + max_t = max(times) if times else 0 + + lines.append( + f" {model_name:<22} {total:>5} {valid:>5} {rate:>5.1f}% {avg_t:>5.1f}s {min_t:>5.1f}s {max_t:>5.1f}s" + ) + + lines.append("-" * 70) + return "\n".join(lines) + + +def generate_names_table(results: dict, playlists: list[list[dict]], num_runs: int) -> str: + """Generate per-playlist names detail table.""" + lines = [] + model_names = list(results.keys()) + num_playlists = len(playlists) + + for pi in range(num_playlists): + lines.append(f"\nPlaylist {pi + 1} - Generated Names:") + header = f" {'Model':<22}" + for ri in range(num_runs): + header += f" {'Run ' + str(ri + 1):<24}" + lines.append(header) + lines.append(" " + "-" * (22 + num_runs * 26)) + + for model_name in model_names: + row = f" {model_name:<22}" + runs = results[model_name]["runs"] + # Filter runs for this playlist + playlist_runs = [r for r in runs if r["playlist_index"] == pi] + for r in playlist_runs: + name = r.get("cleaned_name", "") + if r.get("error"): + name = f"[ERR: {r['error'][:15]}]" + elif not r["valid"]: + name = f"[INVALID: {name[:12]}]" + # Truncate to fit column + if len(name) > 23: + name = name[:20] + "..." + row += f" {name:<24}" + lines.append(row) + + return "\n".join(lines) + + +def generate_html_report(results: dict, playlists: list[list[dict]], + num_runs: int, timestamp: str, config: dict, + save_raw: bool) -> str: + """Generate a self-contained HTML report.""" + model_names = list(results.keys()) + num_playlists = len(playlists) + + # Build summary rows + summary_rows = "" + for model_name, model_data in results.items(): + all_runs = model_data["runs"] + total = len(all_runs) + valid = sum(1 for r in all_runs if r["valid"]) + rate = (valid / total * 100) if total > 0 else 0 + errors = sum(1 for r in all_runs if r["error"]) + times = [r["elapsed"] for r in all_runs if r["error"] is None] + avg_t = sum(times) / len(times) if times else 0 + min_t = min(times) if times else 0 + max_t = max(times) if times else 0 + + rate_class = "pass" if rate >= 80 else ("warn" if rate >= 50 else "fail") + provider = model_data.get("provider", "") + + summary_rows += f""" + {model_name}{provider} + {total}{valid}{errors} + {rate:.1f}% + {avg_t:.2f}s{min_t:.2f}s{max_t:.2f}s + \n""" + + # Build per-playlist detail sections + playlist_sections = "" + for pi in range(num_playlists): + song_list_html = "
      \n" + for s in playlists[pi]: + song_list_html += f"
    • {s['title']} — {s['author']}
    • \n" + song_list_html += "
    " + + detail_rows = "" + for model_name in model_names: + runs = [r for r in results[model_name]["runs"] if r["playlist_index"] == pi] + valid_count = sum(1 for r in runs if r["valid"]) + total_count = len(runs) + times = [r["elapsed"] for r in runs if r["error"] is None] + avg_t = sum(times) / len(times) if times else 0 + rate = (valid_count / total_count * 100) if total_count else 0 + rate_class = "pass" if rate >= 80 else ("warn" if rate >= 50 else "fail") + + # Build all names into a single cell + names_html = "" + for ri, r in enumerate(runs): + name = r.get("cleaned_name", "") + error = r.get("error") + raw = r.get("raw_response", "").replace("&", "&").replace("<", "<").replace(">", ">") + + if error: + names_html += f'
    Run {ri + 1}: {error} ({r["elapsed"]:.1f}s)
    \n' + elif r["valid"]: + raw_detail = f'
    raw
    {raw}
    ' if save_raw and raw else "" + names_html += f'
    Run {ri + 1}: {name} ({r["elapsed"]:.1f}s){raw_detail}
    \n' + else: + raw_detail = f'
    raw
    {raw}
    ' if save_raw and raw else "" + names_html += f'
    Run {ri + 1}: {name} ({len(name)} chars) ({r["elapsed"]:.1f}s){raw_detail}
    \n' + + detail_rows += f""" + {model_name}
    {results[model_name].get('provider', '')} + {valid_count}/{total_count} ({rate:.0f}%) + {avg_t:.2f}s + {names_html} + \n""" + + playlist_sections += f""" +
    +

    Playlist {pi + 1}

    +
    + Songs used: + {song_list_html} +
    + + + + + + {detail_rows} +
    ModelValidAvg TimeGenerated Names
    +
    + """ + + html = f""" + + + + +AI Playlist Naming Test - {timestamp} + + + +

    AI Playlist Naming Test

    +

    Date: {timestamp}  |  + Runs per model: {num_runs}  |  + Playlists: {num_playlists}  |  + Songs per playlist: {len(playlists[0]) if playlists else 0}

    + +

    Summary

    + + + + + + {summary_rows} +
    ModelProviderTestsValidErrorsValid RateAvg TimeMin TimeMax Time
    + +

    Prompt Used

    +
    + Show prompt template +
    {(config.get('prompt') or CREATIVE_PROMPT_TEMPLATE).replace('&', '&').replace('<', '<').replace('>', '>').replace('{song_list_sample}', '{song_list_sample}')}
    +
    + +

    Detailed Results

    +{playlist_sections} + +

    Test Configuration

    +
    {json.dumps(config, indent=2, default=str)}
    + + + +""" + return html + + +def generate_json_report(results: dict, playlists: list[list[dict]], + timestamp: str, config: dict) -> dict: + """Generate the full JSON report.""" + report = { + "timestamp": timestamp, + "config": config, + "playlists": [ + [{"title": s["title"], "author": s["author"]} for s in pl] + for pl in playlists + ], + "models": {}, + } + + for model_name, model_data in results.items(): + all_runs = model_data["runs"] + total = len(all_runs) + valid = sum(1 for r in all_runs if r["valid"]) + errors = sum(1 for r in all_runs if r["error"]) + times = [r["elapsed"] for r in all_runs if r["error"] is None] + + report["models"][model_name] = { + "provider": model_data.get("provider", ""), + "model_id": model_data.get("model_id", ""), + "url": model_data.get("url", ""), + "summary": { + "total_tests": total, + "valid": valid, + "invalid": total - valid - errors, + "errors": errors, + "valid_rate": round(valid / total * 100, 1) if total > 0 else 0, + "avg_time": round(sum(times) / len(times), 3) if times else 0, + "min_time": round(min(times), 3) if times else 0, + "max_time": round(max(times), 3) if times else 0, + }, + "runs": [ + { + "playlist_index": r["playlist_index"], + "run_index": r["run_index"], + "cleaned_name": r.get("cleaned_name", ""), + "raw_response": r.get("raw_response", ""), + "valid": r["valid"], + "elapsed": round(r["elapsed"], 3), + "error": r.get("error"), + "name_length": len(r.get("cleaned_name", "")), + } + for r in all_runs + ], + } + + return report + + +# --------------------------------------------------------------------------- +# Main test loop +# --------------------------------------------------------------------------- +def run_tests(config: dict, dry_run: bool = False) -> tuple[dict, list[list[dict]]]: + """ + Execute the full test suite. + + Returns: + (results_dict, playlists) + results_dict keys are model names, values have 'runs' list and metadata. + """ + pg = config["postgres"] + tc = config["test_config"] + models = [m for m in config["models"] if m.get("enabled", False)] + + if not models: + print("ERROR: No models enabled in configuration.") + sys.exit(1) + + num_runs = tc["num_runs_per_model"] + num_playlists = tc["num_playlists"] + songs_per = tc["songs_per_playlist"] + timeout = tc.get("timeout_per_request", 120) + total_songs = num_playlists * songs_per + + # Use sample_songs from config if present, otherwise fetch from DB + sample_songs = config.get("sample_songs") + if sample_songs: + # Single playlist when using hardcoded songs (no point repeating identical lists) + num_playlists = 1 + playlists = [sample_songs[:songs_per]] + print(f"Using {len(sample_songs)} songs from config (sample_songs, 1 playlist)...\n") + else: + print(f"Fetching {total_songs} songs from database ({pg['container_name']})...") + songs = fetch_songs_from_db( + pg["container_name"], pg["user"], pg["database"], total_songs, + ) + playlists = split_into_playlists(songs, num_playlists, songs_per) + print(f" OK - {len(songs)} songs split into {num_playlists} playlists of {songs_per}\n") + + # Build prompts (same for all models) + # Use prompt from config if provided, otherwise fall back to hardcoded default + prompt_template = config.get("prompt") + prompts = [build_prompt(pl, prompt_template) for pl in playlists] + + if dry_run: + print("=== DRY RUN MODE ===") + print(f"Would test {len(models)} model(s), {num_playlists} playlist(s), {num_runs} run(s) each\n") + for mi, m in enumerate(models): + print(f" Model {mi + 1}: {m['name']} ({m['provider']}) - {m['model_id']}") + print(f"\nPlaylist 1 prompt preview (first 500 chars):") + print(prompts[0][:500]) + print("...") + return {}, playlists + + # Run tests + results = {} + total_models = len(models) + connection_failures = set() + + for mi, model in enumerate(models): + model_name = model["name"] + print(f"[{mi + 1}/{total_models}] Testing: {model_name} ({model['provider']})") + + results[model_name] = { + "provider": model["provider"], + "model_id": model["model_id"], + "url": model["url"], + "runs": [], + } + + # Skip if previous connection to same URL failed + if model["url"] in connection_failures: + print(f" Skipping (connection to {model['url']} already failed)\n") + for pi in range(num_playlists): + for ri in range(num_runs): + results[model_name]["runs"].append({ + "playlist_index": pi, + "run_index": ri, + "raw_response": "", + "cleaned_name": "", + "valid": False, + "elapsed": 0, + "error": "Skipped (connection failed)", + }) + continue + + model_valid = 0 + model_total = 0 + model_times = [] + abort_model = False + + for pi in range(num_playlists): + for ri in range(num_runs): + if abort_model: + results[model_name]["runs"].append({ + "playlist_index": pi, + "run_index": ri, + "raw_response": "", + "cleaned_name": "", + "valid": False, + "elapsed": 0, + "error": "Skipped (connection failed)", + }) + continue + + model_total += 1 + status_prefix = f" Playlist {pi + 1}: Run {ri + 1}/{num_runs}..." + + result = call_model(model, prompts[pi], timeout) + result["playlist_index"] = pi + result["run_index"] = ri + results[model_name]["runs"].append(result) + + if result["error"] == "Connection refused": + print(f"{status_prefix} FAIL (connection refused)") + connection_failures.add(model["url"]) + abort_model = True + continue + + if result["error"]: + print(f"{status_prefix} ERR {result['elapsed']:.1f}s {result['error']}") + elif result["valid"]: + model_valid += 1 + model_times.append(result["elapsed"]) + print(f"{status_prefix} OK {result['elapsed']:.1f}s \"{result['cleaned_name']}\"") + else: + name = result["cleaned_name"] + print(f"{status_prefix} INVALID {result['elapsed']:.1f}s \"{name}\" ({len(name)} chars)") + model_times.append(result["elapsed"]) + + # Model summary + if model_total > 0 and not abort_model: + avg_t = sum(model_times) / len(model_times) if model_times else 0 + rate = model_valid / model_total * 100 + print(f" Result: {model_valid}/{model_total} valid ({rate:.1f}%), avg {avg_t:.1f}s\n") + elif abort_model: + print(f" Result: Aborted (connection failed)\n") + + return results, playlists + + +def save_reports(results: dict, playlists: list[list[dict]], config: dict, + num_runs: int, output_dir: str, save_raw: bool): + """Save TXT, HTML, and JSON reports to disk.""" + os.makedirs(output_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") + file_ts = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Console / TXT summary + summary = generate_summary_table(results, timestamp) + names_detail = generate_names_table(results, playlists, num_runs) + full_txt = summary + "\n" + names_detail + "\n" + + print("\n" + full_txt) + + txt_path = os.path.join(output_dir, f"ai_naming_{file_ts}.txt") + with open(txt_path, "w", encoding="utf-8") as f: + f.write(full_txt) + print(f"TXT report saved: {txt_path}") + + # HTML report + html = generate_html_report(results, playlists, num_runs, timestamp, config, save_raw) + html_path = os.path.join(output_dir, f"ai_naming_{file_ts}.html") + with open(html_path, "w", encoding="utf-8") as f: + f.write(html) + print(f"HTML report saved: {html_path}") + + # JSON report + json_data = generate_json_report(results, playlists, timestamp, config) + json_path = os.path.join(output_dir, f"ai_naming_{file_ts}.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(json_data, f, indent=2, default=str) + print(f"JSON report saved: {json_path}") + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser( + description="AudioMuse-AI - AI Playlist Naming Performance Test", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + _default_cfg = "testing_suite/ai_naming_test_config.yaml" + if not os.path.exists(_default_cfg): + _default_cfg = "testing_suite/ai_naming_test_config.example.yaml" + parser.add_argument("--config", "-c", type=str, + default=_default_cfg, + help="Path to YAML config file (default: ai_naming_test_config.yaml)") + parser.add_argument("--runs", "-n", type=int, default=None, + help="Override num_runs_per_model from config") + parser.add_argument("--dry-run", action="store_true", + help="Fetch songs and build prompts, but don't call any APIs") + + args = parser.parse_args() + + # Load config + if not os.path.exists(args.config): + print(f"ERROR: Config file not found: {args.config}") + print(f"Usage: python testing_suite/test_ai_naming.py --config path/to/config.yaml") + sys.exit(1) + + with open(args.config, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + + # Merge provider defaults (url, api_key) into each model entry + apply_defaults(config) + + # Apply CLI overrides + if args.runs is not None: + config["test_config"]["num_runs_per_model"] = args.runs + + num_runs = config["test_config"]["num_runs_per_model"] + output_cfg = config.get("output", {}) + output_dir = output_cfg.get("directory", "testing_suite/reports/ai_naming") + save_raw = output_cfg.get("save_raw_responses", True) + + print("=" * 60) + print(" AudioMuse-AI - AI Playlist Naming Performance Test") + print("=" * 60) + + enabled = [m for m in config["models"] if m.get("enabled", False)] + print(f" Models: {len(enabled)} enabled") + print(f" Playlists: {config['test_config']['num_playlists']}") + print(f" Runs/model: {num_runs}") + print(f" Songs/playlist: {config['test_config']['songs_per_playlist']}") + print("=" * 60 + "\n") + + # Run tests + results, playlists = run_tests(config, dry_run=args.dry_run) + + if args.dry_run: + print("\nDry run complete. No API calls were made.") + return + + if not results: + print("No results to report.") + return + + # Generate and save reports + save_reports(results, playlists, config, num_runs, output_dir, save_raw) + + +if __name__ == "__main__": + main() diff --git a/testing_suite/test_instant_playlist.py b/testing_suite/test_instant_playlist.py new file mode 100644 index 00000000..4ba1debc --- /dev/null +++ b/testing_suite/test_instant_playlist.py @@ -0,0 +1,1646 @@ +#!/usr/bin/env python3 +""" +AudioMuse-AI - Instant Playlist Tool-Calling Performance Test + +Benchmarks how well different AI models select the correct MCP tool +when given a natural language playlist request. Mirrors test_ai_naming.py +structure but tests tool selection instead of text generation. + +Sends the unified system prompt + user query to each model, parses the +tool call response, and scores: JSON valid, correct tool, valid args, +pre-execution valid. + +Usage: + python testing_suite/test_instant_playlist.py + python testing_suite/test_instant_playlist.py --config path/to/config.yaml + python testing_suite/test_instant_playlist.py --runs 5 + python testing_suite/test_instant_playlist.py --dry-run +""" + +import argparse +import json +import os +import re +import sys +import time +from datetime import datetime + +import requests +import yaml + + +# --------------------------------------------------------------------------- +# Valid tool names (authoritative list) +# --------------------------------------------------------------------------- +VALID_TOOL_NAMES = [ + "song_similarity", + "text_search", + "artist_similarity", + "song_alchemy", + "ai_brainstorm", + "search_database", +] + +# search_database filter keys checked during pre-execution validation +SEARCH_DB_FILTER_KEYS = [ + "genres", "moods", "tempo_min", "tempo_max", "energy_min", "energy_max", + "key", "scale", "year_min", "year_max", "min_rating", +] + + +# --------------------------------------------------------------------------- +# Tool definitions (inlined from ai_mcp_client.py:674-904) +# --------------------------------------------------------------------------- +def get_tool_definitions(clap_enabled: bool) -> list[dict]: + """Return the 6 MCP tool definitions. Mirrors get_mcp_tools().""" + tools = [ + { + "name": "song_similarity", + "description": "PRIORITY #1: MOST SPECIFIC - Find songs similar to a specific song (requires exact title+artist). USE when user mentions a SPECIFIC SONG TITLE.", + "inputSchema": { + "type": "object", + "properties": { + "song_title": { + "type": "string", + "description": "Song title" + }, + "song_artist": { + "type": "string", + "description": "Artist name" + }, + "get_songs": { + "type": "integer", + "description": "Number of songs", + "default": 100 + } + }, + "required": ["song_title", "song_artist"] + } + } + ] + + if clap_enabled: + tools.append({ + "name": "text_search", + "description": "PRIORITY #2: HIGH PRIORITY - Natural language search using CLAP. USE for: INSTRUMENTS (piano, guitar, ukulele), SOUND DESCRIPTIONS (romantic, dreamy, chill vibes), DESCRIPTIVE QUERIES ('energetic workout'). Supports optional tempo/energy filters for hybrid search.", + "inputSchema": { + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Natural language description (e.g., 'piano music', 'romantic pop', 'ukulele songs', 'energetic guitar rock')" + }, + "tempo_filter": { + "type": "string", + "enum": ["slow", "medium", "fast"], + "description": "Optional: Filter CLAP results by tempo (hybrid mode)" + }, + "energy_filter": { + "type": "string", + "enum": ["low", "medium", "high"], + "description": "Optional: Filter CLAP results by energy (hybrid mode)" + }, + "get_songs": { + "type": "integer", + "description": "Number of songs", + "default": 100 + } + }, + "required": ["description"] + } + }) + + p = '3' if clap_enabled else '2' + tools.append({ + "name": "artist_similarity", + "description": f"PRIORITY #{p}: Find songs BY an artist AND similar artists. USE for: 'songs by/from/like Artist X' including the artist's own songs (call once per artist). DON'T USE for: 'sounds LIKE multiple artists blended' (use song_alchemy).", + "inputSchema": { + "type": "object", + "properties": { + "artist": { + "type": "string", + "description": "Artist name" + }, + "get_songs": { + "type": "integer", + "description": "Number of songs", + "default": 100 + } + }, + "required": ["artist"] + } + }) + + p2 = '4' if clap_enabled else '3' + tools.append({ + "name": "song_alchemy", + "description": f"PRIORITY #{p2}: VECTOR ARITHMETIC - Blend or subtract MULTIPLE artists/songs. REQUIRES 2+ items. Keywords: 'meets', 'combined', 'blend', 'mix of', 'but not', 'without'. BEST for: 'play like A + B' ('play like Iron Maiden, Metallica, Deep Purple'), 'like X but NOT Y', 'Artist A meets Artist B', 'mix of A and B'. DON'T USE for: single artist (use artist_similarity), genre/mood (use search_database). Examples: 'play like Iron Maiden + Metallica + Deep Purple' = add all 3; 'Beatles but not ballads' = add Beatles, subtract ballads.", + "inputSchema": { + "type": "object", + "properties": { + "add_items": { + "type": "array", + "description": "Items to ADD (blend into result). Each item: {type: 'song' or 'artist', id: 'artist_name' or 'song_title by artist'}", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["song", "artist"], + "description": "Item type: 'song' or 'artist'" + }, + "id": { + "type": "string", + "description": "For artist: 'Artist Name'; For song: 'Song Title by Artist Name'" + } + }, + "required": ["type", "id"] + } + }, + "subtract_items": { + "type": "array", + "description": "Items to SUBTRACT (remove from result). Same format as add_items.", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["song", "artist"], + "description": "Item type: 'song' or 'artist'" + }, + "id": { + "type": "string", + "description": "For artist: 'Artist Name'; For song: 'Song Title by Artist Name'" + } + }, + "required": ["type", "id"] + } + }, + "get_songs": { + "type": "integer", + "description": "Number of songs", + "default": 100 + } + }, + "required": ["add_items"] + } + }) + + p3 = '5' if clap_enabled else '4' + tools.append({ + "name": "ai_brainstorm", + "description": f"PRIORITY #{p3}: AI world knowledge - Use ONLY when other tools CAN'T work. USE for: named events (Grammy, Billboard, festivals), cultural knowledge (trending, viral, classic hits), historical significance (best of decade, iconic albums), songs NOT in library. DON'T USE for: artist's own songs (use artist_similarity), 'sounds like' (use song_alchemy), genre/mood (use search_database), instruments/moods (use text_search if available).", + "inputSchema": { + "type": "object", + "properties": { + "user_request": { + "type": "string", + "description": "User's request" + }, + "get_songs": { + "type": "integer", + "description": "Number of songs", + "default": 100 + } + }, + "required": ["user_request"] + } + }) + + p4 = '6' if clap_enabled else '5' + tools.append({ + "name": "search_database", + "description": f"PRIORITY #{p4}: MOST GENERAL (last resort) - Search by genre/mood/tempo/energy/year/rating/scale filters. USE for: genre/mood/tempo combinations when NO specific artists/songs mentioned AND text_search not available/suitable. DON'T USE if you can use other more specific tools. COMBINE all filters in ONE call!", + "inputSchema": { + "type": "object", + "properties": { + "genres": { + "type": "array", + "items": {"type": "string"}, + "description": "Genres (rock, pop, metal, jazz, etc.)" + }, + "moods": { + "type": "array", + "items": {"type": "string"}, + "description": "Moods (danceable, aggressive, happy, party, relaxed, sad)" + }, + "tempo_min": { + "type": "number", + "description": "Min BPM (40-200)" + }, + "tempo_max": { + "type": "number", + "description": "Max BPM (40-200)" + }, + "energy_min": { + "type": "number", + "description": "Min energy 0.0 (calm) to 1.0 (intense)" + }, + "energy_max": { + "type": "number", + "description": "Max energy 0.0 (calm) to 1.0 (intense)" + }, + "key": { + "type": "string", + "description": "Musical key (C, D, E, F, G, A, B with # or b)" + }, + "scale": { + "type": "string", + "enum": ["major", "minor"], + "description": "Musical scale: major or minor" + }, + "year_min": { + "type": "integer", + "description": "Earliest release year (e.g. 1990)" + }, + "year_max": { + "type": "integer", + "description": "Latest release year (e.g. 1999)" + }, + "min_rating": { + "type": "integer", + "description": "Minimum user rating 1-5" + }, + "get_songs": { + "type": "integer", + "description": "Number of songs", + "default": 100 + } + } + } + }) + + return tools + + +# --------------------------------------------------------------------------- +# System prompt builder (inlined from ai_mcp_client.py:13-83) +# --------------------------------------------------------------------------- +_FALLBACK_GENRES = "rock, pop, metal, jazz, electronic, dance, alternative, indie, punk, blues, hard rock, heavy metal, hip-hop, funk, country, soul" +_FALLBACK_MOODS = "danceable, aggressive, happy, party, relaxed, sad" + + +def _get_dynamic_genres(library_context: dict | None) -> str: + """Return genre list from library context, falling back to defaults.""" + if library_context and library_context.get('top_genres'): + return ', '.join(library_context['top_genres'][:15]) + return _FALLBACK_GENRES + + +def _get_dynamic_moods(library_context: dict | None) -> str: + """Return mood list from library context, falling back to defaults.""" + if library_context and library_context.get('top_moods'): + return ', '.join(library_context['top_moods'][:10]) + return _FALLBACK_MOODS + + +def build_system_prompt(tools: list[dict], library_context: dict | None = None) -> str: + """Build the unified system prompt used by ALL AI providers.""" + tool_names = [t['name'] for t in tools] + has_text_search = 'text_search' in tool_names + + # Build library context section + lib_section = "" + if library_context and library_context.get('total_songs', 0) > 0: + ctx = library_context + year_range = '' + if ctx.get('year_min') and ctx.get('year_max'): + year_range = f"\n- Year range: {ctx['year_min']}-{ctx['year_max']}" + rating_info = '' + if ctx.get('has_ratings'): + rating_info = f"\n- {ctx['rated_songs_pct']}% of songs have ratings (0-5 scale)" + scale_info = '' + if ctx.get('scales'): + scale_info = f"\n- Scales available: {', '.join(ctx['scales'])}" + + lib_section = f""" +=== USER'S MUSIC LIBRARY === +- {ctx['total_songs']} songs from {ctx['unique_artists']} artists{year_range}{rating_info}{scale_info} +""" + + # Build tool decision tree + decision_tree = [] + decision_tree.append("1. Specific song+artist mentioned? -> song_similarity") + if has_text_search: + decision_tree.append("2. Instruments (piano, guitar, ukulele) or SOUND DESCRIPTIONS (romantic, dreamy, chill vibes)? -> text_search") + decision_tree.append("3. 'songs by/from/like [ARTIST]'? -> artist_similarity (returns artist's own + similar)") + decision_tree.append("4. MULTIPLE artists blended ('A meets B', 'A + B', 'like A and B combined') OR negation ('X but not Y', 'X without Y')? -> song_alchemy (REQUIRES 2+ items)") + decision_tree.append("5. Songs NOT in library, trending, award winners (Grammy, Billboard), cultural knowledge? -> ai_brainstorm") + decision_tree.append("6. Genre/mood/tempo/energy/year/rating filters? -> search_database (last resort)") + else: + decision_tree.append("2. 'songs by/from/like [ARTIST]'? -> artist_similarity (returns artist's own + similar)") + decision_tree.append("3. MULTIPLE artists blended ('A meets B', 'A + B', 'like A and B combined') OR negation ('X but not Y', 'X without Y')? -> song_alchemy (REQUIRES 2+ items)") + decision_tree.append("4. Songs NOT in library, trending, award winners (Grammy, Billboard), cultural knowledge? -> ai_brainstorm") + decision_tree.append("5. Genre/mood/tempo/energy/year/rating filters? -> search_database (last resort)") + + decision_text = '\n'.join(decision_tree) + + prompt = f"""You are an expert music playlist curator. Analyze the user's request and call the appropriate tools to build a playlist of 100 songs. +{lib_section} +=== TOOL SELECTION (most specific -> most general) === +{decision_text} + +=== RULES === +1. Call one or more tools - each returns songs with item_id, title, and artist +2. song_similarity REQUIRES both title AND artist - never leave empty +3. artist_similarity returns the artist's OWN songs + songs from SIMILAR artists +4. search_database: COMBINE all filters in ONE call. Use for genre/mood/tempo/energy/year/rating +5. For multiple artists: call artist_similarity once per artist, or use song_alchemy to blend +6. Prefer tool calls over text explanations +7. For complex requests, call MULTIPLE tools in ONE turn for better coverage: + - "relaxing piano jazz" -> text_search("relaxing piano") + search_database(genres=["jazz"]) + - "energetic songs by Metallica and AC/DC" -> artist_similarity("Metallica") + artist_similarity("AC/DC") +8. When a query has BOTH a genre AND a mood from the MOODS list, prefer search_database over text_search: + - "sad jazz" -> search_database(genres=["jazz"], moods=["sad"]) NOT text_search + - But "dreamy atmospheric" -> text_search (no specific genre, sound description) + +=== VALID search_database VALUES === +GENRES: {_get_dynamic_genres(library_context)} +MOODS: {_get_dynamic_moods(library_context)} +TEMPO: 40-200 BPM +ENERGY: 0.0 (calm) to 1.0 (intense) - use 0.0-0.35 for low, 0.35-0.65 for medium, 0.65-1.0 for high +SCALE: major, minor +YEAR: year_min/year_max (e.g., 1990-1999 for 90s). For decade requests (80s, 90s), prefer year filters over genres. +RATING: min_rating 1-5 (user's personal ratings)""" + + return prompt + + +# --------------------------------------------------------------------------- +# Ollama prompt builder (inlined from ai_mcp_client.py:426-466) +# --------------------------------------------------------------------------- +def build_ollama_prompt(user_query: str, tools: list[dict], + library_context: dict | None = None) -> str: + """Build the full Ollama prompt with JSON output instructions.""" + has_text_search = 'text_search' in [t['name'] for t in tools] + + # Build tool parameter descriptions + tools_list = [] + for tool in tools: + props = tool['inputSchema'].get('properties', {}) + params_desc = ", ".join([f"{k} ({v.get('type')})" for k, v in props.items()]) + tools_list.append(f"- {tool['name']}: {params_desc}") + tools_text = "\n".join(tools_list) + + system_prompt = build_system_prompt(tools, library_context) + + # Build examples + examples = [] + examples.append('"Similar to By the Way by Red Hot Chili Peppers"\n{{"tool_calls": [{{"name": "song_similarity", "arguments": {{"song_title": "By the Way", "song_artist": "Red Hot Chili Peppers", "get_songs": 100}}}}]}}') + if has_text_search: + examples.append('"calm piano song"\n{{"tool_calls": [{{"name": "text_search", "arguments": {{"description": "calm piano", "get_songs": 100}}}}]}}') + examples.append('"songs like blink-182"\n{{"tool_calls": [{{"name": "artist_similarity", "arguments": {{"artist": "blink-182", "get_songs": 100}}}}]}}') + examples.append('"blink-182 songs"\n{{"tool_calls": [{{"name": "artist_similarity", "arguments": {{"artist": "blink-182", "get_songs": 100}}}}]}}') + examples.append('"energetic rock"\n{{"tool_calls": [{{"name": "search_database", "arguments": {{"genres": ["rock"], "energy_min": 0.65, "get_songs": 100}}}}]}}') + examples_text = "\n\n".join(examples) + + prompt = f"""{system_prompt} + +=== TOOL PARAMETERS === +{tools_text} + +=== OUTPUT FORMAT (CRITICAL) === +Return ONLY a valid JSON object with this EXACT format: +{{ + "tool_calls": [ + {{"name": "tool_name", "arguments": {{"param": "value"}}}} + ] +}} + +=== EXAMPLES === +{examples_text} + +Now analyze this request and return ONLY the JSON: +Request: "{user_query}" +""" + return prompt + + +# --------------------------------------------------------------------------- +# OpenAI-format payload builder (inlined from ai_mcp_client.py:257-301) +# --------------------------------------------------------------------------- +def build_openai_payload(user_query: str, tools: list[dict], model_id: str, + library_context: dict | None = None) -> dict: + """Build the OpenAI/OpenRouter chat-completion payload with tools.""" + functions = [] + for tool in tools: + functions.append({ + "type": "function", + "function": { + "name": tool['name'], + "description": tool['description'], + "parameters": tool['inputSchema'] + } + }) + + system_prompt = build_system_prompt(tools, library_context) + + return { + "model": model_id, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_query} + ], + "tools": functions, + "tool_choice": "auto" + } + + +# --------------------------------------------------------------------------- +# API callers +# --------------------------------------------------------------------------- +def call_ollama_model(model_cfg: dict, user_query: str, tools: list[dict], + library_context: dict | None, timeout: int) -> dict: + """Call an Ollama model and return parsed tool calls.""" + url = model_cfg["url"] + model_id = model_cfg["model_id"] + + prompt = build_ollama_prompt(user_query, tools, library_context) + + payload = { + "model": model_id, + "prompt": prompt, + "stream": False, + "format": "json" + } + + start_time = time.time() + try: + response = requests.post(url, json=payload, timeout=timeout) + response.raise_for_status() + result = response.json() + elapsed = time.time() - start_time + + if 'response' not in result: + return {"error": "Invalid Ollama response", "elapsed": elapsed, "raw_response": str(result)[:500]} + + response_text = result['response'] + cleaned = response_text.strip() + + # Remove markdown code blocks if present + if "```json" in cleaned: + cleaned = cleaned.split("```json")[1].split("```")[0] + elif "```" in cleaned: + cleaned = cleaned.split("```")[1].split("```")[0] + cleaned = cleaned.strip() + + # Strip think tags + for tag in ["", "[/INST]", "[/THOUGHT]"]: + if tag in cleaned: + cleaned = cleaned.split(tag, 1)[-1].strip() + + parsed = json.loads(cleaned) + + # Extract tool_calls from various response shapes + if isinstance(parsed, dict) and 'tool_calls' in parsed: + tool_calls = parsed['tool_calls'] + elif isinstance(parsed, list): + tool_calls = parsed + elif isinstance(parsed, dict) and 'name' in parsed: + tool_calls = [parsed] + else: + return {"error": "Missing 'tool_calls' field", "elapsed": elapsed, + "raw_response": cleaned[:500]} + + if not isinstance(tool_calls, list): + tool_calls = [tool_calls] + + # Validate structure + valid_calls = [] + for tc in tool_calls: + if isinstance(tc, dict) and 'name' in tc: + if 'arguments' not in tc: + tc['arguments'] = {} + valid_calls.append(tc) + + if not valid_calls: + return {"error": "No valid tool calls found", "elapsed": elapsed, + "raw_response": cleaned[:500]} + + return {"tool_calls": valid_calls, "elapsed": elapsed, "raw_response": response_text} + + except json.JSONDecodeError as e: + elapsed = time.time() - start_time + return {"error": f"JSON parse error: {e}", "elapsed": elapsed, + "raw_response": response_text[:500] if 'response_text' in locals() else ""} + except requests.exceptions.ConnectionError: + elapsed = time.time() - start_time + return {"error": "Connection refused", "elapsed": elapsed, "raw_response": ""} + except requests.exceptions.Timeout: + elapsed = time.time() - start_time + return {"error": f"Timeout after {timeout}s", "elapsed": elapsed, "raw_response": ""} + except requests.exceptions.HTTPError as e: + elapsed = time.time() - start_time + detail = "" + try: + detail = e.response.text[:200] + except Exception: + pass + return {"error": f"HTTP {e.response.status_code}: {detail}", "elapsed": elapsed, "raw_response": ""} + except Exception as e: + elapsed = time.time() - start_time + return {"error": str(e), "elapsed": elapsed, "raw_response": ""} + + +def call_openai_model(model_cfg: dict, user_query: str, tools: list[dict], + library_context: dict | None, timeout: int) -> dict: + """Call an OpenAI-compatible API and return parsed tool calls.""" + url = model_cfg["url"] + model_id = model_cfg["model_id"] + api_key = model_cfg.get("api_key", "") + + payload = build_openai_payload(user_query, tools, model_id, library_context) + + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + if "openrouter" in url.lower(): + headers["HTTP-Referer"] = "https://github.com/NeptuneHub/AudioMuse-AI" + headers["X-Title"] = "AudioMuse-AI" + + start_time = time.time() + try: + response = requests.post(url, headers=headers, json=payload, timeout=timeout) + response.raise_for_status() + result = response.json() + elapsed = time.time() - start_time + + raw_response = json.dumps(result, indent=2) + + tool_calls = [] + if 'choices' in result and result['choices']: + message = result['choices'][0].get('message', {}) + if 'tool_calls' in message: + for tc in message['tool_calls']: + if tc.get('type') == 'function': + try: + args = json.loads(tc['function']['arguments']) + except (json.JSONDecodeError, KeyError): + args = {} + tool_calls.append({ + "name": tc['function']['name'], + "arguments": args + }) + + if not tool_calls: + text_response = result.get('choices', [{}])[0].get('message', {}).get('content', '') + return {"error": "No tool calls returned", "elapsed": elapsed, + "raw_response": text_response[:500] if text_response else raw_response[:500]} + + return {"tool_calls": tool_calls, "elapsed": elapsed, "raw_response": raw_response} + + except requests.exceptions.ConnectionError: + elapsed = time.time() - start_time + return {"error": "Connection refused", "elapsed": elapsed, "raw_response": ""} + except requests.exceptions.Timeout: + elapsed = time.time() - start_time + return {"error": f"Timeout after {timeout}s", "elapsed": elapsed, "raw_response": ""} + except requests.exceptions.HTTPError as e: + elapsed = time.time() - start_time + detail = "" + try: + detail = e.response.text[:200] + except Exception: + pass + return {"error": f"HTTP {e.response.status_code}: {detail}", "elapsed": elapsed, "raw_response": ""} + except Exception as e: + elapsed = time.time() - start_time + return {"error": str(e), "elapsed": elapsed, "raw_response": ""} + + +def call_model(model_cfg: dict, user_query: str, tools: list[dict], + library_context: dict | None, timeout: int) -> dict: + """Dispatch to the correct API caller based on provider config.""" + url = model_cfg.get("url", "") + api_key = model_cfg.get("api_key", "") + is_openai_format = ( + bool(api_key) or + "openai" in url.lower() or + "openrouter" in url.lower() + ) + + if is_openai_format: + return call_openai_model(model_cfg, user_query, tools, library_context, timeout) + else: + return call_ollama_model(model_cfg, user_query, tools, library_context, timeout) + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- +def _score_args_quality(selected_tool: str, selected_args: dict, expected_args: dict) -> float: + """Score argument quality against expected_args from YAML config. + + Returns a float 0.0-1.0 representing how well the arguments match expectations. + Uses case-insensitive matching to be tolerant of formatting variations. + """ + if not expected_args: + return 1.0 # No expected args defined = full marks + + if not isinstance(selected_args, dict): + return 0.0 + + checks = [] + + if selected_tool == "song_similarity": + # Check song_title and song_artist + if 'song_title' in expected_args: + actual = (selected_args.get('song_title') or '').lower() + expected = expected_args['song_title'].lower() + checks.append(expected in actual or actual in expected) + if 'song_artist' in expected_args: + actual = (selected_args.get('song_artist') or '').lower() + expected = expected_args['song_artist'].lower() + checks.append(expected in actual or actual in expected) + + elif selected_tool == "search_database": + # Check genres + if 'genres' in expected_args: + actual_genres = [g.lower() for g in (selected_args.get('genres') or [])] + for exp_genre in expected_args['genres']: + checks.append(any(exp_genre.lower() in ag for ag in actual_genres)) + # Check moods + if 'moods' in expected_args: + actual_moods = [m.lower() for m in (selected_args.get('moods') or [])] + for exp_mood in expected_args['moods']: + checks.append(any(exp_mood.lower() in am for am in actual_moods)) + # Check scale + if 'scale' in expected_args: + checks.append((selected_args.get('scale') or '').lower() == expected_args['scale'].lower()) + # Check min_rating + if 'min_rating' in expected_args: + checks.append(selected_args.get('min_rating') == expected_args['min_rating']) + + elif selected_tool == "artist_similarity": + if 'artist' in expected_args: + actual = (selected_args.get('artist') or '').lower() + expected = expected_args['artist'].lower() + checks.append(expected in actual or actual in expected) + + elif selected_tool == "song_alchemy": + # Check that expected artists appear in add_items + if 'add_items_artists' in expected_args: + add_items = selected_args.get('add_items') or [] + actual_ids = [] + for item in add_items: + if isinstance(item, dict): + actual_ids.append((item.get('id') or '').lower()) + elif isinstance(item, str): + actual_ids.append(item.lower()) + for exp_artist in expected_args['add_items_artists']: + checks.append(any(exp_artist.lower() in aid for aid in actual_ids)) + + if not checks: + return 1.0 # No checkable fields for this tool = full marks + return sum(1 for c in checks if c) / len(checks) + + +def validate_result(result: dict, test_query: dict, tools: list[dict]) -> dict: + """ + Validate a single model response against the expected outcome. + + Returns a dict with: + json_valid, correct_tool, valid_args, pre_exec_valid, + args_quality, composite_score, selected_tool, selected_args, + all_tools_called + """ + tool_names = [t['name'] for t in tools] + expected_tool = test_query['expected_tool'] + acceptable_tools = test_query.get('acceptable_tools', [expected_tool]) + expected_args = test_query.get('expected_args', {}) + + # Default: everything fails + validation = { + "json_valid": False, + "correct_tool": False, + "valid_args": False, + "pre_exec_valid": False, + "args_quality": 0.0, + "composite_score": 0.0, + "selected_tool": "", + "selected_args": {}, + "all_tools_called": [], + } + + if 'error' in result or 'tool_calls' not in result: + return validation + + tc_list = result.get('tool_calls', []) + if not tc_list or not isinstance(tc_list, list): + return validation + + # JSON is valid if we got parseable tool calls + validation["json_valid"] = True + + # Record all tools called + all_called = [tc.get('name', '') for tc in tc_list] + validation["all_tools_called"] = all_called + + # Check if expected tool appears ANYWHERE in tool calls (not just first) + matched_tc = None + for tc in tc_list: + if tc.get('name') == expected_tool: + matched_tc = tc + break + + # If ideal tool not found, check acceptable_tools + if matched_tc is None: + for tc in tc_list: + if tc.get('name') in acceptable_tools: + matched_tc = tc + break + + # Use matched tool call for scoring, fall back to first + scoring_tc = matched_tc if matched_tc else tc_list[0] + selected_tool = scoring_tc.get('name', '') + selected_args = scoring_tc.get('arguments', {}) + + validation["selected_tool"] = selected_tool + validation["selected_args"] = selected_args + + # Correct tool? Check against acceptable_tools list + validation["correct_tool"] = (selected_tool in acceptable_tools) + + # Valid args? (check required args and types for the SELECTED tool) + validation["valid_args"] = _check_args_valid(selected_tool, selected_args, tool_names) + + # Pre-execution valid? (mirrors app_chat.py:448-475) + validation["pre_exec_valid"] = _check_pre_exec_valid(selected_tool, selected_args) + + # Args quality scoring (0.0-1.0) using expected_args from YAML + if validation["correct_tool"] and expected_args: + validation["args_quality"] = _score_args_quality(selected_tool, selected_args, expected_args) + elif validation["correct_tool"]: + validation["args_quality"] = 1.0 # Correct tool, no expected_args to check + + # Composite score: tool_correct (50) + args_quality (25) + pre_exec_valid (15) + json_valid (10) + score = 0.0 + if validation["json_valid"]: + score += 10.0 + if validation["correct_tool"]: + score += 50.0 + if validation["correct_tool"]: + score += 25.0 * validation["args_quality"] + if validation["pre_exec_valid"]: + score += 15.0 + validation["composite_score"] = round(score, 1) + + return validation + + +def _check_args_valid(tool_name: str, args: dict, available_tools: list[str]) -> bool: + """Check that required arguments are present and of correct type.""" + if not isinstance(args, dict): + return False + if tool_name not in available_tools and tool_name not in VALID_TOOL_NAMES: + return False + + if tool_name == "song_similarity": + return (isinstance(args.get('song_title'), str) and len(args['song_title']) > 0 and + isinstance(args.get('song_artist'), str) and len(args['song_artist']) > 0) + + elif tool_name == "text_search": + return isinstance(args.get('description'), str) and len(args['description']) > 0 + + elif tool_name == "artist_similarity": + return isinstance(args.get('artist'), str) and len(args['artist']) > 0 + + elif tool_name == "song_alchemy": + add_items = args.get('add_items', []) + if not isinstance(add_items, list) or len(add_items) == 0: + return False + # Accept both structured dicts and simple strings + for item in add_items: + if isinstance(item, dict): + if not item.get('type') or not item.get('id'): + return False + elif not isinstance(item, str): + return False + return True + + elif tool_name == "ai_brainstorm": + return isinstance(args.get('user_request'), str) and len(args['user_request']) > 0 + + elif tool_name == "search_database": + # search_database has no required args (but pre_exec checks for filters) + return True + + return False + + +def _check_pre_exec_valid(tool_name: str, args: dict) -> bool: + """Mirror pre-execution validation from app_chat.py.""" + if not isinstance(args, dict): + return False + if tool_name == "song_similarity": + title = args.get('song_title', '') + artist = args.get('song_artist', '') + if isinstance(title, str) and isinstance(artist, str): + return bool(title.strip()) and bool(artist.strip()) + return False + + elif tool_name == "search_database": + # At least one filter must be present + return any(args.get(k) for k in SEARCH_DB_FILTER_KEYS) + + # All other tools pass pre-execution validation if they have valid args + return _check_args_valid(tool_name, args, VALID_TOOL_NAMES) + + +# --------------------------------------------------------------------------- +# Config helpers (mirrors test_ai_naming.py) +# --------------------------------------------------------------------------- +def apply_defaults(config: dict) -> None: + """Merge provider defaults (url, api_key) into each model entry.""" + defaults = config.get("defaults", {}) + for model in config.get("models", []): + provider = model.get("provider", "") + provider_defaults = defaults.get(provider, {}) + for key, value in provider_defaults.items(): + if key not in model: + model[key] = value + + # Allow environment variable override for API keys + env_api_key = os.environ.get('OPENROUTER_API_KEY') + if env_api_key: + for model in config.get("models", []): + if model.get("provider") == "openrouter": + model["api_key"] = env_api_key + + +# --------------------------------------------------------------------------- +# Main test loop +# --------------------------------------------------------------------------- +def run_tests(config: dict, dry_run: bool = False) -> tuple[dict, list[dict]]: + """ + Execute the full test suite. + + Returns: + (results_dict, test_queries) + results_dict keys are model names, values have 'runs' list and metadata. + """ + tc = config["test_config"] + models = [m for m in config["models"] if m.get("enabled", False)] + clap_enabled = tc.get("clap_enabled", True) + library_context = config.get("library_context") + + if not models: + print("ERROR: No models enabled in configuration.") + sys.exit(1) + + num_runs = tc["num_runs_per_model"] + timeout = tc.get("timeout_per_request", 120) + + # Build tool definitions + tools = get_tool_definitions(clap_enabled) + tool_names_available = [t['name'] for t in tools] + + # Filter test queries + all_queries = config.get("test_queries", []) + test_queries = [] + for q in all_queries: + if q.get("skip_if_clap_disabled") and not clap_enabled: + continue + test_queries.append(q) + + if not test_queries: + print("ERROR: No test queries after filtering.") + sys.exit(1) + + num_queries = len(test_queries) + + if dry_run: + print("=== DRY RUN MODE ===") + print(f"Would test {len(models)} model(s), {num_queries} queries, {num_runs} run(s) each") + print(f"CLAP enabled: {clap_enabled}") + print(f"Tools available: {', '.join(tool_names_available)}") + print(f"Library context: {'yes' if library_context else 'no'}\n") + + for mi, m in enumerate(models): + print(f" Model {mi + 1}: {m['name']} ({m['provider']}) - {m['model_id']}") + + print(f"\n--- System Prompt Preview ---") + sys_prompt = build_system_prompt(tools, library_context) + print(sys_prompt[:800]) + print("...\n") + + # Show Ollama prompt for first query + print(f"--- Ollama Prompt Preview (query 1: \"{test_queries[0]['query']}\") ---") + ollama_prompt = build_ollama_prompt(test_queries[0]['query'], tools, library_context) + print(ollama_prompt[:1200]) + print("...\n") + + # Show OpenAI payload for first query + print(f"--- OpenAI Payload Preview (query 1) ---") + openai_payload = build_openai_payload(test_queries[0]['query'], tools, "example-model", library_context) + # Show just messages, not the full tool defs + print(json.dumps(openai_payload['messages'], indent=2)[:600]) + print(f" ... + {len(openai_payload['tools'])} tool definitions\n") + + print(f"--- Test Queries ({num_queries}) ---") + for qi, q in enumerate(test_queries): + print(f" {qi + 1:2d}. [{q['category']}] \"{q['query']}\" -> {q['expected_tool']}") + + return {}, test_queries + + # Run tests + results = {} + total_models = len(models) + connection_failures = set() + + for mi, model in enumerate(models): + model_name = model["name"] + print(f"[{mi + 1}/{total_models}] Testing: {model_name} ({model['provider']})") + + results[model_name] = { + "provider": model["provider"], + "model_id": model["model_id"], + "url": model["url"], + "runs": [], + } + + # Skip if previous connection to same URL failed + if model["url"] in connection_failures: + print(f" Skipping (connection to {model['url']} already failed)\n") + for qi in range(num_queries): + for ri in range(num_runs): + results[model_name]["runs"].append({ + "query_index": qi, + "query": test_queries[qi]["query"], + "expected_tool": test_queries[qi]["expected_tool"], + "category": test_queries[qi]["category"], + "run_index": ri, + "json_valid": False, + "correct_tool": False, + "valid_args": False, + "pre_exec_valid": False, + "args_quality": 0.0, + "composite_score": 0.0, + "selected_tool": "", + "selected_args": {}, + "all_tools_called": [], + "raw_response": "", + "elapsed": 0, + "error": "Skipped (connection failed)", + }) + continue + + model_correct = 0 + model_total = 0 + model_times = [] + abort_model = False + + for qi, tq in enumerate(test_queries): + for ri in range(num_runs): + if abort_model: + results[model_name]["runs"].append({ + "query_index": qi, + "query": tq["query"], + "expected_tool": tq["expected_tool"], + "category": tq["category"], + "run_index": ri, + "json_valid": False, + "correct_tool": False, + "valid_args": False, + "pre_exec_valid": False, + "args_quality": 0.0, + "composite_score": 0.0, + "selected_tool": "", + "selected_args": {}, + "all_tools_called": [], + "raw_response": "", + "elapsed": 0, + "error": "Skipped (connection failed)", + }) + continue + + model_total += 1 + status_prefix = f" Q{qi + 1} Run {ri + 1}/{num_runs}:" + + api_result = call_model(model, tq["query"], tools, library_context, timeout) + validation = validate_result(api_result, tq, tools) + + elapsed = api_result.get("elapsed", 0) + error = api_result.get("error") + raw_resp = api_result.get("raw_response", "") + + run_result = { + "query_index": qi, + "query": tq["query"], + "expected_tool": tq["expected_tool"], + "category": tq["category"], + "run_index": ri, + "json_valid": validation["json_valid"], + "correct_tool": validation["correct_tool"], + "valid_args": validation["valid_args"], + "pre_exec_valid": validation["pre_exec_valid"], + "args_quality": validation["args_quality"], + "composite_score": validation["composite_score"], + "selected_tool": validation["selected_tool"], + "selected_args": validation["selected_args"], + "all_tools_called": validation["all_tools_called"], + "raw_response": raw_resp if isinstance(raw_resp, str) else str(raw_resp), + "elapsed": elapsed, + "error": error, + } + results[model_name]["runs"].append(run_result) + + if error == "Connection refused": + print(f"{status_prefix} FAIL (connection refused)") + connection_failures.add(model["url"]) + abort_model = True + continue + + if error: + print(f"{status_prefix} ERR {elapsed:.1f}s {error}") + elif validation["correct_tool"]: + model_correct += 1 + model_times.append(elapsed) + args_ok = "args OK" if validation["valid_args"] else "args INVALID" + print(f"{status_prefix} OK {elapsed:.1f}s {validation['selected_tool']} ({args_ok})") + else: + model_times.append(elapsed) + print(f"{status_prefix} WRONG {elapsed:.1f}s got={validation['selected_tool']} expected={tq['expected_tool']}") + + # Model summary + if model_total > 0 and not abort_model: + avg_t = sum(model_times) / len(model_times) if model_times else 0 + rate = model_correct / model_total * 100 + print(f" Result: {model_correct}/{model_total} correct ({rate:.1f}%), avg {avg_t:.1f}s\n") + elif abort_model: + print(f" Result: Aborted (connection failed)\n") + + return results, test_queries + + +# --------------------------------------------------------------------------- +# Report generation +# --------------------------------------------------------------------------- +def generate_summary_table(results: dict, timestamp: str) -> str: + """Generate the ASCII summary table.""" + lines = [] + lines.append("=" * 105) + lines.append(f" RESULTS - Instant Playlist Tool-Calling Test ({timestamp})") + lines.append("=" * 105) + lines.append(f" {'Model':<22} {'Total':>5} {'JSON OK':>7} {'Tool OK':>7} {'Args OK':>7} {'Rate':>6} {'Score':>6} {'Avg Time':>8}") + lines.append("-" * 105) + + for model_name, model_data in results.items(): + all_runs = model_data["runs"] + total = len(all_runs) + json_ok = sum(1 for r in all_runs if r["json_valid"]) + tool_ok = sum(1 for r in all_runs if r["correct_tool"]) + args_ok = sum(1 for r in all_runs if r["correct_tool"] and r["valid_args"]) + rate = (tool_ok / total * 100) if total > 0 else 0 + avg_composite = sum(r.get("composite_score", 0) for r in all_runs) / total if total > 0 else 0 + times = [r["elapsed"] for r in all_runs if r["error"] is None] + avg_t = sum(times) / len(times) if times else 0 + + lines.append( + f" {model_name:<22} {total:>5} {json_ok:>7} {tool_ok:>7} {args_ok:>7} {rate:>5.1f}% {avg_composite:>5.1f} {avg_t:>7.1f}s" + ) + + lines.append("-" * 105) + return "\n".join(lines) + + +def generate_query_detail_table(results: dict, test_queries: list[dict], num_runs: int) -> str: + """Generate per-query detail table.""" + lines = [] + model_names = list(results.keys()) + + # Group queries by category + categories = {} + for qi, tq in enumerate(test_queries): + cat = tq["category"] + if cat not in categories: + categories[cat] = [] + categories[cat].append((qi, tq)) + + for cat, queries in categories.items(): + lines.append(f"\n=== Category: {cat} ===") + for qi, tq in queries: + lines.append(f"\n Q{qi + 1}: \"{tq['query']}\" -> expected: {tq['expected_tool']}") + for model_name in model_names: + runs = [r for r in results[model_name]["runs"] if r["query_index"] == qi] + correct = sum(1 for r in runs if r["correct_tool"]) + total = len(runs) + tools_selected = [r["selected_tool"] or "(none)" for r in runs] + tools_str = ", ".join(tools_selected[:5]) + lines.append(f" {model_name:<22} {correct}/{total} [{tools_str}]") + + return "\n".join(lines) + + +def generate_html_report(results: dict, test_queries: list[dict], + num_runs: int, timestamp: str, config: dict, + save_raw: bool, system_prompt: str) -> str: + """Generate a self-contained HTML report.""" + model_names = list(results.keys()) + num_queries = len(test_queries) + + # Build summary rows + summary_rows = "" + for model_name, model_data in results.items(): + all_runs = model_data["runs"] + total = len(all_runs) + json_ok = sum(1 for r in all_runs if r["json_valid"]) + tool_ok = sum(1 for r in all_runs if r["correct_tool"]) + args_ok = sum(1 for r in all_runs if r["correct_tool"] and r["valid_args"]) + pre_exec = sum(1 for r in all_runs if r["correct_tool"] and r["pre_exec_valid"]) + errors = sum(1 for r in all_runs if r["error"]) + rate = (tool_ok / total * 100) if total > 0 else 0 + avg_composite = sum(r.get("composite_score", 0) for r in all_runs) / total if total > 0 else 0 + times = [r["elapsed"] for r in all_runs if r["error"] is None] + avg_t = sum(times) / len(times) if times else 0 + min_t = min(times) if times else 0 + max_t = max(times) if times else 0 + + rate_class = "pass" if rate >= 80 else ("warn" if rate >= 50 else "fail") + score_class = "pass" if avg_composite >= 75 else ("warn" if avg_composite >= 50 else "fail") + provider = model_data.get("provider", "") + + summary_rows += f""" + {model_name}{provider} + {total}{json_ok}{tool_ok}{args_ok}{pre_exec} + {errors} + {rate:.1f}% + {avg_composite:.1f} + {avg_t:.2f}s{min_t:.2f}s{max_t:.2f}s + \n""" + + # Build category breakdown + categories = {} + for qi, tq in enumerate(test_queries): + cat = tq["category"] + if cat not in categories: + categories[cat] = [] + categories[cat].append(qi) + + category_rows = "" + for cat, query_indices in categories.items(): + for model_name, model_data in results.items(): + cat_runs = [r for r in model_data["runs"] if r["query_index"] in query_indices] + cat_total = len(cat_runs) + cat_correct = sum(1 for r in cat_runs if r["correct_tool"]) + cat_rate = (cat_correct / cat_total * 100) if cat_total > 0 else 0 + cat_composite = sum(r.get("composite_score", 0) for r in cat_runs) / cat_total if cat_total > 0 else 0 + cat_class = "pass" if cat_rate >= 80 else ("warn" if cat_rate >= 50 else "fail") + score_class = "pass" if cat_composite >= 75 else ("warn" if cat_composite >= 50 else "fail") + category_rows += f""" + {cat}{model_name} + {cat_correct}/{cat_total} + {cat_rate:.0f}% + {cat_composite:.1f} + \n""" + + # Calculate per-query difficulty based on aggregate success rate + query_difficulty = {} + for qi, tq in enumerate(test_queries): + all_query_runs = [] + for model_data in results.values(): + all_query_runs.extend(r for r in model_data["runs"] if r["query_index"] == qi) + total_runs = len(all_query_runs) + correct_runs = sum(1 for r in all_query_runs if r["correct_tool"]) + success_rate = (correct_runs / total_runs * 100) if total_runs > 0 else 0 + if success_rate >= 90: + difficulty = "Easy" + diff_class = "pass" + elif success_rate >= 70: + difficulty = "Medium" + diff_class = "warn" + elif success_rate >= 50: + difficulty = "Hard" + diff_class = "fail" + else: + difficulty = "Very Hard" + diff_class = "fail" + query_difficulty[qi] = {"label": difficulty, "class": diff_class, "rate": success_rate} + + # Build query difficulty summary table + difficulty_rows = "" + sorted_queries = sorted(query_difficulty.items(), key=lambda x: x[1]["rate"]) + for qi, diff in sorted_queries: + tq = test_queries[qi] + acceptable = tq.get('acceptable_tools') + accept_str = f' (also accepts: {", ".join(t for t in acceptable if t != tq["expected_tool"])})' if acceptable and len(acceptable) > 1 else "" + difficulty_rows += f""" + Q{qi + 1}{tq['query']}{tq['category']} + {tq['expected_tool']}{accept_str} + {diff['rate']:.0f}% + {diff['label']} + \n""" + + # Build per-query detail sections + query_sections = "" + for qi, tq in enumerate(test_queries): + diff = query_difficulty[qi] + detail_rows = "" + for model_name in model_names: + runs = [r for r in results[model_name]["runs"] if r["query_index"] == qi] + correct_count = sum(1 for r in runs if r["correct_tool"]) + total_count = len(runs) + times = [r["elapsed"] for r in runs if r["error"] is None] + avg_t = sum(times) / len(times) if times else 0 + avg_score = sum(r.get("composite_score", 0) for r in runs) / total_count if total_count else 0 + rate = (correct_count / total_count * 100) if total_count else 0 + rate_class = "pass" if rate >= 80 else ("warn" if rate >= 50 else "fail") + + runs_html = "" + for ri, r in enumerate(runs): + error = r.get("error") + selected = r.get("selected_tool", "") + all_tools = r.get("all_tools_called", []) + args_str = json.dumps(r.get("selected_args", {}), indent=1) + raw = str(r.get("raw_response", "")).replace("&", "&").replace("<", "<").replace(">", ">") + score = r.get("composite_score", 0) + multi_tool_str = f' [{", ".join(all_tools)}]' if len(all_tools) > 1 else "" + + if error: + runs_html += f'
    Run {ri + 1}: {error} ({r["elapsed"]:.1f}s)
    \n' + elif r["correct_tool"]: + args_ok_str = "args OK" if r["valid_args"] else "args INVALID" + aq = r.get("args_quality", 0) + aq_str = f" aq={aq:.0%}" if aq < 1.0 else "" + raw_detail = f'
    raw
    {raw[:1000]}
    ' if save_raw and raw else "" + runs_html += f'
    Run {ri + 1}: {selected} ({args_ok_str}{aq_str}) score={score:.0f}{multi_tool_str} ({r["elapsed"]:.1f}s){raw_detail}
    \n' + else: + raw_detail = f'
    raw
    {raw[:1000]}
    ' if save_raw and raw else "" + runs_html += f'
    Run {ri + 1}: {selected or "(none)"} (expected {tq["expected_tool"]}) score={score:.0f}{multi_tool_str} ({r["elapsed"]:.1f}s){raw_detail}
    \n' + + detail_rows += f""" + {model_name}
    {results[model_name].get('provider', '')} + {correct_count}/{total_count} ({rate:.0f}%) + {avg_score:.1f} + {avg_t:.2f}s + {runs_html} + \n""" + + acceptable = tq.get('acceptable_tools') + accept_note = f' (also accepts: {", ".join(t for t in acceptable if t != tq["expected_tool"])})' if acceptable and len(acceptable) > 1 else "" + + query_sections += f""" +
    +

    Q{qi + 1} [{tq['category']}]: "{tq['query']}" → {tq['expected_tool']}{accept_note} {diff['label']}

    + + + + + + {detail_rows} +
    ModelCorrectAvg ScoreAvg TimeResults
    +
    + """ + + # Escape system prompt for HTML + sys_prompt_escaped = system_prompt.replace("&", "&").replace("<", "<").replace(">", ">") + + # Sanitize config for display (hide API keys) + display_config = json.loads(json.dumps(config, default=str)) + if "defaults" in display_config: + for provider in display_config["defaults"]: + if "api_key" in display_config["defaults"][provider]: + display_config["defaults"][provider]["api_key"] = "***hidden***" + for m in display_config.get("models", []): + if "api_key" in m: + m["api_key"] = "***hidden***" + + html = f""" + + + + +Instant Playlist Tool-Calling Test - {timestamp} + + + +

    Instant Playlist Tool-Calling Test

    +

    Date: {timestamp}  |  + Runs per model: {num_runs}  |  + Queries: {num_queries}  |  + CLAP enabled: {config['test_config'].get('clap_enabled', True)}

    + +

    Summary

    + + + + + + + {summary_rows} +
    ModelProviderTotalJSON OKTool OKArgs OKPre-Exec OKErrorsTool RateCompositeAvg TimeMin TimeMax Time
    + +

    Query Difficulty

    +

    Difficulty is auto-calculated from aggregate success rate across all models. Queries below 50% may need prompt improvement or reclassification.

    + + + + + {difficulty_rows} +
    QueryTextCategoryExpected ToolSuccess RateDifficulty
    + +

    Category Breakdown

    + + + + + {category_rows} +
    CategoryModelCorrectRateAvg Score
    + +

    System Prompt Used

    +
    + Show system prompt +
    {sys_prompt_escaped}
    +
    + +

    Per-Query Details

    +{query_sections} + +

    Test Configuration

    +
    {json.dumps(display_config, indent=2, default=str)}
    + + + +""" + return html + + +def generate_json_report(results: dict, test_queries: list[dict], + timestamp: str, config: dict) -> dict: + """Generate the full JSON report.""" + # Group queries by category for per-category stats + categories = {} + for qi, tq in enumerate(test_queries): + cat = tq["category"] + if cat not in categories: + categories[cat] = [] + categories[cat].append(qi) + + # Calculate per-query difficulty + query_stats = [] + for qi, tq in enumerate(test_queries): + all_query_runs = [] + for model_data in results.values(): + all_query_runs.extend(r for r in model_data["runs"] if r["query_index"] == qi) + total_runs = len(all_query_runs) + correct_runs = sum(1 for r in all_query_runs if r["correct_tool"]) + success_rate = round(correct_runs / total_runs * 100, 1) if total_runs > 0 else 0 + if success_rate >= 90: + difficulty = "easy" + elif success_rate >= 70: + difficulty = "medium" + elif success_rate >= 50: + difficulty = "hard" + else: + difficulty = "very_hard" + query_stats.append({ + "index": qi, "query": tq["query"], "expected_tool": tq["expected_tool"], + "acceptable_tools": tq.get("acceptable_tools", [tq["expected_tool"]]), + "category": tq["category"], + "success_rate": success_rate, + "difficulty": difficulty, + }) + + report = { + "timestamp": timestamp, + "test_type": "instant_playlist_tool_calling", + "config": { + "clap_enabled": config["test_config"].get("clap_enabled", True), + "num_runs_per_model": config["test_config"]["num_runs_per_model"], + "timeout": config["test_config"].get("timeout_per_request", 120), + "num_queries": len(test_queries), + }, + "queries": query_stats, + "models": {}, + } + + for model_name, model_data in results.items(): + all_runs = model_data["runs"] + total = len(all_runs) + json_ok = sum(1 for r in all_runs if r["json_valid"]) + tool_ok = sum(1 for r in all_runs if r["correct_tool"]) + args_ok = sum(1 for r in all_runs if r["correct_tool"] and r["valid_args"]) + pre_exec = sum(1 for r in all_runs if r["correct_tool"] and r["pre_exec_valid"]) + errors = sum(1 for r in all_runs if r["error"]) + times = [r["elapsed"] for r in all_runs if r["error"] is None] + + avg_composite = sum(r.get("composite_score", 0) for r in all_runs) / total if total > 0 else 0 + + # Per-category breakdown + per_category = {} + for cat, query_indices in categories.items(): + cat_runs = [r for r in all_runs if r["query_index"] in query_indices] + cat_total = len(cat_runs) + cat_correct = sum(1 for r in cat_runs if r["correct_tool"]) + cat_composite = sum(r.get("composite_score", 0) for r in cat_runs) / cat_total if cat_total > 0 else 0 + per_category[cat] = { + "total": cat_total, + "correct": cat_correct, + "rate": round(cat_correct / cat_total * 100, 1) if cat_total > 0 else 0, + "avg_composite": round(cat_composite, 1), + } + + report["models"][model_name] = { + "provider": model_data.get("provider", ""), + "model_id": model_data.get("model_id", ""), + "url": model_data.get("url", ""), + "summary": { + "total_tests": total, + "json_valid": json_ok, + "tool_correct": tool_ok, + "args_valid": args_ok, + "pre_exec_valid": pre_exec, + "errors": errors, + "tool_rate": round(tool_ok / total * 100, 1) if total > 0 else 0, + "avg_composite": round(avg_composite, 1), + "avg_time": round(sum(times) / len(times), 3) if times else 0, + "min_time": round(min(times), 3) if times else 0, + "max_time": round(max(times), 3) if times else 0, + }, + "per_category": per_category, + "runs": [ + { + "query_index": r["query_index"], + "query": r["query"], + "expected_tool": r["expected_tool"], + "category": r["category"], + "run_index": r["run_index"], + "json_valid": r["json_valid"], + "correct_tool": r["correct_tool"], + "valid_args": r["valid_args"], + "pre_exec_valid": r["pre_exec_valid"], + "args_quality": r.get("args_quality", 0), + "composite_score": r.get("composite_score", 0), + "selected_tool": r["selected_tool"], + "selected_args": r["selected_args"], + "all_tools_called": r.get("all_tools_called", []), + "elapsed": round(r["elapsed"], 3), + "error": r.get("error"), + } + for r in all_runs + ], + } + + return report + + +# --------------------------------------------------------------------------- +# Save reports +# --------------------------------------------------------------------------- +def save_reports(results: dict, test_queries: list[dict], config: dict, + num_runs: int, output_dir: str, save_raw: bool, + system_prompt: str): + """Save TXT, HTML, and JSON reports to disk.""" + os.makedirs(output_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") + file_ts = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Console / TXT summary + summary = generate_summary_table(results, timestamp) + query_detail = generate_query_detail_table(results, test_queries, num_runs) + full_txt = summary + "\n" + query_detail + "\n" + + print("\n" + full_txt) + + txt_path = os.path.join(output_dir, f"instant_playlist_{file_ts}.txt") + with open(txt_path, "w", encoding="utf-8") as f: + f.write(full_txt) + print(f"TXT report saved: {txt_path}") + + # HTML report + html = generate_html_report(results, test_queries, num_runs, timestamp, + config, save_raw, system_prompt) + html_path = os.path.join(output_dir, f"instant_playlist_{file_ts}.html") + with open(html_path, "w", encoding="utf-8") as f: + f.write(html) + print(f"HTML report saved: {html_path}") + + # JSON report + json_data = generate_json_report(results, test_queries, timestamp, config) + json_path = os.path.join(output_dir, f"instant_playlist_{file_ts}.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump(json_data, f, indent=2, default=str) + print(f"JSON report saved: {json_path}") + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser( + description="AudioMuse-AI - Instant Playlist Tool-Calling Performance Test", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + _default_cfg = "testing_suite/instant_playlist_test_config.yaml" + if not os.path.exists(_default_cfg): + _default_cfg = "testing_suite/instant_playlist_test_config.example.yaml" + parser.add_argument("--config", "-c", type=str, + default=_default_cfg, + help="Path to YAML config file (default: instant_playlist_test_config.yaml)") + parser.add_argument("--runs", "-n", type=int, default=None, + help="Override num_runs_per_model from config") + parser.add_argument("--dry-run", action="store_true", + help="Build prompts and show config, but don't call any APIs") + + args = parser.parse_args() + + # Load config + if not os.path.exists(args.config): + print(f"ERROR: Config file not found: {args.config}") + print(f"Usage: python testing_suite/test_instant_playlist.py --config path/to/config.yaml") + sys.exit(1) + + with open(args.config, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + + # Merge provider defaults into each model entry + apply_defaults(config) + + # Apply CLI overrides + if args.runs is not None: + config["test_config"]["num_runs_per_model"] = args.runs + + num_runs = config["test_config"]["num_runs_per_model"] + clap_enabled = config["test_config"].get("clap_enabled", True) + output_cfg = config.get("output", {}) + output_dir = output_cfg.get("directory", "testing_suite/reports/instant_playlist") + save_raw = output_cfg.get("save_raw_responses", True) + + # Build tools and system prompt for display + tools = get_tool_definitions(clap_enabled) + library_context = config.get("library_context") + system_prompt = build_system_prompt(tools, library_context) + + # Filter queries for count + all_queries = config.get("test_queries", []) + active_queries = [q for q in all_queries + if not (q.get("skip_if_clap_disabled") and not clap_enabled)] + + print("=" * 60) + print(" AudioMuse-AI - Instant Playlist Tool-Calling Test") + print("=" * 60) + + enabled = [m for m in config["models"] if m.get("enabled", False)] + print(f" Models: {len(enabled)} enabled") + print(f" Queries: {len(active_queries)}") + print(f" Runs/model: {num_runs}") + print(f" CLAP: {'enabled' if clap_enabled else 'disabled'}") + print(f" Tools: {', '.join(t['name'] for t in tools)}") + print("=" * 60 + "\n") + + # Run tests + results, test_queries = run_tests(config, dry_run=args.dry_run) + + if args.dry_run: + print("\nDry run complete. No API calls were made.") + return + + if not results: + print("No results to report.") + return + + # Generate and save reports + save_reports(results, test_queries, config, num_runs, output_dir, save_raw, + system_prompt) + + +if __name__ == "__main__": + main() diff --git a/testing_suite/test_runner/__init__.py b/testing_suite/test_runner/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testing_suite/test_runner/existing_tests.py b/testing_suite/test_runner/existing_tests.py new file mode 100644 index 00000000..c8a388fd --- /dev/null +++ b/testing_suite/test_runner/existing_tests.py @@ -0,0 +1,469 @@ +""" +Existing Test Integration Module for AudioMuse-AI Testing Suite. + +Discovers and runs existing unit and integration tests from the codebase +against both instances, collecting and comparing results. + +Integrates: + - tests/unit/ (17 unit test modules via pytest) + - test/test.py (E2E API integration tests) + - test/test_analysis_integration.py (ONNX model integration) + - test/test_clap_analysis_integration.py (CLAP model integration) +""" + +import json +import logging +import os +import subprocess +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from testing_suite.config import ComparisonConfig, InstanceConfig +from testing_suite.utils import ( + ComparisonReport, TestResult, TestStatus, pct_diff +) + +logger = logging.getLogger(__name__) + +PROJECT_ROOT = Path(__file__).resolve().parents[2] + +# Known unit test files +UNIT_TEST_DIR = PROJECT_ROOT / "tests" / "unit" +UNIT_TEST_FILES = [ + "test_ai.py", + "test_analysis.py", + "test_app_analysis.py", + "test_artist_gmm_manager.py", + "test_clap_text_search.py", + "test_clustering.py", + "test_clustering_helper.py", + "test_clustering_postprocessing.py", + "test_commons.py", + "test_mediaserver.py", + "test_memory_cleanup.py", + "test_memory_utils.py", + "test_path_manager.py", + "test_song_alchemy.py", + "test_sonic_fingerprint_manager.py", + "test_string_sanitization.py", + "test_voyager_manager.py", +] + +# Integration test files +INTEGRATION_TEST_DIR = PROJECT_ROOT / "test" +INTEGRATION_TEST_FILES = [ + "test_analysis_integration.py", + "test_clap_analysis_integration.py", +] + +# E2E API test (requires a running instance) +E2E_TEST_FILE = PROJECT_ROOT / "test" / "test.py" + +# Individual E2E test names (from test/test.py) +E2E_TEST_NAMES = [ + "test_analysis_smoke_flow", + "test_instant_playlist_functionality", + "test_sonic_fingerprint_and_playlist", + "test_song_alchemy_and_playlist", + "test_map_visualization", + "test_annoy_similarity_and_playlist", + "test_song_path_and_playlist", + "test_clustering_smoke_flow", +] + + +def _parse_pytest_json(json_path: str) -> dict: + """Parse pytest JSON report.""" + try: + with open(json_path, 'r') as f: + return json.load(f) + except Exception as e: + logger.warning(f"Could not parse pytest JSON report: {e}") + return {} + + +def _run_pytest(test_path: str, extra_args: list = None, + env_override: dict = None, timeout: int = 600, + json_report: bool = True) -> Tuple[dict, str, int]: + """ + Run pytest and capture results. + Returns (parsed_json_result, stdout, returncode). + """ + cmd = ["python", "-m", "pytest", "-v", "--tb=short"] + + json_path = None + if json_report: + json_path = f"/tmp/pytest_report_{int(time.time() * 1000)}.json" + cmd += [f"--json-report", f"--json-report-file={json_path}"] + + cmd.append(str(test_path)) + + if extra_args: + cmd.extend(extra_args) + + env = os.environ.copy() + if env_override: + env.update(env_override) + + # Ensure project root is in PYTHONPATH + python_path = env.get("PYTHONPATH", "") + env["PYTHONPATH"] = f"{PROJECT_ROOT}:{python_path}" if python_path else str(PROJECT_ROOT) + + try: + proc = subprocess.run( + cmd, capture_output=True, text=True, timeout=timeout, + cwd=str(PROJECT_ROOT), env=env, + ) + stdout = proc.stdout + proc.stderr + returncode = proc.returncode + + # Parse JSON report if available + result = {} + if json_path and os.path.exists(json_path): + result = _parse_pytest_json(json_path) + os.unlink(json_path) + + return result, stdout, returncode + + except subprocess.TimeoutExpired: + return {}, f"pytest timed out after {timeout}s", -1 + except Exception as e: + return {}, str(e), -2 + + +def _parse_stdout_results(stdout: str) -> dict: + """ + Parse pytest stdout for test results when JSON report is not available. + Returns dict with passed, failed, error, skipped counts and test names. + """ + import re + + results = { + "passed": 0, + "failed": 0, + "errors": 0, + "skipped": 0, + "tests": [], + } + + # Parse individual test results: PASSED, FAILED, ERROR, SKIPPED + for line in stdout.split('\n'): + line = line.strip() + if " PASSED" in line: + results["passed"] += 1 + results["tests"].append({"name": line.split(" PASSED")[0].strip(), "status": "passed"}) + elif " FAILED" in line: + results["failed"] += 1 + results["tests"].append({"name": line.split(" FAILED")[0].strip(), "status": "failed"}) + elif " ERROR" in line: + results["errors"] += 1 + results["tests"].append({"name": line.split(" ERROR")[0].strip(), "status": "error"}) + elif " SKIPPED" in line: + results["skipped"] += 1 + results["tests"].append({"name": line.split(" SKIPPED")[0].strip(), "status": "skipped"}) + + # Try to parse summary line like "5 passed, 1 failed, 2 skipped" + summary_match = re.search( + r'(\d+)\s+passed.*?(?:(\d+)\s+failed)?.*?(?:(\d+)\s+skipped)?.*?(?:(\d+)\s+error)?', + stdout + ) + if summary_match: + if summary_match.group(1): + results["passed"] = max(results["passed"], int(summary_match.group(1))) + if summary_match.group(2): + results["failed"] = max(results["failed"], int(summary_match.group(2))) + if summary_match.group(3): + results["skipped"] = max(results["skipped"], int(summary_match.group(3))) + if summary_match.group(4): + results["errors"] = max(results["errors"], int(summary_match.group(4))) + + return results + + +class ExistingTestRunner: + """Runs existing tests and integrates results into the comparison report.""" + + def __init__(self, config: ComparisonConfig): + self.config = config + + def run_all(self, report: ComparisonReport): + """Run all existing test suites.""" + logger.info("Starting existing test integration...") + + if self.config.run_existing_unit_tests: + self._run_unit_tests(report) + + if self.config.run_existing_integration_tests: + self._run_integration_tests(report) + self._run_e2e_tests(report) + + logger.info("Existing test integration complete.") + + # ------------------------------------------------------------------ + # Unit tests + # ------------------------------------------------------------------ + + def _run_unit_tests(self, report: ComparisonReport): + """Run unit tests from tests/unit/ directory.""" + t0 = time.time() + + if not UNIT_TEST_DIR.exists(): + report.add_result(TestResult( + category="existing_tests", + name="Unit Tests: directory check", + status=TestStatus.ERROR, + message=f"Unit test directory not found: {UNIT_TEST_DIR}", + duration_seconds=time.time() - t0, + )) + return + + # Run entire unit test suite + logger.info("Running unit test suite...") + result, stdout, rc = _run_pytest( + str(UNIT_TEST_DIR), + extra_args=["-x", "--timeout=120"], + timeout=600, + json_report=True, + ) + + # Parse results + if result and "summary" in result: + summary = result["summary"] + passed = summary.get("passed", 0) + failed = summary.get("failed", 0) + errors = summary.get("error", 0) + skipped = summary.get("skipped", 0) + total = summary.get("total", passed + failed + errors + skipped) + else: + # Fallback to stdout parsing + parsed = _parse_stdout_results(stdout) + passed = parsed["passed"] + failed = parsed["failed"] + errors = parsed["errors"] + skipped = parsed["skipped"] + total = passed + failed + errors + skipped + + if failed == 0 and errors == 0: + status = TestStatus.PASS + elif errors > 0: + status = TestStatus.ERROR + else: + status = TestStatus.FAIL + + report.add_result(TestResult( + category="existing_tests", + name="Unit Tests: overall", + status=status, + message=( + f"Total={total}, Passed={passed}, Failed={failed}, " + f"Errors={errors}, Skipped={skipped} | " + f"Return code: {rc}" + ), + instance_a_value={ + "total": total, "passed": passed, "failed": failed, + "errors": errors, "skipped": skipped, + }, + duration_seconds=time.time() - t0, + details={"returncode": rc, "stdout_tail": stdout[-2000:] if stdout else ""}, + )) + + # Report individual test file results + for test_file in UNIT_TEST_FILES: + test_path = UNIT_TEST_DIR / test_file + if not test_path.exists(): + report.add_result(TestResult( + category="existing_tests", + name=f"Unit: {test_file}", + status=TestStatus.SKIP, + message=f"File not found: {test_path}", + )) + continue + + tf0 = time.time() + file_result, file_stdout, file_rc = _run_pytest( + str(test_path), + extra_args=["--timeout=60"], + timeout=120, + json_report=False, + ) + + parsed = _parse_stdout_results(file_stdout) + + if file_rc == 0: + file_status = TestStatus.PASS + elif file_rc == 1: + file_status = TestStatus.FAIL + elif file_rc == 5: + file_status = TestStatus.SKIP # No tests collected + else: + file_status = TestStatus.ERROR + + report.add_result(TestResult( + category="existing_tests", + name=f"Unit: {test_file}", + status=file_status, + message=( + f"Passed={parsed['passed']}, Failed={parsed['failed']}, " + f"Errors={parsed['errors']}, Skipped={parsed['skipped']}" + ), + instance_a_value=parsed, + duration_seconds=time.time() - tf0, + details={"returncode": file_rc}, + )) + + # ------------------------------------------------------------------ + # Integration tests + # ------------------------------------------------------------------ + + def _run_integration_tests(self, report: ComparisonReport): + """Run integration tests from test/ directory.""" + for test_file in INTEGRATION_TEST_FILES: + test_path = INTEGRATION_TEST_DIR / test_file + if not test_path.exists(): + report.add_result(TestResult( + category="existing_tests", + name=f"Integration: {test_file}", + status=TestStatus.SKIP, + message=f"File not found: {test_path}", + )) + continue + + t0 = time.time() + result, stdout, rc = _run_pytest( + str(test_path), + extra_args=["-m", "integration", "--timeout=300"], + timeout=600, + json_report=False, + ) + + parsed = _parse_stdout_results(stdout) + + if rc == 0: + status = TestStatus.PASS + elif rc == 5: + status = TestStatus.SKIP + elif rc == 1: + status = TestStatus.FAIL + else: + status = TestStatus.ERROR + + report.add_result(TestResult( + category="existing_tests", + name=f"Integration: {test_file}", + status=status, + message=( + f"Passed={parsed['passed']}, Failed={parsed['failed']}, " + f"Errors={parsed['errors']}, Skipped={parsed['skipped']}" + ), + instance_a_value=parsed, + duration_seconds=time.time() - t0, + details={"returncode": rc, "stdout_tail": stdout[-1000:] if stdout else ""}, + )) + + # ------------------------------------------------------------------ + # E2E API tests (against both instances) + # ------------------------------------------------------------------ + + def _run_e2e_tests(self, report: ComparisonReport): + """Run E2E API tests from test/test.py against both instances.""" + if not E2E_TEST_FILE.exists(): + report.add_result(TestResult( + category="existing_tests", + name="E2E Tests: file check", + status=TestStatus.SKIP, + message=f"E2E test file not found: {E2E_TEST_FILE}", + )) + return + + instances = [] + if self.config.instance_a.api_url: + instances.append(("A", self.config.instance_a)) + if self.config.instance_b.api_url: + instances.append(("B", self.config.instance_b)) + + for label, instance in instances: + # Run non-destructive E2E tests (skip analysis and clustering which modify state) + safe_tests = [ + "test_map_visualization", + "test_annoy_similarity_and_playlist", + ] + + for test_name in safe_tests: + t0 = time.time() + result, stdout, rc = _run_pytest( + str(E2E_TEST_FILE), + extra_args=["-k", test_name, "--timeout=300"], + env_override={"BASE_URL": instance.api_url}, + timeout=600, + json_report=False, + ) + + parsed = _parse_stdout_results(stdout) + + if rc == 0: + status = TestStatus.PASS + elif rc == 5: + status = TestStatus.SKIP + elif rc == 1: + status = TestStatus.FAIL + else: + status = TestStatus.ERROR + + report.add_result(TestResult( + category="existing_tests", + name=f"E2E ({label}): {test_name}", + status=status, + message=( + f"Instance {label} ({instance.api_url}): " + f"Passed={parsed['passed']}, Failed={parsed['failed']}" + ), + instance_a_value=parsed if label == "A" else None, + instance_b_value=parsed if label == "B" else None, + duration_seconds=time.time() - t0, + details={"returncode": rc, "instance": label, + "api_url": instance.api_url, + "stdout_tail": stdout[-500:] if stdout else ""}, + )) + + # ------------------------------------------------------------------ + # Discovery: list all available tests + # ------------------------------------------------------------------ + + @staticmethod + def discover_tests() -> dict: + """Discover all available tests and return a structured summary.""" + discovery = { + "unit_tests": [], + "integration_tests": [], + "e2e_tests": [], + } + + # Unit tests + if UNIT_TEST_DIR.exists(): + for f in sorted(UNIT_TEST_DIR.glob("test_*.py")): + discovery["unit_tests"].append({ + "file": str(f.relative_to(PROJECT_ROOT)), + "name": f.stem, + "exists": True, + }) + + # Integration tests + for f in INTEGRATION_TEST_FILES: + path = INTEGRATION_TEST_DIR / f + discovery["integration_tests"].append({ + "file": str(path.relative_to(PROJECT_ROOT)), + "name": Path(f).stem, + "exists": path.exists(), + }) + + # E2E tests + if E2E_TEST_FILE.exists(): + for name in E2E_TEST_NAMES: + discovery["e2e_tests"].append({ + "file": str(E2E_TEST_FILE.relative_to(PROJECT_ROOT)), + "name": name, + "exists": True, + }) + + return discovery diff --git a/testing_suite/utils.py b/testing_suite/utils.py new file mode 100644 index 00000000..e296f471 --- /dev/null +++ b/testing_suite/utils.py @@ -0,0 +1,434 @@ +""" +Shared utilities for the AudioMuse-AI Testing & Comparison Suite. + +Provides HTTP helpers, database connectors, Docker log fetchers, +timing utilities, and result aggregation primitives. +""" + +import json +import logging +import subprocess +import time +from dataclasses import dataclass, field, asdict +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +import requests + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Result types +# --------------------------------------------------------------------------- + +class TestStatus(str, Enum): + PASS = "PASS" + FAIL = "FAIL" + WARN = "WARN" + SKIP = "SKIP" + ERROR = "ERROR" + + +@dataclass +class TestResult: + """A single test result entry.""" + category: str # e.g. "api", "database", "docker", "performance" + name: str # descriptive test name + status: TestStatus + message: str = "" + instance_a_value: Any = None + instance_b_value: Any = None + diff: Any = None # computed difference + duration_seconds: float = 0.0 + details: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + d = asdict(self) + d['status'] = self.status.value + # Ensure JSON-serializable + for k in ('instance_a_value', 'instance_b_value', 'diff', 'details'): + try: + json.dumps(d[k]) + except (TypeError, ValueError): + d[k] = str(d[k]) + return d + + +@dataclass +class CategorySummary: + """Summary for a test category.""" + category: str + total: int = 0 + passed: int = 0 + failed: int = 0 + warned: int = 0 + skipped: int = 0 + errors: int = 0 + results: List[TestResult] = field(default_factory=list) + + def add(self, result: TestResult): + self.results.append(result) + self.total += 1 + if result.status == TestStatus.PASS: + self.passed += 1 + elif result.status == TestStatus.FAIL: + self.failed += 1 + elif result.status == TestStatus.WARN: + self.warned += 1 + elif result.status == TestStatus.SKIP: + self.skipped += 1 + elif result.status == TestStatus.ERROR: + self.errors += 1 + + +@dataclass +class ComparisonReport: + """Full comparison report across all categories.""" + timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) + instance_a_name: str = "" + instance_b_name: str = "" + instance_a_branch: str = "" + instance_b_branch: str = "" + categories: Dict[str, CategorySummary] = field(default_factory=dict) + config_snapshot: Dict[str, Any] = field(default_factory=dict) + + def add_result(self, result: TestResult): + cat = result.category + if cat not in self.categories: + self.categories[cat] = CategorySummary(category=cat) + self.categories[cat].add(result) + + @property + def total_tests(self) -> int: + return sum(c.total for c in self.categories.values()) + + @property + def total_passed(self) -> int: + return sum(c.passed for c in self.categories.values()) + + @property + def total_failed(self) -> int: + return sum(c.failed for c in self.categories.values()) + + @property + def total_errors(self) -> int: + return sum(c.errors for c in self.categories.values()) + + @property + def overall_status(self) -> TestStatus: + if self.total_failed > 0 or self.total_errors > 0: + return TestStatus.FAIL + return TestStatus.PASS + + def to_dict(self) -> dict: + return { + "timestamp": self.timestamp, + "instance_a": {"name": self.instance_a_name, "branch": self.instance_a_branch}, + "instance_b": {"name": self.instance_b_name, "branch": self.instance_b_branch}, + "overall_status": self.overall_status.value, + "summary": { + "total": self.total_tests, + "passed": self.total_passed, + "failed": self.total_failed, + "errors": self.total_errors, + }, + "categories": { + name: { + "total": cat.total, + "passed": cat.passed, + "failed": cat.failed, + "warned": cat.warned, + "skipped": cat.skipped, + "errors": cat.errors, + "results": [r.to_dict() for r in cat.results], + } + for name, cat in self.categories.items() + }, + "config": self.config_snapshot, + } + + +# --------------------------------------------------------------------------- +# HTTP Helpers +# --------------------------------------------------------------------------- + +def http_get(url: str, params: dict = None, timeout: int = 120, + retries: int = 3, retry_delay: float = 2.0) -> requests.Response: + """HTTP GET with retries on connection errors.""" + last_exc = None + for attempt in range(1, retries + 1): + try: + resp = requests.get(url, params=params, timeout=timeout) + return resp + except requests.RequestException as e: + last_exc = e + if attempt < retries: + time.sleep(retry_delay) + logger.debug(f"Retry {attempt}/{retries} for GET {url}: {e}") + raise last_exc + + +def http_post(url: str, json_data: dict = None, timeout: int = 120, + retries: int = 3, retry_delay: float = 2.0) -> requests.Response: + """HTTP POST with retries on connection errors.""" + last_exc = None + for attempt in range(1, retries + 1): + try: + resp = requests.post(url, json=json_data, timeout=timeout) + return resp + except requests.RequestException as e: + last_exc = e + if attempt < retries: + time.sleep(retry_delay) + logger.debug(f"Retry {attempt}/{retries} for POST {url}: {e}") + raise last_exc + + +def timed_request(method: str, url: str, **kwargs) -> Tuple[requests.Response, float]: + """Execute an HTTP request and return (response, elapsed_seconds).""" + start = time.perf_counter() + if method.upper() == "GET": + resp = http_get(url, **kwargs) + else: + resp = http_post(url, **kwargs) + elapsed = time.perf_counter() - start + return resp, elapsed + + +def wait_for_task_success(base_url: str, task_id: str, timeout: int = 1200, + retries: int = 3, retry_delay: float = 2.0) -> dict: + """Poll active_tasks until task completes, then verify success via last_task.""" + start = time.time() + while time.time() - start < timeout: + act_resp = http_get(f'{base_url}/api/active_tasks', retries=retries, + retry_delay=retry_delay) + act_resp.raise_for_status() + active = act_resp.json() + + if active and active.get('task_id') == task_id: + time.sleep(2) + continue + + last_resp = http_get(f'{base_url}/api/last_task', retries=retries, + retry_delay=retry_delay) + last_resp.raise_for_status() + final = last_resp.json() + final_id = final.get('task_id') + final_state = (final.get('status') or final.get('state') or '').upper() + + if final_id == task_id: + return final + # Task might have been superseded; keep polling briefly + time.sleep(2) + + return {"status": "TIMEOUT", "task_id": task_id} + + +# --------------------------------------------------------------------------- +# Database Helpers +# --------------------------------------------------------------------------- + +def get_pg_connection(dsn: str): + """Create a psycopg2 connection from DSN.""" + import psycopg2 + return psycopg2.connect(dsn, connect_timeout=30, + options='-c statement_timeout=120000') + + +def pg_query(dsn: str, sql: str, params: tuple = None) -> List[tuple]: + """Execute a read-only query and return all rows.""" + import psycopg2 + conn = psycopg2.connect(dsn, connect_timeout=30, + options='-c statement_timeout=120000') + try: + with conn.cursor() as cur: + cur.execute(sql, params) + return cur.fetchall() + finally: + conn.close() + + +def pg_query_dict(dsn: str, sql: str, params: tuple = None) -> List[dict]: + """Execute a query and return rows as dicts.""" + import psycopg2 + from psycopg2.extras import RealDictCursor + conn = psycopg2.connect(dsn, connect_timeout=30, + options='-c statement_timeout=120000') + try: + with conn.cursor(cursor_factory=RealDictCursor) as cur: + cur.execute(sql, params) + return [dict(row) for row in cur.fetchall()] + finally: + conn.close() + + +def pg_scalar(dsn: str, sql: str, params: tuple = None): + """Execute a query that returns a single scalar value.""" + rows = pg_query(dsn, sql, params) + if rows and rows[0]: + return rows[0][0] + return None + + +# --------------------------------------------------------------------------- +# Docker Helpers +# --------------------------------------------------------------------------- + +def docker_exec(container: str, command: str, ssh_host: str = "", + ssh_user: str = "", ssh_key: str = "", + ssh_port: int = 22, timeout: int = 30) -> Tuple[str, str, int]: + """ + Run a command inside a Docker container (locally or via SSH). + Returns (stdout, stderr, returncode). + """ + if ssh_host: + ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", + "-p", str(ssh_port)] + if ssh_key: + ssh_cmd += ["-i", ssh_key] + ssh_cmd.append(f"{ssh_user}@{ssh_host}" if ssh_user else ssh_host) + ssh_cmd.append(f"docker exec {container} {command}") + full_cmd = ssh_cmd + else: + full_cmd = ["docker", "exec", container] + command.split() + + try: + proc = subprocess.run(full_cmd, capture_output=True, text=True, timeout=timeout) + return proc.stdout, proc.stderr, proc.returncode + except subprocess.TimeoutExpired: + return "", "Command timed out", -1 + except FileNotFoundError: + return "", "docker or ssh command not found", -2 + + +def docker_logs(container: str, tail: int = 500, since: str = "", + ssh_host: str = "", ssh_user: str = "", ssh_key: str = "", + ssh_port: int = 22, timeout: int = 30) -> Tuple[str, str, int]: + """ + Fetch Docker container logs (locally or via SSH). + Returns (stdout, stderr, returncode). + """ + cmd_parts = ["docker", "logs", f"--tail={tail}"] + if since: + cmd_parts += [f"--since={since}"] + cmd_parts.append(container) + + if ssh_host: + ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", + "-p", str(ssh_port)] + if ssh_key: + ssh_cmd += ["-i", ssh_key] + ssh_cmd.append(f"{ssh_user}@{ssh_host}" if ssh_user else ssh_host) + ssh_cmd.append(" ".join(cmd_parts)) + full_cmd = ssh_cmd + else: + full_cmd = cmd_parts + + try: + proc = subprocess.run(full_cmd, capture_output=True, text=True, timeout=timeout) + return proc.stdout, proc.stderr, proc.returncode + except subprocess.TimeoutExpired: + return "", "Logs fetch timed out", -1 + except FileNotFoundError: + return "", "docker or ssh command not found", -2 + + +def docker_inspect(container: str, ssh_host: str = "", ssh_user: str = "", + ssh_key: str = "", ssh_port: int = 22, + timeout: int = 15) -> Optional[dict]: + """ + Run docker inspect on a container and return the parsed JSON. + Returns None on failure. + """ + cmd_parts = ["docker", "inspect", container] + + if ssh_host: + ssh_cmd = ["ssh", "-o", "StrictHostKeyChecking=no", + "-p", str(ssh_port)] + if ssh_key: + ssh_cmd += ["-i", ssh_key] + ssh_cmd.append(f"{ssh_user}@{ssh_host}" if ssh_user else ssh_host) + ssh_cmd.append(" ".join(cmd_parts)) + full_cmd = ssh_cmd + else: + full_cmd = cmd_parts + + try: + proc = subprocess.run(full_cmd, capture_output=True, text=True, timeout=timeout) + if proc.returncode == 0: + data = json.loads(proc.stdout) + return data[0] if isinstance(data, list) and data else data + except Exception as e: + logger.debug(f"docker inspect failed for {container}: {e}") + return None + + +# --------------------------------------------------------------------------- +# Comparison Helpers +# --------------------------------------------------------------------------- + +def compare_values(a, b, tolerance_pct: float = 0.0) -> Tuple[bool, str]: + """ + Compare two values. For numeric types, allow a percentage tolerance. + Returns (is_equal, description). + """ + if a is None and b is None: + return True, "Both None" + if a is None or b is None: + return False, f"One is None: A={a}, B={b}" + + if isinstance(a, (int, float)) and isinstance(b, (int, float)): + if a == 0 and b == 0: + return True, "Both zero" + if a == 0 or b == 0: + return False, f"A={a}, B={b}" + pct_diff = abs(a - b) / max(abs(a), abs(b)) * 100 + if pct_diff <= tolerance_pct: + return True, f"Within tolerance ({pct_diff:.2f}% <= {tolerance_pct}%)" + return False, f"Difference {pct_diff:.2f}% exceeds tolerance {tolerance_pct}%" + + if isinstance(a, str) and isinstance(b, str): + if a == b: + return True, "Strings match" + return False, f"Strings differ: '{a[:100]}' vs '{b[:100]}'" + + if isinstance(a, dict) and isinstance(b, dict): + keys_a = set(a.keys()) + keys_b = set(b.keys()) + if keys_a != keys_b: + missing_in_b = keys_a - keys_b + missing_in_a = keys_b - keys_a + return False, f"Key mismatch: missing_in_B={missing_in_b}, missing_in_A={missing_in_a}" + return True, "Dict keys match" + + if isinstance(a, list) and isinstance(b, list): + if len(a) == len(b): + return True, f"Lists same length ({len(a)})" + return False, f"List length differs: {len(a)} vs {len(b)}" + + # Fallback + if a == b: + return True, "Values equal" + return False, f"Values differ: {a} vs {b}" + + +def pct_diff(a: float, b: float) -> float: + """Calculate percentage difference between two values.""" + if a == 0 and b == 0: + return 0.0 + if a == 0 or b == 0: + return 100.0 + return abs(a - b) / max(abs(a), abs(b)) * 100 + + +def format_duration(seconds: float) -> str: + """Format seconds into a human-readable string.""" + if seconds < 1: + return f"{seconds*1000:.1f}ms" + if seconds < 60: + return f"{seconds:.2f}s" + minutes = int(seconds // 60) + secs = seconds % 60 + return f"{minutes}m {secs:.1f}s" diff --git a/tests/unit/test_app_chat.py b/tests/unit/test_app_chat.py index 7476f077..3fb898d3 100644 --- a/tests/unit/test_app_chat.py +++ b/tests/unit/test_app_chat.py @@ -1,27 +1,551 @@ -""" -Tests for app_chat.py::chat_playlist_api() — Instant Playlist pipeline. - -Tests verify: -- Pre-validation (song_similarity empty title/artist rejection, search_database no-filter rejection) -- Artist diversity enforcement (MAX_SONGS_PER_ARTIST_PLAYLIST cap, backfill) +"""Unit tests for app_chat.py instant playlist pipeline + +Tests cover the agentic playlist workflow: +- Artist diversity enforcement (max songs per artist, backfill) +- Proportional sampling from tool calls +- Pre-execution validation (empty song_similarity, filterless search_database) +- Ollama JSON extraction edge cases +- Iteration deduplication (song_ids_seen) +- Stopping conditions (target reached, no new songs, AI error) +- API key validation for cloud providers - Iteration message content (iteration 0 minimal, iteration > 0 rich feedback) """ +import json import pytest -from unittest.mock import Mock, patch, MagicMock, call +from unittest.mock import Mock, MagicMock, patch, call from tests.conftest import make_dict_row, make_mock_connection +flask = pytest.importorskip('flask', reason='Flask not installed') +from flask import Flask + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def app(): + """Create a Flask app with the chat blueprint registered.""" + with patch('config.OLLAMA_SERVER_URL', 'http://localhost:11434'), \ + patch('config.OLLAMA_MODEL_NAME', 'test-model'), \ + patch('config.OPENAI_SERVER_URL', 'http://localhost'), \ + patch('config.OPENAI_MODEL_NAME', 'gpt-4'), \ + patch('config.OPENAI_API_KEY', ''), \ + patch('config.GEMINI_MODEL_NAME', 'gemini-pro'), \ + patch('config.GEMINI_API_KEY', ''), \ + patch('config.MISTRAL_MODEL_NAME', 'mistral-7b'), \ + patch('config.MISTRAL_API_KEY', ''), \ + patch('config.AI_MODEL_PROVIDER', 'OLLAMA'): + from app_chat import chat_bp + flask_app = Flask(__name__) + flask_app.register_blueprint(chat_bp) + flask_app.config['TESTING'] = True + yield flask_app + + +@pytest.fixture +def client(app): + return app.test_client() + + +def _song(item_id, title="Song", artist="Artist"): + """Helper to create a song dict.""" + return {'item_id': item_id, 'title': title, 'artist': artist} + + +# --------------------------------------------------------------------------- +# Artist Diversity Enforcement +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestArtistDiversity: + """Test artist diversity enforcement (Phase 3B).""" + + def test_under_limit_keeps_all(self): + """Songs under the per-artist limit are all kept.""" + songs = [_song(f'id{i}', artist='ArtistA') for i in range(3)] + max_per = 5 + artist_counts = {} + diverse = [] + overflow = [] + for s in songs: + a = s.get('artist', 'Unknown') + artist_counts[a] = artist_counts.get(a, 0) + 1 + if artist_counts[a] <= max_per: + diverse.append(s) + else: + overflow.append(s) + assert len(diverse) == 3 + assert len(overflow) == 0 + + def test_over_limit_trims_excess(self): + """Songs over the per-artist limit go to overflow.""" + songs = [_song(f'id{i}', artist='ArtistA') for i in range(8)] + max_per = 5 + artist_counts = {} + diverse = [] + overflow = [] + for s in songs: + a = s.get('artist', 'Unknown') + artist_counts[a] = artist_counts.get(a, 0) + 1 + if artist_counts[a] <= max_per: + diverse.append(s) + else: + overflow.append(s) + assert len(diverse) == 5 + assert len(overflow) == 3 + + def test_multiple_artists_independent_limits(self): + """Each artist gets an independent limit.""" + songs = [_song(f'a{i}', artist='A') for i in range(6)] + songs += [_song(f'b{i}', artist='B') for i in range(4)] + max_per = 5 + artist_counts = {} + diverse = [] + for s in songs: + a = s.get('artist', 'Unknown') + artist_counts[a] = artist_counts.get(a, 0) + 1 + if artist_counts[a] <= max_per: + diverse.append(s) + assert sum(1 for s in diverse if s['artist'] == 'A') == 5 + assert sum(1 for s in diverse if s['artist'] == 'B') == 4 + + def test_backfill_from_overflow(self): + """Overflow songs backfill from least-represented artists.""" + # 10 songs from ArtistA, 2 from ArtistB, limit=3, target=8 + songs = [_song(f'a{i}', artist='ArtistA') for i in range(10)] + songs += [_song(f'b{i}', artist='ArtistB') for i in range(2)] + max_per = 3 + target = 8 + artist_counts = {} + diverse = [] + overflow = [] + for s in songs: + a = s.get('artist', 'Unknown') + artist_counts[a] = artist_counts.get(a, 0) + 1 + if artist_counts[a] <= max_per: + diverse.append(s) + else: + overflow.append(s) + # diverse has 3 A + 2 B = 5 songs, need 3 more from overflow + if len(diverse) < target and overflow: + diverse_artist_counts = {} + for s in diverse: + a = s.get('artist', 'Unknown') + diverse_artist_counts[a] = diverse_artist_counts.get(a, 0) + 1 + overflow.sort(key=lambda s: diverse_artist_counts.get(s.get('artist', ''), 0)) + backfill_needed = target - len(diverse) + diverse.extend(overflow[:backfill_needed]) + assert len(diverse) == 8 + + def test_default_max_per_artist_is_5(self): + """Config default MAX_SONGS_PER_ARTIST_PLAYLIST is 5.""" + from config import MAX_SONGS_PER_ARTIST_PLAYLIST + assert MAX_SONGS_PER_ARTIST_PLAYLIST == 5 + + +# --------------------------------------------------------------------------- +# Proportional Sampling +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestProportionalSampling: + """Test proportional sampling when more songs than target.""" + + def test_under_target_uses_all(self): + """When total < target, all songs are kept.""" + all_songs = [_song(f'id{i}') for i in range(50)] + target = 100 + assert len(all_songs) <= target + + def test_over_target_samples_proportionally(self): + """When total > target, songs are sampled proportionally by source.""" + song_sources = {} + songs_by_call = {0: [], 1: []} + for i in range(80): + s = _song(f'id{i}') + songs_by_call[0].append(s) + song_sources[f'id{i}'] = 0 + for i in range(80, 120): + s = _song(f'id{i}') + songs_by_call[1].append(s) + song_sources[f'id{i}'] = 1 + total = 120 + target = 100 + final = [] + for call_index, tool_songs in songs_by_call.items(): + proportion = len(tool_songs) / total + allocated = int(proportion * target) + if allocated == 0 and len(tool_songs) > 0: + allocated = 1 + final.extend(tool_songs[:allocated]) + # Call 0: 80/120*100=66, Call 1: 40/120*100=33 => 99 total + assert len(final) <= target + + def test_each_call_gets_at_least_one(self): + """Even a tool call with 1 song gets at least 1 in the final list.""" + songs_by_call = {0: [_song('majority')]*99, 1: [_song('tiny')]} + total = 100 + target = 50 + final = [] + for call_index, tool_songs in songs_by_call.items(): + proportion = len(tool_songs) / total + allocated = int(proportion * target) + if allocated == 0 and len(tool_songs) > 0: + allocated = 1 + final.extend(tool_songs[:allocated]) + # Check that call 1's song is included + assert any(s['item_id'] == 'tiny' for s in final) + + def test_rounding_backfill(self): + """Remaining songs are backfilled if proportional rounding falls short.""" + all_songs = [_song(f'id{i}') for i in range(120)] + song_sources = {f'id{i}': i % 3 for i in range(120)} + target = 100 + songs_by_call = {} + for s in all_songs: + ci = song_sources[s['item_id']] + songs_by_call.setdefault(ci, []).append(s) + final = [] + for ci, ts in songs_by_call.items(): + proportion = len(ts) / len(all_songs) + allocated = int(proportion * target) + if allocated == 0 and len(ts) > 0: + allocated = 1 + final.extend(ts[:allocated]) + if len(final) < target: + remaining = [s for s in all_songs if s not in final] + final.extend(remaining[:target - len(final)]) + assert len(final) == target + + +# --------------------------------------------------------------------------- +# Pre-Execution Validation +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestPreExecutionValidation: + """Test pre-execution validation of tool calls.""" + + def test_song_similarity_empty_title_rejected(self): + """song_similarity with empty title is rejected.""" + tc = {'name': 'song_similarity', 'arguments': {'song_title': '', 'song_artist': 'Artist'}} + ta = tc['arguments'] + assert not ta.get('song_title', '').strip() + + def test_song_similarity_empty_artist_rejected(self): + """song_similarity with empty artist is rejected.""" + tc = {'name': 'song_similarity', 'arguments': {'song_title': 'Title', 'song_artist': ''}} + ta = tc['arguments'] + assert not ta.get('song_artist', '').strip() + + def test_song_similarity_whitespace_only_rejected(self): + """song_similarity with whitespace-only values is rejected.""" + tc = {'name': 'song_similarity', 'arguments': {'song_title': ' ', 'song_artist': ' '}} + ta = tc['arguments'] + assert not ta.get('song_title', '').strip() + assert not ta.get('song_artist', '').strip() + + def test_song_similarity_valid_passes(self): + """song_similarity with valid title and artist passes.""" + tc = {'name': 'song_similarity', 'arguments': {'song_title': 'Bohemian Rhapsody', 'song_artist': 'Queen'}} + ta = tc['arguments'] + assert ta.get('song_title', '').strip() + assert ta.get('song_artist', '').strip() + + def test_search_database_no_filters_rejected(self): + """search_database with no filters is rejected.""" + tc = {'name': 'search_database', 'arguments': {'get_songs': 50}} + ta = tc['arguments'] + filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', + 'key', 'scale', 'year_min', 'year_max', 'min_rating'] + has_filter = any(ta.get(k) for k in filter_keys) + assert not has_filter + + def test_search_database_with_genres_passes(self): + """search_database with genres filter passes.""" + tc = {'name': 'search_database', 'arguments': {'genres': ['rock'], 'get_songs': 50}} + ta = tc['arguments'] + filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', + 'key', 'scale', 'year_min', 'year_max', 'min_rating'] + has_filter = any(ta.get(k) for k in filter_keys) + assert has_filter + + def test_search_database_with_energy_passes(self): + """search_database with energy filter passes.""" + tc = {'name': 'search_database', 'arguments': {'energy_min': 0.5, 'get_songs': 50}} + ta = tc['arguments'] + filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', + 'key', 'scale', 'year_min', 'year_max', 'min_rating'] + has_filter = any(ta.get(k) for k in filter_keys) + assert has_filter + + def test_search_database_with_year_filter_passes(self): + """search_database with year_min filter passes.""" + tc = {'name': 'search_database', 'arguments': {'year_min': 2000, 'get_songs': 50}} + ta = tc['arguments'] + filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', + 'key', 'scale', 'year_min', 'year_max', 'min_rating'] + has_filter = any(ta.get(k) for k in filter_keys) + assert has_filter + + +# --------------------------------------------------------------------------- +# Iteration Deduplication +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestIterationDeduplication: + """Test song deduplication across iterations.""" + + def test_duplicate_ids_filtered(self): + """Duplicate item_ids are not added twice.""" + song_ids_seen = set() + all_songs = [] + batch1 = [_song('id1'), _song('id2'), _song('id3')] + batch2 = [_song('id2'), _song('id3'), _song('id4')] # id2, id3 are dupes + for s in batch1: + if s['item_id'] not in song_ids_seen: + all_songs.append(s) + song_ids_seen.add(s['item_id']) + for s in batch2: + if s['item_id'] not in song_ids_seen: + all_songs.append(s) + song_ids_seen.add(s['item_id']) + assert len(all_songs) == 4 + assert song_ids_seen == {'id1', 'id2', 'id3', 'id4'} + + def test_all_duplicates_adds_zero(self): + """When all songs are duplicates, zero new songs are added.""" + song_ids_seen = {'id1', 'id2'} + all_songs = [_song('id1'), _song('id2')] + batch = [_song('id1'), _song('id2')] + new_count = 0 + for s in batch: + if s['item_id'] not in song_ids_seen: + all_songs.append(s) + song_ids_seen.add(s['item_id']) + new_count += 1 + assert new_count == 0 + + +# --------------------------------------------------------------------------- +# Stopping Conditions +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestStoppingConditions: + """Test the agentic loop stopping conditions.""" + + def test_target_reached_stops(self): + """Loop stops when target song count is reached.""" + target = 100 + current = 105 + assert current >= target + + def test_no_tool_calls_stops(self): + """Loop stops when AI returns no tool calls.""" + tool_calls = [] + assert not tool_calls + + def test_no_new_songs_stops(self): + """Loop stops when an iteration adds 0 new songs.""" + iteration_songs_added = 0 + assert iteration_songs_added == 0 + + def test_max_iterations_stops(self): + """Loop stops at max_iterations.""" + max_iterations = 5 + for iteration in range(max_iterations): + pass + assert iteration == max_iterations - 1 + + def test_ai_error_stops_after_first_iteration(self): + """AI error on iteration > 0 breaks the loop.""" + iteration = 2 + error = True + should_break = iteration > 0 and error + assert should_break + + +# --------------------------------------------------------------------------- +# API Key Validation +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestAPIKeyValidation: + """Test API key validation for cloud providers.""" + + def test_missing_input_returns_400(self, client): + """Request without userInput returns 400.""" + resp = client.post('/api/chatPlaylist', json={}) + assert resp.status_code == 400 + assert resp.content_type.startswith('application/json') + assert 'error' in resp.get_json() + + def test_none_provider_returns_no_ai_message(self, client): + """Provider NONE returns informational message.""" + resp = client.post('/api/chatPlaylist', json={ + 'userInput': 'test', + 'ai_provider': 'NONE' + }) + assert resp.status_code == 200 + assert resp.content_type.startswith('application/json') + data = resp.get_json() + assert 'No AI provider selected' in data['response']['message'] + + def test_openai_missing_key_returns_400(self, client): + """OpenAI without API key returns 400.""" + resp = client.post('/api/chatPlaylist', json={ + 'userInput': 'test', + 'ai_provider': 'OPENAI', + 'openai_api_key': '' + }) + assert resp.status_code == 400 + assert resp.content_type.startswith('application/json') + assert 'error' in resp.get_json() or 'response' in resp.get_json() + + def test_gemini_placeholder_key_returns_400(self, client): + """Gemini with placeholder key returns 400.""" + resp = client.post('/api/chatPlaylist', json={ + 'userInput': 'test', + 'ai_provider': 'GEMINI', + 'gemini_api_key': 'YOUR-GEMINI-API-KEY-HERE' + }) + assert resp.status_code == 400 + assert resp.content_type.startswith('application/json') + assert 'error' in resp.get_json() or 'response' in resp.get_json() + + def test_mistral_placeholder_key_returns_400(self, client): + """Mistral with placeholder key returns 400.""" + resp = client.post('/api/chatPlaylist', json={ + 'userInput': 'test', + 'ai_provider': 'MISTRAL', + 'mistral_api_key': 'YOUR-MISTRAL-API-KEY-HERE' + }) + assert resp.status_code == 400 + assert resp.content_type.startswith('application/json') + assert 'error' in resp.get_json() or 'response' in resp.get_json() + + +# --------------------------------------------------------------------------- +# Create Playlist Endpoint +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestCreatePlaylistEndpoint: + """Test the /api/create_playlist endpoint.""" + + def _mock_voyager(self): + """Context manager to mock tasks.voyager_manager for import.""" + mock_vm = MagicMock() + return patch.dict('sys.modules', {'tasks.voyager_manager': mock_vm}) + + def test_missing_params_returns_400(self, client): + """Missing playlist_name or item_ids returns 400.""" + with self._mock_voyager(): + resp = client.post('/api/create_playlist', json={'playlist_name': 'Test'}) + assert resp.status_code == 400 + + def test_empty_name_returns_400(self, client): + """Empty playlist name returns 400.""" + with self._mock_voyager(): + resp = client.post('/api/create_playlist', json={ + 'playlist_name': ' ', + 'item_ids': ['id1'] + }) + assert resp.status_code == 400 + + def test_empty_item_ids_returns_400(self, client): + """Empty item_ids list returns 400.""" + with self._mock_voyager(): + resp = client.post('/api/create_playlist', json={ + 'playlist_name': 'Test', + 'item_ids': [] + }) + assert resp.status_code == 400 + + def test_single_provider_success(self, client): + """Successful single-provider playlist creation.""" + mock_vm = MagicMock() + mock_vm.create_playlist_from_ids = Mock(return_value='playlist-123') + with patch.dict('sys.modules', {'tasks.voyager_manager': mock_vm}): + resp = client.post('/api/create_playlist', json={ + 'playlist_name': 'My Mix', + 'item_ids': ['id1', 'id2'] + }) + assert resp.status_code == 200 + data = resp.get_json() + assert 'Successfully created' in data['message'] + + def test_multi_provider_success(self, client): + """Successful multi-provider playlist creation.""" + mock_vm = MagicMock() + mock_vm.create_playlist_from_ids = Mock(return_value={ + 'jellyfin': {'success': True, 'id': 'jf-1'}, + 'navidrome': {'success': False, 'error': 'timeout'} + }) + with patch.dict('sys.modules', {'tasks.voyager_manager': mock_vm}): + resp = client.post('/api/create_playlist', json={ + 'playlist_name': 'My Mix', + 'item_ids': ['id1'], + 'provider_ids': 'all' + }) + assert resp.status_code == 200 + data = resp.get_json() + assert '1/2' in data['message'] + + +# --------------------------------------------------------------------------- +# Config Defaults Endpoint +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestConfigDefaultsEndpoint: + """Test the /api/config_defaults endpoint.""" + + def test_returns_200(self, client): + """GET /api/config_defaults returns 200.""" + resp = client.get('/api/config_defaults') + assert resp.status_code == 200 + + def test_returns_json_content_type(self, client): + """Response has application/json content type.""" + resp = client.get('/api/config_defaults') + assert resp.content_type.startswith('application/json') + + def test_returns_json_with_expected_keys(self, client): + """Response includes provider configuration defaults.""" + resp = client.get('/api/config_defaults') + data = resp.get_json() + assert isinstance(data, dict) + assert 'default_ai_provider' in data + assert 'default_ollama_model_name' in data + assert 'ollama_server_url' in data + assert 'default_openai_model_name' in data + assert 'openai_server_url' in data + assert 'default_gemini_model_name' in data + assert 'default_mistral_model_name' in data + + def test_values_are_strings(self, client): + """All returned values are strings.""" + resp = client.get('/api/config_defaults') + data = resp.get_json() + for key, value in data.items(): + assert isinstance(value, str), f"Expected string for '{key}', got {type(value)}" + + +# --------------------------------------------------------------------------- +# Pre-Validation (from origin/main) +# --------------------------------------------------------------------------- class TestPreValidation: """Test the pre-validation block in chat_playlist_api() (lines ~466-493).""" def test_song_similarity_empty_title_rejected(self): """song_similarity with empty title should be skipped.""" - # This test validates the logic without calling the full endpoint - # It tests the rejection criteria: title must be non-empty title = "" artist = "Artist" - - # Check if title passes validation is_valid = bool(title.strip()) assert not is_valid @@ -29,8 +553,6 @@ def test_song_similarity_empty_artist_rejected(self): """song_similarity with empty artist should be skipped.""" title = "Song" artist = "" - - # Check if artist passes validation is_valid = bool(artist.strip()) assert not is_valid @@ -38,17 +560,14 @@ def test_song_similarity_whitespace_only_rejected(self): """song_similarity with whitespace-only title/artist should be skipped.""" title = " " artist = " \t " - assert not title.strip() assert not artist.strip() def test_search_database_zero_filters_rejected(self): """search_database with no filters specified should be skipped.""" - # Test the filter-checking logic filters = {} filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', 'key', 'scale', 'year_min', 'year_max', 'min_rating', 'album'] - has_filter = any(filters.get(k) for k in filter_keys) assert not has_filter @@ -57,7 +576,6 @@ def test_search_database_album_only_filter_accepted(self): filters = {'album': 'Dark Side of the Moon'} filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', 'key', 'scale', 'year_min', 'year_max', 'min_rating', 'album'] - has_filter = any(filters.get(k) for k in filter_keys) assert has_filter @@ -66,7 +584,6 @@ def test_search_database_genres_filter_accepted(self): filters = {'genres': ['rock', 'metal']} filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', 'key', 'scale', 'year_min', 'year_max', 'min_rating', 'album'] - has_filter = any(filters.get(k) for k in filter_keys) assert has_filter @@ -75,29 +592,30 @@ def test_search_database_year_filter_accepted(self): filters = {'year_min': 1990} filter_keys = ['genres', 'moods', 'tempo_min', 'tempo_max', 'energy_min', 'energy_max', 'key', 'scale', 'year_min', 'year_max', 'min_rating', 'album'] - has_filter = any(filters.get(k) for k in filter_keys) assert has_filter def test_song_similarity_both_title_and_artist_required(self): """song_similarity requires BOTH title AND artist non-empty.""" test_cases = [ - {"title": "Song", "artist": ""}, # Only title → invalid - {"title": "", "artist": "Artist"}, # Only artist → invalid - {"title": "Song", "artist": "Artist"}, # Both → valid + {"title": "Song", "artist": ""}, # Only title -> invalid + {"title": "", "artist": "Artist"}, # Only artist -> invalid + {"title": "Song", "artist": "Artist"}, # Both -> valid ] - for tc in test_cases: title_valid = bool(tc['title'].strip()) artist_valid = bool(tc['artist'].strip()) is_valid = title_valid and artist_valid - if tc['title'] == "Song" and tc['artist'] == "Artist": assert is_valid else: assert not is_valid +# --------------------------------------------------------------------------- +# Artist Diversity Enforcement (from origin/main) +# --------------------------------------------------------------------------- + class TestArtistDiversityEnforcement: """Test artist diversity cap and backfill logic (lines ~671-702 in app_chat.py).""" @@ -118,19 +636,16 @@ def _apply_diversity_logic(self, songs, max_per_artist, target_count): # Backfill if needed if len(diverse_list) < target_count and overflow_pool: - # Count how many unique artists in diverse_list diverse_artist_counts = {} for song in diverse_list: artist = song.get('artist', 'Unknown') diverse_artist_counts[artist] = diverse_artist_counts.get(artist, 0) + 1 - # Sort overflow by least-represented artists first def artist_rarity(song): artist = song.get('artist', 'Unknown') return diverse_artist_counts.get(artist, 0) overflow_sorted = sorted(overflow_pool, key=artist_rarity) - backfill_needed = target_count - len(diverse_list) backfill = overflow_sorted[:backfill_needed] diverse_list.extend(backfill) @@ -145,12 +660,9 @@ def test_songs_above_cap_moved_to_overflow(self): {'item_id': '3', 'artist': 'Beatles', 'title': 'A Day in Life'}, {'item_id': '4', 'artist': 'Beatles', 'title': 'Twist and Shout'}, {'item_id': '5', 'artist': 'Beatles', 'title': 'Love Me Do'}, - {'item_id': '6', 'artist': 'Beatles', 'title': 'Penny Lane'}, # 6th song should go to overflow + {'item_id': '6', 'artist': 'Beatles', 'title': 'Penny Lane'}, ] - result = self._apply_diversity_logic(songs, max_per_artist=5, target_count=5) - - # With target = 5 and only 5 Beatles fitting, we should have exactly 5 beatles_in_result = [s for s in result if s['artist'] == 'Beatles'] assert len(beatles_in_result) == 5 assert len(result) == 5 @@ -161,30 +673,23 @@ def test_exact_cap_songs_all_included(self): {'item_id': f'{i}', 'artist': 'Artist1', 'title': f'Song{i}'} for i in range(1, 6) ] - result = self._apply_diversity_logic(songs, max_per_artist=5, target_count=10) - assert len(result) == 5 assert all(s['artist'] == 'Artist1' for s in result) def test_backfill_from_overflow(self): """Overflow songs backfilled if target not met.""" songs = [ - # 5 Beatles (at cap) {'item_id': '1', 'artist': 'Beatles', 'title': 'A'}, {'item_id': '2', 'artist': 'Beatles', 'title': 'B'}, {'item_id': '3', 'artist': 'Beatles', 'title': 'C'}, {'item_id': '4', 'artist': 'Beatles', 'title': 'D'}, {'item_id': '5', 'artist': 'Beatles', 'title': 'E'}, - # 3 Rolling Stones (overflow) {'item_id': '6', 'artist': 'Rolling Stones', 'title': 'X'}, {'item_id': '7', 'artist': 'Rolling Stones', 'title': 'Y'}, {'item_id': '8', 'artist': 'Rolling Stones', 'title': 'Z'}, ] - result = self._apply_diversity_logic(songs, max_per_artist=5, target_count=8) - - # Should have 5 Beatles + 3 Rolling Stones = 8 assert len(result) == 8 beatles = [s for s in result if s['artist'] == 'Beatles'] stones = [s for s in result if s['artist'] == 'Rolling Stones'] @@ -194,31 +699,24 @@ def test_backfill_from_overflow(self): def test_backfill_prioritizes_underrepresented_artists(self): """Backfill prefers artists with fewer songs already in list.""" songs = [ - # 5 Artist1 (at cap) {'item_id': '1', 'artist': 'Artist1', 'title': 'A1'}, {'item_id': '2', 'artist': 'Artist1', 'title': 'A2'}, {'item_id': '3', 'artist': 'Artist1', 'title': 'A3'}, {'item_id': '4', 'artist': 'Artist1', 'title': 'A4'}, {'item_id': '5', 'artist': 'Artist1', 'title': 'A5'}, - # 1 Artist2 (underrepresented) {'item_id': '6', 'artist': 'Artist2', 'title': 'B1'}, - # 5 Artist3 (at cap) {'item_id': '7', 'artist': 'Artist3', 'title': 'C1'}, {'item_id': '8', 'artist': 'Artist3', 'title': 'C2'}, {'item_id': '9', 'artist': 'Artist3', 'title': 'C3'}, {'item_id': '10', 'artist': 'Artist3', 'title': 'C4'}, {'item_id': '11', 'artist': 'Artist3', 'title': 'C5'}, - # Overflows {'item_id': '12', 'artist': 'Artist2', 'title': 'B2'}, {'item_id': '13', 'artist': 'Artist3', 'title': 'C6'}, ] - result = self._apply_diversity_logic(songs, max_per_artist=5, target_count=12) - - # Should backfill Artist2 before Artist3 (more underrepresented) assert len(result) == 12 artist2_count = len([s for s in result if s['artist'] == 'Artist2']) - assert artist2_count >= 2 # B1 + B2 from backfill + assert artist2_count >= 2 def test_overflow_pool_not_used_when_target_met(self): """If diverse_list already meets target, don't add overflow.""" @@ -226,17 +724,18 @@ def test_overflow_pool_not_used_when_target_met(self): {'item_id': '1', 'artist': 'Artist1', 'title': 'A1'}, {'item_id': '2', 'artist': 'Artist1', 'title': 'A2'}, {'item_id': '3', 'artist': 'Artist2', 'title': 'B1'}, - {'item_id': '4', 'artist': 'Artist1', 'title': 'A3'}, # Overflow + {'item_id': '4', 'artist': 'Artist1', 'title': 'A3'}, ] - result = self._apply_diversity_logic(songs, max_per_artist=2, target_count=3) - - # Should have exactly 3: Artist1(2) + Artist2(1) assert len(result) == 3 artist1_count = len([s for s in result if s['artist'] == 'Artist1']) assert artist1_count == 2 +# --------------------------------------------------------------------------- +# Iteration Message (from origin/main) +# --------------------------------------------------------------------------- + class TestIterationMessage: """Test iteration 0 vs iteration > 0 message content.""" @@ -244,11 +743,7 @@ def test_iteration_0_message_is_minimal_request(self): """Iteration 0 should just be: 'Build a {target}-song playlist for: \"...\"'""" user_input = "songs like Radiohead" target = 100 - - # Iteration 0 message construction ai_context = f'Build a {target}-song playlist for: "{user_input}"' - - # Should be simple, no library stats assert "Build a 100-song playlist for:" in ai_context assert "Radiohead" in ai_context assert "Top artists:" not in ai_context @@ -256,23 +751,18 @@ def test_iteration_0_message_is_minimal_request(self): def test_iteration_gt0_contains_top_artists(self): """Iteration > 0 should include top artists and their counts.""" - # Simulate building the feedback message for iteration > 0 current_song_count = 45 target_song_count = 100 songs_needed = target_song_count - current_song_count - - # Simulated top artists artist_counts = {'Radiohead': 12, 'Thom Yorke': 8, 'The National': 6} top_5 = sorted(artist_counts.items(), key=lambda x: x[1], reverse=True)[:5] top_artists_str = ', '.join([f'{a}({c})' for a, c in top_5]) - ai_context = f"""Original request: "songs like Radiohead" Progress: {current_song_count}/{target_song_count} songs collected. Need {songs_needed} MORE. What we have so far: - Top artists: {top_artists_str} """ - assert f"{current_song_count}/{target_song_count}" in ai_context assert "Top artists:" in ai_context assert "Radiohead(12)" in ai_context @@ -282,9 +772,7 @@ def test_iteration_gt0_contains_diversity_ratio(self): current_song_count = 45 unique_artists = 15 diversity_ratio = unique_artists / max(current_song_count, 1) - ai_context = f"Artist diversity: {unique_artists} unique artists (ratio: {diversity_ratio:.2f})" - assert "Artist diversity:" in ai_context assert f"{unique_artists}" in ai_context @@ -295,8 +783,6 @@ def test_iteration_gt0_contains_tools_used_history(self): {'name': 'song_alchemy', 'songs': 20}, ] tools_str = ', '.join([f"{t['name']}({t['songs']})" for t in tools_used]) - ai_context = f"Tools used: {tools_str}" - assert "text_search(25)" in ai_context assert "song_alchemy(20)" in ai_context diff --git a/tests/unit/test_app_setup.py b/tests/unit/test_app_setup.py new file mode 100644 index 00000000..24540a31 --- /dev/null +++ b/tests/unit/test_app_setup.py @@ -0,0 +1,601 @@ +"""Unit tests for app_setup.py Flask blueprint + +Tests cover the setup wizard and provider management: +- Provider config validation (PROVIDER_SCHEMAS) +- Setup status detection (env auto-detect, DB flag) +- Provider CRUD operations +- Settings management (get/set/apply) +- API endpoint responses +- Multi-provider mode +""" +import json +import sys +import pytest +from datetime import datetime +from unittest.mock import Mock, MagicMock, patch, call + +flask = pytest.importorskip('flask', reason='Flask not installed') +from flask import Flask + +# Pre-register mock for tasks.mediaserver to avoid pydub/audioop import chain +if 'tasks.mediaserver' not in sys.modules: + _mock_mediaserver = MagicMock() + _mock_mediaserver.get_available_provider_types = Mock(return_value={}) + _mock_mediaserver.get_provider_info = Mock(return_value=None) + _mock_mediaserver.test_provider_connection = Mock(return_value=(True, 'OK')) + _mock_mediaserver.get_sample_tracks_from_provider = Mock(return_value=[]) + _mock_mediaserver.get_libraries_for_provider = Mock(return_value=[]) + _mock_mediaserver.PROVIDER_TYPES = { + 'jellyfin': {'name': 'Jellyfin', 'description': 'Jellyfin Server', + 'supports_user_auth': True, 'supports_play_history': True}, + 'navidrome': {'name': 'Navidrome', 'description': 'Navidrome Server', + 'supports_user_auth': True, 'supports_play_history': True}, + 'lyrion': {'name': 'Lyrion', 'description': 'Lyrion Music Server', + 'supports_user_auth': False, 'supports_play_history': True}, + 'emby': {'name': 'Emby', 'description': 'Emby Server', + 'supports_user_auth': True, 'supports_play_history': True}, + 'localfiles': {'name': 'Local Files', 'description': 'Local file system', + 'supports_user_auth': False, 'supports_play_history': False}, + } + sys.modules['tasks.mediaserver'] = _mock_mediaserver + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def app(): + """Create a Flask app with the setup blueprint registered.""" + with patch('app_setup.get_db') as _mock_get_db, \ + patch('app_setup.detect_music_path_prefix') as _mock_detect: + from app_setup import setup_bp + flask_app = Flask(__name__) + flask_app.register_blueprint(setup_bp) + flask_app.config['TESTING'] = True + yield flask_app + + +@pytest.fixture +def client(app): + """Create a Flask test client.""" + return app.test_client() + + +def _make_mock_cursor(rows=None, fetchone_val=None, rowcount=1): + """Helper to create a mock DB cursor with context-manager support.""" + mock_cur = MagicMock() + if rows is not None: + mock_cur.fetchall.return_value = rows + if fetchone_val is not None: + mock_cur.fetchone.return_value = fetchone_val + mock_cur.rowcount = rowcount + return mock_cur + + +def _make_mock_db(cursor): + """Helper to create a mock DB connection that yields the given cursor.""" + mock_db = MagicMock() + mock_db.cursor.return_value.__enter__ = Mock(return_value=cursor) + mock_db.cursor.return_value.__exit__ = Mock(return_value=False) + return mock_db + + +# --------------------------------------------------------------------------- +# Provider Config Validation (PROVIDER_SCHEMAS) +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestProviderConfigValidation: + """Test validate_provider_config() for all provider types.""" + + def _validate(self, provider_type, config_data): + from app_setup import validate_provider_config + return validate_provider_config(provider_type, config_data) + + def test_unknown_provider_type_invalid(self): + valid, errors = self._validate('unknown_type', {}) + assert not valid + assert 'Unknown provider type' in errors[0] + + def test_jellyfin_valid(self): + valid, errors = self._validate('jellyfin', { + 'url': 'http://localhost:8096', + 'user_id': 'user123', + 'token': 'abc123' + }) + assert valid + assert len(errors) == 0 + + def test_jellyfin_missing_required_fields(self): + valid, errors = self._validate('jellyfin', {'url': 'http://localhost'}) + assert not valid + assert any('user_id' in e for e in errors) + assert any('token' in e for e in errors) + + def test_jellyfin_invalid_url_scheme(self): + valid, errors = self._validate('jellyfin', { + 'url': 'ftp://localhost:8096', + 'user_id': 'user123', + 'token': 'abc123' + }) + assert not valid + assert any('http://' in e for e in errors) + + def test_navidrome_valid(self): + valid, errors = self._validate('navidrome', { + 'url': 'https://navidrome.local', + 'user': 'admin', + 'password': 'pass123' + }) + assert valid + + def test_navidrome_missing_password(self): + valid, errors = self._validate('navidrome', { + 'url': 'https://navidrome.local', + 'user': 'admin' + }) + assert not valid + assert any('password' in e for e in errors) + + def test_lyrion_valid(self): + valid, errors = self._validate('lyrion', { + 'url': 'http://lyrion.local:9000' + }) + assert valid + + def test_lyrion_missing_url(self): + valid, errors = self._validate('lyrion', {}) + assert not valid + assert any('url' in e for e in errors) + + def test_emby_valid(self): + valid, errors = self._validate('emby', { + 'url': 'http://emby.local:8096', + 'user_id': 'uid', + 'token': 'tok' + }) + assert valid + + def test_localfiles_valid(self): + with patch('os.path.isabs', return_value=True): + valid, errors = self._validate('localfiles', { + 'music_directory': '/music/library' + }) + assert valid + + def test_localfiles_relative_path_invalid(self): + valid, errors = self._validate('localfiles', { + 'music_directory': 'relative/path' + }) + assert not valid + assert any('absolute' in e for e in errors) + + def test_localfiles_missing_music_directory(self): + valid, errors = self._validate('localfiles', {}) + assert not valid + assert any('music_directory' in e for e in errors) + + def test_all_provider_types_in_schema(self): + """All known provider types have validation schemas.""" + from app_setup import PROVIDER_SCHEMAS + expected = {'jellyfin', 'navidrome', 'lyrion', 'emby', 'localfiles'} + assert set(PROVIDER_SCHEMAS.keys()) == expected + + +# --------------------------------------------------------------------------- +# Setup Status Detection +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestSetupStatus: + """Test is_setup_completed() with env-var auto-detection.""" + + @patch('app_setup.get_setting', return_value=True) + def test_completed_from_db_flag(self, mock_get): + from app_setup import is_setup_completed + assert is_setup_completed() is True + + @patch('app_setup.get_setting', return_value=None) + @patch('app_setup.create_default_provider_from_env') + @patch('app_setup.set_setting') + def test_auto_detect_jellyfin_env(self, mock_set, mock_create, mock_get): + """Jellyfin env vars auto-complete setup.""" + import config + with patch.object(config, 'MEDIASERVER_TYPE', 'jellyfin'), \ + patch.object(config, 'JELLYFIN_URL', 'http://jf:8096'), \ + patch.object(config, 'JELLYFIN_TOKEN', 'tok123'), \ + patch.object(config, 'JELLYFIN_USER_ID', 'user1'): + from app_setup import is_setup_completed + result = is_setup_completed() + assert result is True + mock_create.assert_called_once() + + @patch('app_setup.get_setting', return_value=None) + def test_localfiles_requires_wizard(self, mock_get): + """localfiles provider type requires the wizard.""" + import config + with patch.object(config, 'MEDIASERVER_TYPE', 'localfiles'): + from app_setup import is_setup_completed + result = is_setup_completed() + assert result is False + + @patch('app_setup.get_setting', return_value=None) + def test_placeholder_values_not_detected(self, mock_get): + """Placeholder values like 'your_...' are not auto-detected.""" + import config + with patch.object(config, 'MEDIASERVER_TYPE', 'jellyfin'), \ + patch.object(config, 'JELLYFIN_URL', 'http://your_jellyfin_url'), \ + patch.object(config, 'JELLYFIN_TOKEN', 'your_token'), \ + patch.object(config, 'JELLYFIN_USER_ID', 'your_user_id'): + from app_setup import is_setup_completed + result = is_setup_completed() + assert result is False + + +# --------------------------------------------------------------------------- +# Settings Management +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestSettingsManagement: + """Test get/set settings and apply_settings_to_config.""" + + @patch('app_setup.get_setting') + def test_apply_int_setting(self, mock_get_setting): + """Integer settings are type-cast correctly.""" + mock_get_setting.return_value = '10' + import config + original = config.MAX_SONGS_PER_ARTIST_PLAYLIST + try: + from app_setup import apply_settings_to_config + # Only the max_songs_per_artist_playlist key should match + mock_get_setting.side_effect = lambda key: '10' if key == 'max_songs_per_artist_playlist' else None + apply_settings_to_config() + assert config.MAX_SONGS_PER_ARTIST_PLAYLIST == 10 + finally: + config.MAX_SONGS_PER_ARTIST_PLAYLIST = original + + @patch('app_setup.get_setting') + def test_apply_bool_setting(self, mock_get_setting): + """Boolean settings are type-cast correctly.""" + import config + original = config.PLAYLIST_ENERGY_ARC + try: + mock_get_setting.side_effect = lambda key: 'true' if key == 'playlist_energy_arc' else None + from app_setup import apply_settings_to_config + apply_settings_to_config() + assert config.PLAYLIST_ENERGY_ARC is True + finally: + config.PLAYLIST_ENERGY_ARC = original + + +# --------------------------------------------------------------------------- +# API Endpoints +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestSetupEndpoints: + """Test setup API endpoints.""" + + def test_setup_page_renders(self, client): + """GET /setup returns 200.""" + with patch('app_setup.render_template', return_value='setup'): + resp = client.get('/setup') + assert resp.status_code == 200 + assert 'text/html' in resp.content_type + + def test_settings_page_renders(self, client): + """GET /settings returns 200.""" + with patch('app_setup.render_template', return_value='settings'): + resp = client.get('/settings') + assert resp.status_code == 200 + assert 'text/html' in resp.content_type + + @patch('app_setup.get_providers', return_value=[]) + @patch('app_setup.is_setup_completed', return_value=False) + @patch('app_setup.is_multi_provider_enabled', return_value=False) + @patch('app_setup.create_default_provider_from_env') + def test_status_endpoint(self, mock_create, mock_multi, mock_setup, mock_providers, client): + """GET /api/setup/status returns status JSON.""" + resp = client.get('/api/setup/status') + assert resp.status_code == 200 + assert resp.content_type.startswith('application/json') + data = resp.get_json() + assert 'setup_completed' in data + assert 'provider_count' in data + + @patch('app_setup.get_available_provider_types') + @patch('app_setup.get_provider_info') + def test_provider_types_endpoint(self, mock_info, mock_types, client): + """GET /api/setup/providers/types returns provider type list.""" + mock_types.return_value = { + 'jellyfin': {'name': 'Jellyfin', 'description': 'Jellyfin Server', + 'supports_user_auth': True, 'supports_play_history': True} + } + mock_info.return_value = {'config_fields': [{'name': 'url'}]} + resp = client.get('/api/setup/providers/types') + assert resp.status_code == 200 + assert resp.content_type.startswith('application/json') + data = resp.get_json() + assert len(data) == 1 + assert data[0]['type'] == 'jellyfin' + + @patch('app_setup.get_providers', return_value=[]) + def test_list_providers_empty(self, mock_providers, client): + """GET /api/setup/providers returns empty list.""" + resp = client.get('/api/setup/providers') + assert resp.status_code == 200 + assert resp.content_type.startswith('application/json') + assert resp.get_json() == [] + + +# --------------------------------------------------------------------------- +# Provider CRUD +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestProviderCRUD: + """Test provider create/update/delete endpoints.""" + + def test_create_provider_missing_data(self, client): + """POST /api/setup/providers with no body returns 400.""" + resp = client.post('/api/setup/providers', + data='', content_type='application/json') + assert resp.status_code in (400, 415) + + def test_create_provider_missing_type(self, client): + """POST /api/setup/providers without provider_type returns 400.""" + resp = client.post('/api/setup/providers', json={'name': 'Test'}) + assert resp.status_code == 400 + assert 'provider_type' in resp.get_json()['error'] + + def test_create_provider_missing_name(self, client): + """POST /api/setup/providers without name returns 400.""" + resp = client.post('/api/setup/providers', json={'provider_type': 'jellyfin'}) + assert resp.status_code == 400 + assert 'name' in resp.get_json()['error'] + + @patch('app_setup.PROVIDER_TYPES', {'jellyfin': {'name': 'Jellyfin'}}) + @patch('app_setup.validate_provider_config', return_value=(True, [])) + @patch('app_setup.get_providers', return_value=[]) + @patch('app_setup.add_provider', return_value=1) + def test_create_provider_success(self, mock_add, mock_get, mock_validate, client): + """Successful provider creation returns 201.""" + resp = client.post('/api/setup/providers', json={ + 'provider_type': 'jellyfin', + 'name': 'My Jellyfin', + 'config': {'url': 'http://jf:8096', 'user_id': 'u', 'token': 't'} + }) + assert resp.status_code == 201 + assert resp.get_json()['id'] == 1 + + @patch('app_setup.PROVIDER_TYPES', {'jellyfin': {'name': 'Jellyfin'}}) + @patch('app_setup.validate_provider_config', return_value=(True, [])) + @patch('app_setup.get_providers', return_value=[ + {'id': 1, 'provider_type': 'jellyfin', 'name': 'Old', 'config': {}} + ]) + @patch('app_setup.update_provider', return_value=True) + def test_create_provider_upserts_existing(self, mock_update, mock_get, mock_validate, client): + """Creating a provider of existing type upserts instead.""" + resp = client.post('/api/setup/providers', json={ + 'provider_type': 'jellyfin', + 'name': 'Updated Jellyfin', + 'config': {'url': 'http://jf:8096', 'user_id': 'u', 'token': 't'} + }) + assert resp.status_code == 200 + assert resp.get_json().get('was_update') is True + + @patch('app_setup.validate_provider_config', return_value=(False, ['Missing url'])) + @patch('app_setup.PROVIDER_TYPES', {'jellyfin': {'name': 'Jellyfin'}}) + def test_create_provider_validation_failure(self, mock_validate, client): + """Provider creation with invalid config returns 400.""" + resp = client.post('/api/setup/providers', json={ + 'provider_type': 'jellyfin', + 'name': 'Bad Config', + 'config': {} + }) + assert resp.status_code == 400 + assert 'Validation failed' in resp.get_json()['error'] + + @patch('app_setup.delete_provider', return_value=True) + def test_delete_provider_success(self, mock_delete, client): + """DELETE /api/setup/providers/ returns success.""" + resp = client.delete('/api/setup/providers/1') + assert resp.status_code == 200 + + @patch('app_setup.delete_provider', return_value=False) + def test_delete_provider_not_found(self, mock_delete, client): + """DELETE nonexistent provider returns 404.""" + resp = client.delete('/api/setup/providers/999') + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Complete Setup & Multi-Provider +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestSetupCompletion: + """Test setup completion and multi-provider mode.""" + + @patch('app_setup.set_setting') + def test_complete_setup_marks_flag(self, mock_set, client): + """POST /api/setup/complete marks setup as completed.""" + resp = client.post('/api/setup/complete') + assert resp.status_code == 200 + data = resp.get_json() + assert data['setup_completed'] is True + # Verify set_setting was called with setup_completed=True + calls = [c for c in mock_set.call_args_list if c[0][0] == 'setup_completed'] + assert len(calls) >= 1 + + @patch('app_setup.set_setting') + def test_enable_multi_provider(self, mock_set, client): + """POST /api/setup/multi-provider enables multi-provider mode.""" + resp = client.post('/api/setup/multi-provider', json={'enabled': True}) + assert resp.status_code == 200 + data = resp.get_json() + assert data['multi_provider_enabled'] is True + + @patch('app_setup.set_setting') + def test_disable_multi_provider(self, mock_set, client): + """POST /api/setup/multi-provider disables multi-provider mode.""" + resp = client.post('/api/setup/multi-provider', json={'enabled': False}) + assert resp.status_code == 200 + data = resp.get_json() + assert data['multi_provider_enabled'] is False + + def test_multi_provider_no_data(self, client): + """POST /api/setup/multi-provider with no data returns 400.""" + resp = client.post('/api/setup/multi-provider', + data='', content_type='application/json') + assert resp.status_code in (400, 415) + + +# --------------------------------------------------------------------------- +# Settings Endpoints +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestSettingsEndpoints: + """Test settings API endpoints.""" + + @patch('app_setup.get_all_settings', return_value={'general': {'key1': {'value': 'val1'}}}) + def test_get_settings(self, mock_all, client): + """GET /api/setup/settings returns grouped settings.""" + resp = client.get('/api/setup/settings') + assert resp.status_code == 200 + data = resp.get_json() + assert 'general' in data + + @patch('app_setup.set_setting') + @patch('app_setup.apply_settings_to_config') + def test_update_settings(self, mock_apply, mock_set, client): + """PUT /api/setup/settings updates settings and applies them.""" + resp = client.put('/api/setup/settings', json={'ai_provider': 'GEMINI'}) + assert resp.status_code == 200 + mock_set.assert_called_once_with('ai_provider', 'GEMINI') + mock_apply.assert_called_once() + + def test_update_settings_no_data(self, client): + """PUT /api/setup/settings with no data returns 400.""" + resp = client.put('/api/setup/settings', + data='', content_type='application/json') + assert resp.status_code in (400, 415) + + +# --------------------------------------------------------------------------- +# Provider Update Endpoint +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestProviderUpdateEndpoint: + """Test PUT /api/setup/providers/ endpoint.""" + + @patch('app_setup.get_provider_by_id', return_value=None) + def test_update_nonexistent_provider_returns_404(self, mock_get, client): + """PUT on nonexistent provider returns 404.""" + resp = client.put('/api/setup/providers/999', json={'name': 'New Name'}) + assert resp.status_code == 404 + assert resp.content_type.startswith('application/json') + assert 'error' in resp.get_json() + + @patch('app_setup.get_provider_by_id', return_value={ + 'id': 1, 'provider_type': 'jellyfin', 'name': 'Jelly', + 'config': {'url': 'http://jf:8096', 'user_id': 'u', 'token': 't'}, + 'enabled': True, 'priority': 0, + }) + def test_update_provider_no_data_returns_400(self, mock_get, client): + """PUT with empty body returns 400.""" + resp = client.put('/api/setup/providers/1', + data='', content_type='application/json') + assert resp.status_code in (400, 415) + + @patch('app_setup.update_provider', return_value=True) + @patch('app_setup.validate_provider_config', return_value=(True, [])) + @patch('app_setup.get_provider_by_id', return_value={ + 'id': 1, 'provider_type': 'jellyfin', 'name': 'Jelly', + 'config': {'url': 'http://jf:8096', 'user_id': 'u', 'token': 't'}, + 'enabled': True, 'priority': 0, + }) + def test_update_provider_success(self, mock_get, mock_validate, mock_update, client): + """Successful provider update returns 200.""" + resp = client.put('/api/setup/providers/1', json={ + 'name': 'Updated Jellyfin', + 'config': {'url': 'http://jf:8096', 'user_id': 'u2', 'token': 't2'} + }) + assert resp.status_code == 200 + assert resp.content_type.startswith('application/json') + assert 'message' in resp.get_json() + + +# --------------------------------------------------------------------------- +# Provider Test Endpoints +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestProviderTestEndpoints: + """Test provider connection test endpoints.""" + + @patch('app_setup.get_provider_by_id', return_value=None) + def test_test_saved_provider_not_found(self, mock_get, client): + """POST /api/setup/providers//test returns 404 for missing provider.""" + resp = client.post('/api/setup/providers/999/test') + assert resp.status_code == 404 + assert 'error' in resp.get_json() + + @patch('app_setup.test_provider_connection', return_value=(True, 'Connection OK')) + @patch('app_setup.get_provider_by_id', return_value={ + 'id': 1, 'provider_type': 'jellyfin', 'name': 'Jelly', + 'config': {'url': 'http://jf:8096', 'user_id': 'u', 'token': 't'}, + 'enabled': True, 'priority': 0, + }) + def test_test_saved_provider_success(self, mock_get, mock_test, client): + """POST /api/setup/providers//test returns success result.""" + resp = client.post('/api/setup/providers/1/test') + assert resp.status_code == 200 + data = resp.get_json() + assert data['success'] is True + assert 'message' in data + assert data['provider_type'] == 'jellyfin' + + def test_test_unsaved_provider_no_data(self, client): + """POST /api/setup/providers/test with no data returns 400.""" + resp = client.post('/api/setup/providers/test', + data='', content_type='application/json') + assert resp.status_code in (400, 415) + + def test_test_unsaved_provider_missing_type(self, client): + """POST /api/setup/providers/test without provider_type returns 400.""" + resp = client.post('/api/setup/providers/test', json={'config': {}}) + assert resp.status_code == 400 + assert 'error' in resp.get_json() + + @patch('app_setup.test_provider_connection', return_value=(False, 'Connection refused')) + def test_test_unsaved_provider_failure(self, mock_test, client): + """POST /api/setup/providers/test with failing connection.""" + resp = client.post('/api/setup/providers/test', json={ + 'provider_type': 'jellyfin', + 'config': {'url': 'http://bad:8096'}, + 'detect_prefix': False + }) + assert resp.status_code == 200 + data = resp.get_json() + assert data['success'] is False + assert 'message' in data + + +# --------------------------------------------------------------------------- +# Browse Directories Endpoint +# --------------------------------------------------------------------------- + +@pytest.mark.integration +class TestBrowseDirectoriesEndpoint: + """Test GET /api/setup/browse-directories endpoint.""" + + def test_path_traversal_rejected(self, client): + """Path traversal with '..' is rejected.""" + resp = client.get('/api/setup/browse-directories?path=/music/../etc') + assert resp.status_code == 400 + assert 'error' in resp.get_json() diff --git a/tests/unit/test_mediaserver_localfiles.py b/tests/unit/test_mediaserver_localfiles.py new file mode 100644 index 00000000..71f8f8ae --- /dev/null +++ b/tests/unit/test_mediaserver_localfiles.py @@ -0,0 +1,620 @@ +"""Unit tests for tasks/mediaserver_localfiles.py + +Tests cover the LocalFiles media provider: +- Path normalization (POSIX conversion, relative paths) +- File path hashing (SHA-256 stability) +- Supported format filtering +- Metadata extraction (tags, fallbacks) +- Rating extraction (POPM, TXXX, Vorbis, M4A) +- M3U playlist management (create, list, delete) +- Directory scanning (recursive, flat) +- Connection testing +""" +import os +import sys +import hashlib +import pytest +from unittest.mock import Mock, MagicMock, patch, mock_open +from pathlib import Path, PurePosixPath + + +# --------------------------------------------------------------------------- +# Import helpers (bypass tasks/__init__.py -> pydub -> audioop chain) +# --------------------------------------------------------------------------- + +def _import_localfiles(): + """Load tasks.mediaserver_localfiles directly without triggering tasks/__init__.py.""" + import importlib.util + import sys + mod_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), '..', '..', 'tasks', 'mediaserver_localfiles.py' + ) + mod_path = os.path.normpath(mod_path) + mod_name = 'tasks.mediaserver_localfiles' + if mod_name not in sys.modules: + spec = importlib.util.spec_from_file_location(mod_name, mod_path) + mod = importlib.util.module_from_spec(spec) + sys.modules[mod_name] = mod + spec.loader.exec_module(mod) + return sys.modules[mod_name] + + +# --------------------------------------------------------------------------- +# Path Normalization +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestPathNormalization: + """Test normalize_file_path().""" + + def test_posix_conversion(self): + mod = _import_localfiles() + result = mod.normalize_file_path('Artist\\Album\\song.mp3') + assert '\\' not in result + assert 'Artist/Album/song.mp3' == result + + @pytest.mark.skipif(sys.platform == 'win32', reason='POSIX absolute paths not valid on Windows') + def test_relative_to_base(self): + mod = _import_localfiles() + result = mod.normalize_file_path('/music/Artist/Album/song.mp3', '/music') + assert result == 'Artist/Album/song.mp3' + + @pytest.mark.skipif(sys.platform == 'win32', reason='POSIX absolute paths not valid on Windows') + def test_no_base_keeps_absolute(self): + mod = _import_localfiles() + result = mod.normalize_file_path('/music/Artist/song.mp3', '') + # Without base_path, absolute path stays (converted to POSIX) + assert result.startswith('/') + + def test_whitespace_stripped(self): + mod = _import_localfiles() + result = mod.normalize_file_path(' Artist/song.mp3 ') + assert result == 'Artist/song.mp3' + + def test_different_base_keeps_original(self): + mod = _import_localfiles() + # If path is not relative to base, keep as-is + result = mod.normalize_file_path('/other/Artist/song.mp3', '/music') + assert 'Artist/song.mp3' in result + + +# --------------------------------------------------------------------------- +# File Path Hash +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestFilePathHash: + """Test file_path_hash() SHA-256 generation.""" + + def test_deterministic(self): + mod = _import_localfiles() + h1 = mod.file_path_hash('Artist/Album/song.mp3') + h2 = mod.file_path_hash('Artist/Album/song.mp3') + assert h1 == h2 + + def test_different_paths_different_hashes(self): + mod = _import_localfiles() + h1 = mod.file_path_hash('Artist/Album/song1.mp3') + h2 = mod.file_path_hash('Artist/Album/song2.mp3') + assert h1 != h2 + + def test_is_sha256_hex(self): + mod = _import_localfiles() + h = mod.file_path_hash('test/path.mp3') + assert len(h) == 64 # SHA-256 hex = 64 chars + assert all(c in '0123456789abcdef' for c in h) + + def test_matches_manual_sha256(self): + mod = _import_localfiles() + path = 'Artist/Album/song.mp3' + expected = hashlib.sha256(path.encode('utf-8')).hexdigest() + assert mod.file_path_hash(path) == expected + + def test_utf8_paths(self): + mod = _import_localfiles() + h = mod.file_path_hash('Artiste/Café/chanson.mp3') + assert len(h) == 64 + + +# --------------------------------------------------------------------------- +# Supported Formats +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestFormatFiltering: + """Test SUPPORTED_FORMATS constant and format-related logic.""" + + def test_supported_formats_exist(self): + mod = _import_localfiles() + assert '.mp3' in mod.SUPPORTED_FORMATS + assert '.flac' in mod.SUPPORTED_FORMATS + assert '.ogg' in mod.SUPPORTED_FORMATS + assert '.m4a' in mod.SUPPORTED_FORMATS + + def test_wav_supported(self): + mod = _import_localfiles() + assert '.wav' in mod.SUPPORTED_FORMATS + + def test_opus_supported(self): + mod = _import_localfiles() + assert '.opus' in mod.SUPPORTED_FORMATS + + def test_unsupported_format_excluded(self): + mod = _import_localfiles() + assert '.pdf' not in mod.SUPPORTED_FORMATS + assert '.txt' not in mod.SUPPORTED_FORMATS + assert '.jpg' not in mod.SUPPORTED_FORMATS + + def test_get_config_default_formats(self): + """get_config returns SUPPORTED_FORMATS as default.""" + mod = _import_localfiles() + with patch.dict(os.environ, {}, clear=False): + cfg = mod.get_config() + # Formats should be a list of supported extensions + assert isinstance(cfg['supported_formats'], list) + assert len(cfg['supported_formats']) > 0 + + def test_get_config_override(self): + """get_config accepts overrides.""" + mod = _import_localfiles() + cfg = mod.get_config(overrides={'music_directory': '/custom/path'}) + assert cfg['music_directory'] == '/custom/path' + + +# --------------------------------------------------------------------------- +# Metadata Extraction +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestMetadataExtraction: + """Test extract_metadata() with mocked mutagen.""" + + def test_fallback_title_from_filename(self): + """When mutagen returns None, title defaults to filename.""" + mod = _import_localfiles() + with patch.object(mod, 'MUTAGEN_AVAILABLE', False): + meta = mod.extract_metadata('/music/Artist/My Song.mp3') + assert meta['title'] == 'My Song' + assert meta['artist'] == 'Unknown Artist' + assert meta['album'] == 'Unknown Album' + + def _inject_mutagen_mock(self, mod, mock_file): + """Inject MutagenFile into module if mutagen isn't installed.""" + if not hasattr(mod, 'MutagenFile'): + mod.MutagenFile = Mock() + return patch.object(mod, 'MutagenFile', mock_file) + + def test_mutagen_extracts_tags(self): + """When mutagen is available, tags are extracted.""" + mod = _import_localfiles() + mock_audio = MagicMock() + mock_audio.tags = { + 'title': ['Test Song'], + 'artist': ['Test Artist'], + 'album': ['Test Album'], + 'albumartist': ['Album Artist'], + 'date': ['2023'], + 'tracknumber': ['5/12'], + 'genre': ['Rock'], + } + mock_audio.info = MagicMock() + mock_audio.info.length = 180.5 + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen_mock(mod, mock_mutagen), \ + patch.object(mod, '_extract_rating', return_value=None): + meta = mod.extract_metadata('/music/test.mp3') + assert meta['title'] == 'Test Song' + assert meta['artist'] == 'Test Artist' + assert meta['album'] == 'Test Album' + assert meta['album_artist'] == 'Album Artist' + assert meta['year'] == 2023 + assert meta['track_number'] == 5 + assert meta['genre'] == 'Rock' + assert meta['duration'] == 180.5 + + def test_track_number_slash_format(self): + """Track number '3/12' extracts as 3.""" + mod = _import_localfiles() + mock_audio = MagicMock() + mock_audio.tags = {'tracknumber': ['3/12']} + mock_audio.info = None + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen_mock(mod, mock_mutagen), \ + patch.object(mod, '_extract_rating', return_value=None): + meta = mod.extract_metadata('/music/test.mp3') + assert meta['track_number'] == 3 + + def test_performer_fallback_for_artist(self): + """If 'artist' tag missing but 'performer' present, use performer.""" + mod = _import_localfiles() + mock_audio = MagicMock() + mock_audio.tags = {'performer': ['Performer Name']} + mock_audio.info = None + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen_mock(mod, mock_mutagen), \ + patch.object(mod, '_extract_rating', return_value=None): + meta = mod.extract_metadata('/music/test.mp3') + assert meta['artist'] == 'Performer Name' + + def test_mutagen_returns_none(self): + """When MutagenFile returns None, defaults are used.""" + mod = _import_localfiles() + mock_mutagen = Mock(return_value=None) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen_mock(mod, mock_mutagen): + meta = mod.extract_metadata('/music/test.mp3') + assert meta['title'] == 'test' + assert meta['artist'] == 'Unknown Artist' + + def test_exception_returns_defaults(self): + """Exceptions during extraction return default metadata.""" + mod = _import_localfiles() + mock_mutagen = Mock(side_effect=Exception('corrupt file')) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen_mock(mod, mock_mutagen): + meta = mod.extract_metadata('/music/test.mp3') + assert meta['title'] == 'test' + assert meta['artist'] == 'Unknown Artist' + + +# --------------------------------------------------------------------------- +# Rating Extraction +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestRatingExtraction: + """Test _extract_rating() for various tag formats.""" + + def _inject_mutagen(self, mod, mock_file): + """Ensure MutagenFile exists on module so patch.object works.""" + if not hasattr(mod, 'MutagenFile'): + mod.MutagenFile = Mock() + return patch.object(mod, 'MutagenFile', mock_file) + + def _make_popm_audio(self, popm_rating): + """Build a mock audio object with POPM tag.""" + mock_popm = MagicMock() + mock_popm.rating = popm_rating + mock_audio = MagicMock() + mock_tags = MagicMock() + mock_tags.keys.return_value = ['POPM:no@email'] + mock_tags.__getitem__ = Mock(return_value=mock_popm) + mock_audio.tags = mock_tags + return mock_audio + + def _make_flac_audio(self, tag_dict): + """Build a mock audio object with Vorbis-style tags (dict-like).""" + mock_audio = MagicMock() + # Use a MagicMock that supports dict operations + mock_tags = MagicMock() + mock_tags.__contains__ = lambda self, key: key in tag_dict + mock_tags.__getitem__ = lambda self, key: tag_dict[key] + mock_tags.__bool__ = lambda self: bool(tag_dict) + mock_audio.tags = mock_tags + return mock_audio + + def test_no_mutagen_returns_none(self): + mod = _import_localfiles() + with patch.object(mod, 'MUTAGEN_AVAILABLE', False): + assert mod._extract_rating('/test.mp3') is None + + def test_popm_rating_zero(self): + """POPM rating 0 maps to 0.""" + mod = _import_localfiles() + mock_audio = self._make_popm_audio(0) + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen(mod, mock_mutagen): + result = mod._extract_rating('/test.mp3') + assert result == 0 + + def test_popm_rating_255_maps_to_5(self): + """POPM rating 255 maps to 5.""" + mod = _import_localfiles() + mock_audio = self._make_popm_audio(255) + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen(mod, mock_mutagen): + result = mod._extract_rating('/test.mp3') + assert result == 5 + + def test_popm_rating_128_maps_to_3(self): + """POPM rating 128 maps to 3.""" + mod = _import_localfiles() + mock_audio = self._make_popm_audio(128) + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen(mod, mock_mutagen): + result = mod._extract_rating('/test.mp3') + assert result == 3 + + def test_flac_fmps_rating_0_5(self): + """FLAC FMPS_RATING 0.5 maps to round(0.5*5)=3.""" + mod = _import_localfiles() + mock_audio = self._make_flac_audio({'FMPS_RATING': ['0.5']}) + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen(mod, mock_mutagen): + result = mod._extract_rating('/test.flac') + # 0.5 * 5 = 2.5, round() = 2 (banker's rounding) + assert result == 2 + + def test_flac_rating_direct_scale(self): + """FLAC RATING tag with direct 0-5 value.""" + mod = _import_localfiles() + mock_audio = self._make_flac_audio({'RATING': ['4']}) + mock_mutagen = Mock(return_value=mock_audio) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen(mod, mock_mutagen): + result = mod._extract_rating('/test.flac') + assert result == 4 + + def test_mutagen_file_none_returns_none(self): + """MutagenFile returning None gives None rating.""" + mod = _import_localfiles() + mock_mutagen = Mock(return_value=None) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen(mod, mock_mutagen): + assert mod._extract_rating('/test.mp3') is None + + def test_exception_returns_none(self): + """Exceptions during rating extraction return None.""" + mod = _import_localfiles() + mock_mutagen = Mock(side_effect=Exception('error')) + with patch.object(mod, 'MUTAGEN_AVAILABLE', True), \ + self._inject_mutagen(mod, mock_mutagen): + assert mod._extract_rating('/test.mp3') is None + + +# --------------------------------------------------------------------------- +# M3U Playlist Management +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestM3UPlaylistManagement: + """Test M3U playlist create/list/delete operations.""" + + def test_get_all_playlists_no_dir(self): + """Missing playlist directory returns empty list.""" + mod = _import_localfiles() + with patch.object(mod, 'get_config', return_value={'playlist_directory': '/nonexistent'}), \ + patch('os.path.isdir', return_value=False): + result = mod.get_all_playlists() + assert result == [] + + def test_get_all_playlists_finds_m3u(self): + """Lists .m3u and .m3u8 files.""" + mod = _import_localfiles() + with patch.object(mod, 'get_config', return_value={'playlist_directory': '/playlists'}), \ + patch('os.path.isdir', return_value=True), \ + patch('os.listdir', return_value=['rock.m3u', 'jazz.m3u8', 'notes.txt']): + result = mod.get_all_playlists() + names = [p['Name'] for p in result] + assert 'rock' in names + assert 'jazz' in names + assert len(result) == 2 + + def test_get_playlist_by_name(self): + """Find a playlist by exact name.""" + mod = _import_localfiles() + with patch.object(mod, 'get_all_playlists', return_value=[ + {'Id': 'rock.m3u', 'Name': 'rock', 'Path': '/p/rock.m3u'}, + {'Id': 'jazz.m3u', 'Name': 'jazz', 'Path': '/p/jazz.m3u'}, + ]): + result = mod.get_playlist_by_name('jazz') + assert result is not None + assert result['Name'] == 'jazz' + + def test_get_playlist_by_name_not_found(self): + """Non-existent playlist returns None.""" + mod = _import_localfiles() + with patch.object(mod, 'get_all_playlists', return_value=[]): + assert mod.get_playlist_by_name('nonexistent') is None + + def test_delete_playlist_success(self): + """Deleting an existing playlist returns True.""" + mod = _import_localfiles() + with patch.object(mod, 'get_config', return_value={'playlist_directory': '/playlists'}), \ + patch('os.path.exists', return_value=True), \ + patch('os.remove') as mock_rm: + result = mod.delete_playlist('rock.m3u') + assert result is True + mock_rm.assert_called_once() + + def test_delete_playlist_not_found(self): + """Deleting a nonexistent playlist returns False.""" + mod = _import_localfiles() + with patch.object(mod, 'get_config', return_value={'playlist_directory': '/playlists'}), \ + patch('os.path.exists', return_value=False): + result = mod.delete_playlist('missing.m3u') + assert result is False + + +# --------------------------------------------------------------------------- +# Connection Testing +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestConnectionTesting: + """Test test_connection().""" + + def test_missing_directory(self): + mod = _import_localfiles() + with patch('os.path.exists', return_value=False): + ok, msg = mod.test_connection({'music_directory': '/nonexistent'}) + assert not ok + assert 'does not exist' in msg + + def test_not_a_directory(self): + mod = _import_localfiles() + with patch('os.path.exists', return_value=True), \ + patch('os.path.isdir', return_value=False): + ok, msg = mod.test_connection({'music_directory': '/music/file.txt'}) + assert not ok + assert 'not a directory' in msg + + def test_not_readable(self): + mod = _import_localfiles() + with patch('os.path.exists', return_value=True), \ + patch('os.path.isdir', return_value=True), \ + patch('os.access', return_value=False): + ok, msg = mod.test_connection({'music_directory': '/music'}) + assert not ok + assert 'not readable' in msg + + def test_no_audio_files(self): + mod = _import_localfiles() + with patch('os.path.exists', return_value=True), \ + patch('os.path.isdir', return_value=True), \ + patch('os.access', return_value=True), \ + patch('os.walk', return_value=[('/music', [], ['readme.txt'])]): + ok, msg = mod.test_connection({'music_directory': '/music'}) + assert not ok + assert 'No audio files' in msg + + def test_success_with_audio_files(self): + mod = _import_localfiles() + with patch('os.path.exists', return_value=True), \ + patch('os.path.isdir', return_value=True), \ + patch('os.access', return_value=True), \ + patch('os.walk', return_value=[('/music', [], ['song.mp3', 'track.flac'])]): + ok, msg = mod.test_connection({'music_directory': '/music'}) + assert ok + assert 'Found audio files' in msg + + +# --------------------------------------------------------------------------- +# Directory Scanning +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestDirectoryScanning: + """Test get_all_songs() directory scanning.""" + + def test_nonexistent_dir_returns_empty(self): + mod = _import_localfiles() + with patch.object(mod, 'get_config', return_value={ + 'music_directory': '/nonexistent', + 'supported_formats': ['.mp3'], + 'scan_subdirectories': True + }), patch('os.path.isdir', return_value=False): + result = mod.get_all_songs() + assert result == [] + + def test_recursive_scan(self): + """Recursive scan finds files in subdirectories.""" + mod = _import_localfiles() + walk_data = [ + ('/music', ['Artist'], ['root.mp3']), + ('/music/Artist', [], ['song.flac']), + ] + with patch.object(mod, 'get_config', return_value={ + 'music_directory': '/music', + 'supported_formats': ['.mp3', '.flac'], + 'scan_subdirectories': True + }), patch('os.path.isdir', return_value=True), \ + patch('os.walk', return_value=walk_data), \ + patch.object(mod, '_format_song', side_effect=lambda fp, bp: { + 'Id': 'hash', 'Name': os.path.basename(fp), 'Path': fp + }): + result = mod.get_all_songs() + assert len(result) == 2 + + def test_flat_scan(self): + """Non-recursive scan only finds files in the root.""" + mod = _import_localfiles() + with patch.object(mod, 'get_config', return_value={ + 'music_directory': '/music', + 'supported_formats': ['.mp3'], + 'scan_subdirectories': False + }), patch('os.path.isdir', return_value=True), \ + patch('os.listdir', return_value=['song.mp3', 'notes.txt', 'track.mp3']), \ + patch('os.path.isfile', return_value=True), \ + patch.object(mod, '_format_song', side_effect=lambda fp, bp: { + 'Id': 'hash', 'Name': os.path.basename(fp), 'Path': fp + }): + result = mod.get_all_songs() + assert len(result) == 2 # Only .mp3 files + + def test_unsupported_format_skipped(self): + """Files with unsupported extensions are skipped.""" + mod = _import_localfiles() + walk_data = [('/music', [], ['song.mp3', 'image.jpg', 'doc.pdf'])] + with patch.object(mod, 'get_config', return_value={ + 'music_directory': '/music', + 'supported_formats': ['.mp3'], + 'scan_subdirectories': True + }), patch('os.path.isdir', return_value=True), \ + patch('os.walk', return_value=walk_data), \ + patch.object(mod, '_format_song', side_effect=lambda fp, bp: { + 'Id': 'hash', 'Name': os.path.basename(fp), 'Path': fp + }): + result = mod.get_all_songs() + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# Download Track +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestDownloadTrack: + """Test download_track() (copy to temp dir).""" + + def test_missing_source_returns_none(self): + mod = _import_localfiles() + result = mod.download_track('/tmp', {'Path': '/nonexistent/file.mp3'}) + assert result is None + + def test_no_path_returns_none(self): + mod = _import_localfiles() + result = mod.download_track('/tmp', {'Id': '123'}) + assert result is None + + def test_successful_copy(self): + mod = _import_localfiles() + with patch('os.path.exists', return_value=True), \ + patch('shutil.copy2') as mock_copy: + # First exists check is for source, second is for dest collision + with patch('os.path.exists', side_effect=[True, False]): + result = mod.download_track('/tmp', { + 'Path': '/music/Artist/song.mp3', + 'Name': 'song' + }) + assert result is not None + + +# --------------------------------------------------------------------------- +# Provider Info +# --------------------------------------------------------------------------- + +@pytest.mark.unit +class TestProviderInfo: + """Test provider info metadata.""" + + def test_provider_type(self): + mod = _import_localfiles() + info = mod.get_provider_info() + assert info['type'] == 'localfiles' + + def test_no_play_history(self): + mod = _import_localfiles() + info = mod.get_provider_info() + assert info['supports_play_history'] is False + + def test_config_fields_include_music_directory(self): + mod = _import_localfiles() + info = mod.get_provider_info() + field_names = [f['name'] for f in info['config_fields']] + assert 'music_directory' in field_names + + def test_top_played_returns_empty(self): + mod = _import_localfiles() + assert mod.get_top_played_songs(10) == [] + + def test_last_played_returns_none(self): + mod = _import_localfiles() + assert mod.get_last_played_time('id123') is None diff --git a/tests/unit/test_memory_cleanup.py b/tests/unit/test_memory_cleanup.py index 868c09e4..763d2fce 100644 --- a/tests/unit/test_memory_cleanup.py +++ b/tests/unit/test_memory_cleanup.py @@ -232,26 +232,31 @@ def test_cleanup_onnx_sessions_on_success( ): """Test that ONNX sessions are cleaned up after successful album analysis.""" from tasks.analysis import analyze_album_task - + # Setup mocks mock_get_job.return_value = None mock_get_tracks.return_value = [ {'Id': '1', 'Name': 'Track 1', 'AlbumArtist': 'Artist 1', 'ArtistId': 'artist1'} ] mock_download.return_value = "/tmp/track.mp3" - - # Mock database + + # Mock database - must work as context manager (with get_db() as conn) mock_conn = MagicMock() mock_cur = MagicMock() + mock_conn.__enter__ = Mock(return_value=mock_conn) + mock_conn.__exit__ = Mock(return_value=False) mock_conn.cursor.return_value = mock_cur + mock_cur.__enter__ = Mock(return_value=mock_cur) + mock_cur.__exit__ = Mock(return_value=False) mock_cur.fetchall.return_value = [] # No existing tracks + mock_cur.fetchone.return_value = None mock_get_db.return_value = mock_conn - + # Mock ONNX sessions mock_ort.get_available_providers.return_value = ['CPUExecutionProvider'] mock_session = MagicMock() mock_ort.InferenceSession.return_value = mock_session - + # Mock analyze_track to return results mock_analyze.return_value = ( { @@ -269,15 +274,17 @@ def test_cleanup_onnx_sessions_on_success( }, np.random.randn(200) ) - - # Call function - with patch('tasks.clap_analyzer.is_clap_available', return_value=False): - with patch('config.MULAN_ENABLED', False): - result = analyze_album_task("album_123", "Test Album", 5, None) - + + # Call function - mock multi-provider functions added in this branch + with patch('tasks.clap_analyzer.is_clap_available', return_value=False), \ + patch('config.MULAN_ENABLED', False), \ + patch('app_helper.find_existing_analysis_by_file_path', return_value=None), \ + patch('app_helper_artist.upsert_artist_mapping'): + result = analyze_album_task("album_123", "Test Album", 5, None) + # Verify session cleanup was called for all loaded sessions # Should be called 2 times (embedding + prediction; secondary models removed in v4.0.0) assert mock_session_cleanup.call_count >= 2 - + # Verify CUDA cleanup was called assert mock_cuda_cleanup.called diff --git a/tests/unit/test_playlist_ordering.py b/tests/unit/test_playlist_ordering.py index 53773143..45918a15 100644 --- a/tests/unit/test_playlist_ordering.py +++ b/tests/unit/test_playlist_ordering.py @@ -1,123 +1,538 @@ -""" -Tests for tasks/playlist_ordering.py — Greedy nearest-neighbor playlist ordering. - -Tests verify: -- Composite distance calculation (tempo, energy, key weighting) -- Circle of Fifths key distance computation -- Greedy nearest-neighbor algorithm -- Energy arc reshaping -- Handling of songs missing from database +"""Unit tests for playlist ordering module + +Tests cover the composite distance calculation, Circle of Fifths key distance, +greedy nearest-neighbor ordering algorithm, energy arc shaping, and edge cases. +All tests run without external services using unittest.mock for database calls. """ import pytest -from unittest.mock import Mock, patch, MagicMock - -from tests.conftest import _import_module, make_dict_row, make_mock_connection +from unittest.mock import patch, MagicMock, Mock -def _load_playlist_ordering(): - """Load playlist_ordering module via importlib to bypass tasks/__init__.py.""" - return _import_module('tasks.playlist_ordering', 'tasks/playlist_ordering.py') +def _import_ordering(): + """Import playlist_ordering directly, bypassing tasks/__init__.py which + pulls in heavyweight deps (pydub, librosa) not needed for these tests.""" + import importlib.util, os, sys + mod_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), '..', '..', 'tasks', 'playlist_ordering.py' + ) + mod_path = os.path.normpath(mod_path) + if 'tasks.playlist_ordering' not in sys.modules: + spec = importlib.util.spec_from_file_location('tasks.playlist_ordering', mod_path) + mod = importlib.util.module_from_spec(spec) + sys.modules['tasks.playlist_ordering'] = mod + spec.loader.exec_module(mod) + mod = sys.modules['tasks.playlist_ordering'] + return ( + mod._key_distance, + mod._composite_distance, + mod.order_playlist, + mod._apply_energy_arc, + mod.CIRCLE_OF_FIFTHS, + ) +@pytest.mark.unit class TestKeyDistance: - """Test _key_distance() function — Circle of Fifths distance.""" - - def test_identical_keys_same_scale(self): - """Same key, same scale → distance = 0.""" - mod = _load_playlist_ordering() - dist = mod._key_distance('C', 'major', 'C', 'major') - assert dist == 0.0 - - def test_adjacent_keys_without_scale_bonus(self): - """C→G is 1 step / 6 max ≈ 0.167.""" - mod = _load_playlist_ordering() - dist = mod._key_distance('C', None, 'G', None) - assert abs(dist - 1/6) < 0.01 - - def test_missing_key_returns_neutral(self): - """Missing key → return 0.5 (neutral).""" - mod = _load_playlist_ordering() - dist = mod._key_distance(None, None, 'C', None) - assert dist == 0.5 - - def test_unknown_key_returns_neutral(self): - """Unknown key → return 0.5 (neutral).""" - mod = _load_playlist_ordering() - dist = mod._key_distance('C', None, 'XYZ', None) - assert dist == 0.5 - - def test_case_insensitive(self): - """Keys are uppercased → 'c' should match 'C'.""" - mod = _load_playlist_ordering() - dist1 = mod._key_distance('C', None, 'G', None) - dist2 = mod._key_distance('c', None, 'g', None) - assert abs(dist1 - dist2) < 0.01 + def test_same_key_zero_distance(self): + kd, *_ = _import_ordering() + assert kd("C", "major", "C", "major") == 0.0 + + def test_same_key_different_scale_zero(self): + kd, *_ = _import_ordering() + assert kd("G", "major", "G", "minor") == 0.0 + + def test_adjacent_key_c_to_g(self): + kd, *_ = _import_ordering() + assert kd("C", None, "G", None) == pytest.approx(1.0/6.0) + + def test_adjacent_key_c_to_f(self): + kd, *_ = _import_ordering() + assert kd("C", None, "F", None) == pytest.approx(1.0/6.0) + + def test_opposite_key_c_to_fsharp(self): + kd, *_ = _import_ordering() + assert kd("C", None, "F#", None) == pytest.approx(1.0) + + def test_opposite_key_c_to_gb(self): + kd, *_ = _import_ordering() + assert kd("C", None, "Gb", None) == pytest.approx(1.0) + + def test_two_steps_c_to_d(self): + kd, *_ = _import_ordering() + assert kd("C", None, "D", None) == pytest.approx(2.0/6.0) + + def test_three_steps_c_to_a(self): + kd, *_ = _import_ordering() + assert kd("C", None, "A", None) == pytest.approx(0.5) + + def test_four_steps_c_to_e(self): + kd, *_ = _import_ordering() + assert kd("C", None, "E", None) == pytest.approx(4.0/6.0) + + def test_five_steps_c_to_b(self): + kd, *_ = _import_ordering() + assert kd("C", None, "B", None) == pytest.approx(5.0/6.0) + + def test_symmetry(self): + kd, *_ = _import_ordering() + assert kd("A", None, "E", None) == kd("E", None, "A", None) + + def test_same_scale_bonus_reduces_distance(self): + kd, *_ = _import_ordering() + d_no = kd("C", None, "D", None) + d_same = kd("C", "major", "D", "major") + assert d_same == pytest.approx(d_no * 0.8) + + def test_different_scale_no_bonus(self): + kd, *_ = _import_ordering() + d_diff = kd("C", "major", "D", "minor") + d_no = kd("C", None, "D", None) + assert d_diff == pytest.approx(d_no) + + def test_missing_key1_returns_neutral(self): + kd, *_ = _import_ordering() + assert kd(None, "major", "C", "major") == 0.5 + assert kd("", "major", "C", "major") == 0.5 + + def test_missing_key2_returns_neutral(self): + kd, *_ = _import_ordering() + assert kd("C", "major", None, "major") == 0.5 + assert kd("C", "major", "", "major") == 0.5 + + def test_both_keys_missing_returns_neutral(self): + kd, *_ = _import_ordering() + assert kd(None, None, None, None) == 0.5 + + def test_unknown_key_name_returns_neutral(self): + kd, *_ = _import_ordering() + assert kd("X", None, "C", None) == 0.5 + assert kd("C", None, "Z", None) == 0.5 + def test_case_insensitive_keys(self): + kd, *_ = _import_ordering() + assert kd("c", None, "g", None) == kd("C", None, "G", None) + def test_sharp_flat_enharmonic_equivalents(self): + kd, *_ = _import_ordering() + assert kd("C#", None, "Db", None) == 0.0 + assert kd("D#", None, "Eb", None) == 0.0 + assert kd("G#", None, "Ab", None) == 0.0 + assert kd("A#", None, "Bb", None) == 0.0 + + def test_scale_comparison_case_insensitive(self): + kd, *_ = _import_ordering() + d1 = kd("C", "major", "D", "major") + d2 = kd("C", "Major", "D", "MAJOR") + assert d1 == pytest.approx(d2) + + +@pytest.mark.unit class TestCompositeDistance: - """Test _composite_distance() function — Weighted combination.""" - - def test_identical_songs(self): - """Same song data → distance = 0.""" - mod = _load_playlist_ordering() - song = {'tempo': 120, 'energy': 0.08, 'key': 'C', 'scale': 'major'} - dist = mod._composite_distance(song, song) - assert dist == 0.0 - - def test_tempo_difference(self): - """Different tempos → distance reflects tempo weight (0.35).""" - mod = _load_playlist_ordering() - song1 = {'tempo': 80, 'energy': 0.05, 'key': 'C', 'scale': 'major'} - song2 = {'tempo': 160, 'energy': 0.05, 'key': 'C', 'scale': 'major'} - # Tempo diff: |160-80|/80 = 1.0, capped at 1.0 - # Dist: 0.35*1.0 = 0.35 - dist = mod._composite_distance(song1, song2) - assert abs(dist - 0.35) < 0.01 - - def test_energy_capped_at_one(self): - """Large energy diff > 0.14 → capped at 1.0.""" - mod = _load_playlist_ordering() - song1 = {'tempo': 100, 'energy': 0.01, 'key': 'C', 'scale': None} - song2 = {'tempo': 100, 'energy': 0.15, 'key': 'C', 'scale': None} - dist = mod._composite_distance(song1, song2) - assert abs(dist - 0.35) < 0.01 # energy weight = 0.35 - - def test_missing_values_as_zero(self): - """Missing tempo/energy → treated as 0.""" - mod = _load_playlist_ordering() - song1 = {'tempo': None, 'energy': None, 'key': 'C', 'scale': None} - song2 = {'tempo': 100, 'energy': 0.10, 'key': 'C', 'scale': None} - dist = mod._composite_distance(song1, song2) - assert dist > 0 + def test_identical_songs_zero_distance(self): + _, cd, *_ = _import_ordering() + s = {"tempo": 120, "energy": 0.08, "key": "C", "scale": "major"} + assert cd(s, s) == 0.0 + + def test_tempo_weight_35_percent(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 100, "energy": 0.08, "key": "C", "scale": "major"} + b = {"tempo": 180, "energy": 0.08, "key": "C", "scale": "major"} + assert cd(a, b) == pytest.approx(0.35) + + def test_energy_weight_35_percent(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 120, "energy": 0.01, "key": "C", "scale": "major"} + b = {"tempo": 120, "energy": 0.15, "key": "C", "scale": "major"} + assert cd(a, b) == pytest.approx(0.35) + + def test_key_weight_30_percent(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 120, "energy": 0.08, "key": "C", "scale": "major"} + b = {"tempo": 120, "energy": 0.08, "key": "F#", "scale": "minor"} + assert cd(a, b) == pytest.approx(0.30) + + def test_max_distance_all_features_differ(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 80, "energy": 0.01, "key": "C", "scale": "major"} + b = {"tempo": 160, "energy": 0.15, "key": "F#", "scale": "minor"} + assert cd(a, b) == pytest.approx(1.0) + + def test_tempo_normalised_by_80bpm(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 100, "energy": 0, "key": "", "scale": ""} + b = {"tempo": 140, "energy": 0, "key": "", "scale": ""} + assert cd(a, b) == pytest.approx(0.35*0.5 + 0.30*0.5) + + def test_tempo_diff_capped_at_one(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 60, "energy": 0.08, "key": "C", "scale": "major"} + b = {"tempo": 200, "energy": 0.08, "key": "C", "scale": "major"} + assert cd(a, b) == pytest.approx(0.35) + + def test_energy_diff_capped_at_one(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 120, "energy": 0.0, "key": "C", "scale": "major"} + b = {"tempo": 120, "energy": 0.5, "key": "C", "scale": "major"} + assert cd(a, b) == pytest.approx(0.35) + + def test_missing_tempo_treated_as_zero(self): + _, cd, *_ = _import_ordering() + a = {"energy": 0.08, "key": "C", "scale": "major"} + b = {"tempo": 80, "energy": 0.08, "key": "C", "scale": "major"} + assert cd(a, b) == pytest.approx(0.35) + + def test_missing_energy_treated_as_zero(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 120, "key": "C", "scale": "major"} + b = {"tempo": 120, "energy": 0.07, "key": "C", "scale": "major"} + assert cd(a, b) == pytest.approx(0.35*0.5) + + def test_missing_key_gives_neutral(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 120, "energy": 0.08} + b = {"tempo": 120, "energy": 0.08} + assert cd(a, b) == pytest.approx(0.30*0.5) + + def test_custom_weights(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 100, "energy": 0.01, "key": "C", "scale": "major"} + b = {"tempo": 180, "energy": 0.15, "key": "F#", "scale": "minor"} + assert cd(a, b, w_tempo=0.5, w_energy=0.3, w_key=0.2) == pytest.approx(1.0) + + def test_symmetry(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 90, "energy": 0.05, "key": "D", "scale": "minor"} + b = {"tempo": 140, "energy": 0.12, "key": "Ab", "scale": "major"} + assert cd(a, b) == cd(b, a) + + def test_partial_distance_contribution(self): + _, cd, *_ = _import_ordering() + a = {"tempo": 120, "energy": 0.08, "key": "C", "scale": "major"} + b = {"tempo": 160, "energy": 0.08, "key": "G", "scale": "major"} + expected = 0.35*0.5 + 0.30*(1.0/6.0*0.8) + assert cd(a, b) == pytest.approx(expected) + + +def _make_mock_db_rows(sd): + rows = [] + for iid, data in sd.items(): + row = dict(data) + row["item_id"] = iid + rows.append(row) + return rows +def _patch_order_playlist(sd): + """Patch DB calls for order_playlist. Pre-register mock modules + so the lazy imports inside order_playlist() never trigger the heavy + tasks/__init__.py import chain.""" + import sys + rows = _make_mock_db_rows(sd) + mc = MagicMock() + mc.fetchall.return_value = rows + conn = MagicMock() + conn.cursor.return_value.__enter__ = Mock(return_value=mc) + conn.cursor.return_value.__exit__ = Mock(return_value=None) + # Pre-register lightweight mocks for modules imported inside order_playlist() + if 'tasks.mcp_server' not in sys.modules: + mock_mcp = MagicMock() + sys.modules['tasks.mcp_server'] = mock_mcp + if 'psycopg2' not in sys.modules: + sys.modules['psycopg2'] = MagicMock() + if 'psycopg2.extras' not in sys.modules: + sys.modules['psycopg2.extras'] = MagicMock() + sys.modules['tasks.mcp_server'].get_db_connection = Mock(return_value=conn) + return patch.object(sys.modules['tasks.mcp_server'], 'get_db_connection', return_value=conn) + +@pytest.mark.unit class TestOrderPlaylist: - """Test order_playlist() function — Main greedy algorithm.""" - - def test_single_song_unchanged(self): - """Single song → return unchanged (no DB call needed).""" - mod = _load_playlist_ordering() - result = mod.order_playlist(['only_id']) - assert result == ['only_id'] - - def test_two_songs_unchanged(self): - """Two songs → no reordering (len <= 2, no DB call).""" - mod = _load_playlist_ordering() - result = mod.order_playlist(['id1', 'id2']) - assert result == ['id1', 'id2'] - - def test_empty_input(self): - """Empty input → empty output.""" - mod = _load_playlist_ordering() - result = mod.order_playlist([]) - assert result == [] - - def test_minimum_songs_no_ordering(self): - """3+ songs with len <= 2 orderable → return input unchanged.""" - mod = _load_playlist_ordering() - # This simulates the case where we have 3 songs but fewer than 3 with DB data - # Since the function checks if len(orderable_ids) <= 2 and returns early, - # we verify this behavior by checking the algorithm logic itself. - - # The function returns unchanged when there's no enough orderable data - # We can verify this through the underlying algorithm tests above + def test_empty_list_returns_empty(self): + _, _, op, *_ = _import_ordering() + with _patch_order_playlist({}): + assert op([]) == [] + + def test_single_song_returns_unchanged(self): + _, _, op, *_ = _import_ordering() + assert op(["song1"]) == ["song1"] + + def test_two_songs_returns_both(self): + _, _, op, *_ = _import_ordering() + assert op(["song1", "song2"]) == ["song1", "song2"] + + def test_all_input_songs_in_output(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 80+i*10, "energy": 0.02+i*0.02, "key": "C", "scale": "major"} for i in range(10)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + result = op(ids) + assert set(result) == set(ids) and len(result) == len(ids) + + def test_no_duplicates_in_output(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 100+i*5, "energy": 0.05+i*0.01, "key": "G", "scale": "minor"} for i in range(15)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + result = op(ids) + assert len(result) == len(set(result)) + + def test_starts_from_25th_percentile_energy(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.02+i*0.01, "key": "C", "scale": "major"} for i in range(8)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + result = op(ids) + first_e = sd[result[0]]["energy"] + sorted_e = sorted(sd[s]["energy"] for s in ids) + expected_e = sorted_e[len(sorted_e) // 4] + assert first_e == pytest.approx(expected_e) + + def test_adjacent_songs_have_small_distance(self): + _, cd, op, *_ = _import_ordering() + import random + random.seed(42) + kl = ["C","G","D","A","E","B","F#","Db","Ab","Eb"] + sd = {f"s{i}": {"tempo": 80+i*8, "energy": 0.02+i*0.012, "key": kl[i%10], "scale": "major" if i%2==0 else "minor"} for i in range(12)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + ordered = op(ids) + def td(seq): + return sum(cd(sd[seq[i]], sd[seq[i+1]]) for i in range(len(seq)-1)) + od = td(ordered) + rd = [] + for _ in range(20): + s = list(ids) + random.shuffle(s) + rd.append(td(s)) + assert od <= sum(rd)/len(rd) + + def test_unorderable_songs_appended_at_end(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 100+i*10, "energy": 0.05+i*0.02, "key": "C", "scale": "major"} for i in range(3)} + ids = ["s0", "s1", "s2", "s_missing"] + with _patch_order_playlist(sd): + result = op(ids) + assert result[-1] == "s_missing" + assert set(result) == set(ids) + + def test_no_db_rows_returns_original_order(self): + _, _, op, *_ = _import_ordering() + ids = ["a", "b", "c"] + with _patch_order_playlist({}): + assert op(ids) == ids + + def test_only_two_orderable_returns_original(self): + _, _, op, *_ = _import_ordering() + sd = { + "s0": {"tempo": 100, "energy": 0.05, "key": "C", "scale": "major"}, + "s1": {"tempo": 110, "energy": 0.06, "key": "G", "scale": "major"}, + } + ids = ["s0", "s1", "s_no_data"] + with _patch_order_playlist(sd): + assert op(ids) == ids + + +@pytest.mark.unit +class TestEnergyArc: + def test_energy_arc_false_deterministic(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.01+i*0.013, "key": "C", "scale": "major"} for i in range(12)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + r1 = op(ids, energy_arc=False) + with _patch_order_playlist(sd): + r2 = op(ids, energy_arc=False) + assert r1 == r2 + + def test_energy_arc_true_reshapes_10_plus(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.01+i*0.013, "key": "C", "scale": "major"} for i in range(12)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + r_arc = op(ids, energy_arc=True) + with _patch_order_playlist(sd): + r_no = op(ids, energy_arc=False) + assert r_arc != r_no + assert set(r_arc) == set(r_no) + + def test_energy_arc_skipped_under_10(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.01+i*0.02, "key": "C", "scale": "major"} for i in range(8)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + r1 = op(ids, energy_arc=True) + with _patch_order_playlist(sd): + r2 = op(ids, energy_arc=False) + assert r1 == r2 + + def test_apply_energy_arc_peak_in_middle(self): + *_, ea, _ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.01+i*0.012, "key": "C", "scale": "major"} for i in range(12)} + ids = [f"s{i}" for i in range(12)] + arc = ea(ids, sd) + assert set(arc) == set(ids) and len(arc) == 12 + energies = [sd[s]["energy"] for s in arc] + n = len(energies) + fq = sum(energies[:n//4]) / (n//4) + mid = sum(energies[n//3:2*n//3]) / (2*n//3 - n//3) + lq = sum(energies[3*n//4:]) / (n - 3*n//4) + assert mid > fq and mid > lq + + def test_apply_energy_arc_preserves_all(self): + *_, ea, _ = _import_ordering() + sd = {f"s{i}": {"tempo": 100, "energy": 0.01+i*0.01, "key": "D", "scale": "minor"} for i in range(15)} + ids = list(sd.keys()) + arc = ea(ids, sd) + assert set(arc) == set(ids) and len(arc) == len(ids) + + def test_apply_energy_arc_exact_10(self): + *_, ea, _ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.01+i*0.014, "key": "C", "scale": "major"} for i in range(10)} + ids = list(sd.keys()) + arc = ea(ids, sd) + assert set(arc) == set(ids) and len(arc) == 10 + + +@pytest.mark.unit +class TestEdgeCases: + def test_songs_with_missing_tempo(self): + _, _, op, *_ = _import_ordering() + sd = { + "s0": {"tempo": None, "energy": 0.05, "key": "C", "scale": "major"}, + "s1": {"tempo": 120, "energy": 0.06, "key": "G", "scale": "major"}, + "s2": {"tempo": None, "energy": 0.07, "key": "D", "scale": "minor"}, + } + with _patch_order_playlist(sd): + assert set(op(list(sd.keys()))) == set(sd.keys()) + + def test_songs_with_missing_energy(self): + _, _, op, *_ = _import_ordering() + sd = { + "s0": {"tempo": 100, "energy": None, "key": "C", "scale": "major"}, + "s1": {"tempo": 110, "energy": 0.05, "key": "G", "scale": "major"}, + "s2": {"tempo": 120, "energy": None, "key": "D", "scale": "minor"}, + } + with _patch_order_playlist(sd): + assert set(op(list(sd.keys()))) == set(sd.keys()) + + def test_songs_with_missing_key(self): + _, _, op, *_ = _import_ordering() + sd = { + "s0": {"tempo": 100, "energy": 0.05, "key": None, "scale": None}, + "s1": {"tempo": 110, "energy": 0.06, "key": "", "scale": ""}, + "s2": {"tempo": 120, "energy": 0.07, "key": "C", "scale": "major"}, + } + with _patch_order_playlist(sd): + assert set(op(list(sd.keys()))) == set(sd.keys()) + + def test_songs_with_all_missing_attributes(self): + _, _, op, *_ = _import_ordering() + sd = { + "s0": {"tempo": None, "energy": None, "key": None, "scale": None}, + "s1": {"tempo": None, "energy": None, "key": None, "scale": None}, + "s2": {"tempo": None, "energy": None, "key": None, "scale": None}, + } + with _patch_order_playlist(sd): + assert set(op(list(sd.keys()))) == set(sd.keys()) + + def test_all_songs_same_bpm_energy_key(self): + _, _, op, *_ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.08, "key": "C", "scale": "major"} for i in range(6)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + result = op(ids) + assert set(result) == set(ids) and len(result) == 6 + + @pytest.mark.slow + def test_large_playlist_completes(self): + _, _, op, *_ = _import_ordering() + n = 120 + keys = ["C","G","D","A","E","B","F#","Db","Ab","Eb","Bb","F"] + sd = {f"s{i}": {"tempo": 70+(i*7)%130, "energy": 0.01+(i*0.0012)%0.14, "key": keys[i%12], "scale": "major" if i%2==0 else "minor"} for i in range(n)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + result = op(ids) + assert len(result) == n and set(result) == set(ids) + + @pytest.mark.slow + def test_large_playlist_with_energy_arc(self): + _, _, op, *_ = _import_ordering() + n = 100 + keys = ["C","G","D","A","E","B","F#","Db","Ab","Eb","Bb","F"] + sd = {f"s{i}": {"tempo": 80+(i*5)%100, "energy": 0.01+(i*0.0014)%0.14, "key": keys[i%12], "scale": "major" if i%3!=0 else "minor"} for i in range(n)} + ids = list(sd.keys()) + with _patch_order_playlist(sd): + result = op(ids, energy_arc=True) + assert len(result) == n and set(result) == set(ids) + + def test_duplicate_ids_in_input(self): + _, _, op, *_ = _import_ordering() + sd = { + "s0": {"tempo": 100, "energy": 0.05, "key": "C", "scale": "major"}, + "s1": {"tempo": 110, "energy": 0.06, "key": "G", "scale": "major"}, + "s2": {"tempo": 120, "energy": 0.07, "key": "D", "scale": "minor"}, + } + with _patch_order_playlist(sd): + result = op(["s0", "s1", "s2", "s0"]) + assert len(result) >= 3 + + def test_zero_tempo_and_energy_songs(self): + _, _, op, *_ = _import_ordering() + sd = { + "s0": {"tempo": 0, "energy": 0, "key": "C", "scale": "major"}, + "s1": {"tempo": 0, "energy": 0, "key": "G", "scale": "major"}, + "s2": {"tempo": 0, "energy": 0, "key": "D", "scale": "minor"}, + } + with _patch_order_playlist(sd): + assert set(op(list(sd.keys()))) == set(sd.keys()) + + +@pytest.mark.unit +class TestCircleOfFifthsMap: + def test_all_12_chromatic_notes_mapped(self): + *_, cof = _import_ordering() + assert set(cof.values()) == set(range(12)) + + def test_enharmonic_pairs_same_position(self): + *_, cof = _import_ordering() + for a, b in [("F#","GB"),("C#","DB"),("G#","AB"),("D#","EB"),("A#","BB")]: + assert cof[a] == cof[b], f"{a} and {b} should be equal" + + def test_c_is_position_zero(self): + *_, cof = _import_ordering() + assert cof["C"] == 0 + + def test_g_is_position_one(self): + *_, cof = _import_ordering() + assert cof["G"] == 1 + + def test_f_is_position_eleven(self): + *_, cof = _import_ordering() + assert cof["F"] == 11 + + +@pytest.mark.unit +class TestApplyEnergyArcDirect: + def test_low_energy_at_start_and_end(self): + *_, ea, _ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": float(i)} for i in range(12)} + ids = [f"s{i}" for i in range(12)] + arc = ea(ids, sd) + energies = [sd[s]["energy"] for s in arc] + assert energies[0] < 4.0 and energies[-1] < 4.0 + + def test_high_energy_in_middle(self): + *_, ea, _ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": float(i)} for i in range(15)} + ids = [f"s{i}" for i in range(15)] + arc = ea(ids, sd) + energies = [sd[s]["energy"] for s in arc] + n = len(energies) + mid_sec = energies[n//3:2*n//3] + assert sum(mid_sec)/len(mid_sec) > sum(energies)/len(energies) + + def test_arc_with_identical_energies(self): + *_, ea, _ = _import_ordering() + sd = {f"s{i}": {"tempo": 120, "energy": 0.08} for i in range(12)} + ids = [f"s{i}" for i in range(12)] + arc = ea(ids, sd) + assert set(arc) == set(ids) and len(arc) == 12