diff --git a/backends/advanced/init.py b/backends/advanced/init.py index 4ea037b2..f40f660f 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -34,21 +34,31 @@ def __init__(self, args=None): self.console = Console() self.config: Dict[str, Any] = {} self.args = args or argparse.Namespace() - self.config_yml_path = Path("../../config/config.yml") # Main config at config/config.yml + self.config_yml_path = Path( + "../../config/config.yml" + ) # Main config at config/config.yml # Check if we're in the right directory if not Path("pyproject.toml").exists() or not Path("src").exists(): - self.console.print("[red][ERROR][/red] Please run this script from the backends/advanced directory") + self.console.print( + "[red][ERROR][/red] Please run this script from the backends/advanced directory" + ) sys.exit(1) # Initialize ConfigManager (single source of truth for config.yml) self.config_manager = ConfigManager(service_path="backends/advanced") - self.console.print(f"[blue][INFO][/blue] Using config.yml at: {self.config_manager.config_yml_path}") + self.console.print( + f"[blue][INFO][/blue] Using config.yml at: {self.config_manager.config_yml_path}" + ) # Verify config.yml exists - fail fast if missing if not self.config_manager.config_yml_path.exists(): - self.console.print(f"[red][ERROR][/red] config.yml not found at {self.config_manager.config_yml_path}") - self.console.print("[red][ERROR][/red] Run wizard.py from project root to create config.yml") + self.console.print( + f"[red][ERROR][/red] config.yml not found at {self.config_manager.config_yml_path}" + ) + self.console.print( + "[red][ERROR][/red] Run wizard.py from project root to create config.yml" + ) sys.exit(1) # Ensure plugins.yml exists (copy from template if missing) @@ -57,11 +67,7 @@ def __init__(self, args=None): def print_header(self, title: str): """Print a colorful header""" self.console.print() - panel = Panel( - Text(title, style="cyan bold"), - style="cyan", - expand=False - ) + panel = Panel(Text(title, style="cyan bold"), style="cyan", expand=False) self.console.print(panel) self.console.print() @@ -84,19 +90,23 @@ def prompt_password(self, prompt: str) -> str: """Prompt for password (delegates to shared utility)""" return util_prompt_password(prompt, min_length=8, allow_generated=True) - def prompt_choice(self, prompt: str, choices: Dict[str, str], default: str = "1") -> str: + def prompt_choice( + self, prompt: str, choices: Dict[str, str], default: str = "1" + ) -> str: """Prompt for a choice from options""" self.console.print(prompt) for key, desc in choices.items(): self.console.print(f" {key}) {desc}") self.console.print() - + while True: try: choice = Prompt.ask("Enter choice", default=default) if choice in choices: return choice - self.console.print(f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]") + self.console.print( + f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]" + ) except EOFError: self.console.print(f"Using default choice: {default}") return default @@ -108,11 +118,19 @@ def _ensure_plugins_yml_exists(self): if not plugins_yml.exists(): if plugins_template.exists(): - self.console.print("[blue][INFO][/blue] plugins.yml not found, creating from template...") + self.console.print( + "[blue][INFO][/blue] plugins.yml not found, creating from template..." + ) shutil.copy2(plugins_template, plugins_yml) - self.console.print(f"[green]✅[/green] Created {plugins_yml} from template") - self.console.print("[yellow][NOTE][/yellow] Edit config/plugins.yml to configure plugins") - self.console.print("[yellow][NOTE][/yellow] Set HA_TOKEN in .env for Home Assistant integration") + self.console.print( + f"[green]✅[/green] Created {plugins_yml} from template" + ) + self.console.print( + "[yellow][NOTE][/yellow] Edit config/plugins.yml to configure plugins" + ) + self.console.print( + "[yellow][NOTE][/yellow] Set HA_TOKEN in .env for Home Assistant integration" + ) else: raise RuntimeError( f"Template file not found: {plugins_template}\n" @@ -128,7 +146,9 @@ def backup_existing_env(self): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = f".env.backup.{timestamp}" shutil.copy2(env_path, backup_path) - self.console.print(f"[blue][INFO][/blue] Backed up existing .env file to {backup_path}") + self.console.print( + f"[blue][INFO][/blue] Backed up existing .env file to {backup_path}" + ) def read_existing_env_value(self, key: str) -> str: """Read a value from existing .env file (delegates to shared utility)""" @@ -138,8 +158,14 @@ def mask_api_key(self, key: str, show_chars: int = 5) -> str: """Mask API key (delegates to shared utility)""" return mask_value(key, show_chars) - def prompt_with_existing_masked(self, prompt_text: str, env_key: str, placeholders: list, - is_password: bool = False, default: str = "") -> str: + def prompt_with_existing_masked( + self, + prompt_text: str, + env_key: str, + placeholders: list, + is_password: bool = False, + default: str = "", + ) -> str: """ Prompt for a value, showing masked existing value from .env if present. Delegates to shared utility from setup_utils. @@ -161,10 +187,9 @@ def prompt_with_existing_masked(self, prompt_text: str, env_key: str, placeholde env_key=env_key, placeholders=placeholders, is_password=is_password, - default=default + default=default, ) - def setup_authentication(self): """Configure authentication settings""" self.print_section("Authentication Setup") @@ -186,13 +211,17 @@ def setup_authentication(self): ) self.config["ADMIN_PASSWORD"] = password else: - self.config["ADMIN_PASSWORD"] = self.prompt_password("Admin password (min 8 chars)") + self.config["ADMIN_PASSWORD"] = self.prompt_password( + "Admin password (min 8 chars)" + ) # Preserve existing AUTH_SECRET_KEY to avoid invalidating JWTs existing_secret = self.read_existing_env_value("AUTH_SECRET_KEY") if existing_secret: self.config["AUTH_SECRET_KEY"] = existing_secret - self.console.print("[blue][INFO][/blue] Reusing existing AUTH_SECRET_KEY (existing JWT tokens remain valid)") + self.console.print( + "[blue][INFO][/blue] Reusing existing AUTH_SECRET_KEY (existing JWT tokens remain valid)" + ) else: self.config["AUTH_SECRET_KEY"] = secrets.token_hex(32) @@ -201,9 +230,14 @@ def setup_authentication(self): def setup_transcription(self): """Configure transcription provider - updates config.yml and .env""" # Check if transcription provider was provided via command line - if hasattr(self.args, 'transcription_provider') and self.args.transcription_provider: + if ( + hasattr(self.args, "transcription_provider") + and self.args.transcription_provider + ): provider = self.args.transcription_provider - self.console.print(f"[green]✅[/green] Transcription: {provider} (configured via wizard)") + self.console.print( + f"[green]✅[/green] Transcription: {provider} (configured via wizard)" + ) # Map provider to choice if provider == "deepgram": @@ -223,21 +257,27 @@ def setup_transcription(self): else: self.print_section("Speech-to-Text Configuration") - self.console.print("[blue][INFO][/blue] Provider selection is configured in config.yml (defaults.stt)") + self.console.print( + "[blue][INFO][/blue] Provider selection is configured in config.yml (defaults.stt)" + ) self.console.print("[blue][INFO][/blue] API keys are stored in .env") self.console.print() # Interactive prompt - is_macos = platform.system() == 'Darwin' + is_macos = platform.system() == "Darwin" if is_macos: parakeet_desc = "Offline (Parakeet ASR - CPU-based, runs locally)" vibevoice_desc = "Offline (VibeVoice - CPU-based, built-in diarization)" else: parakeet_desc = "Offline (Parakeet ASR - GPU recommended, runs locally)" - vibevoice_desc = "Offline (VibeVoice - GPU recommended, built-in diarization)" + vibevoice_desc = ( + "Offline (VibeVoice - GPU recommended, built-in diarization)" + ) - qwen3_desc = "Offline (Qwen3-ASR - GPU required, 52 languages, streaming + batch)" + qwen3_desc = ( + "Offline (Qwen3-ASR - GPU required, 52 languages, streaming + batch)" + ) smallest_desc = "Smallest.ai Pulse (cloud-based, fast, requires API key)" @@ -247,10 +287,12 @@ def setup_transcription(self): "3": vibevoice_desc, "4": qwen3_desc, "5": smallest_desc, - "6": "None (skip transcription setup)" + "6": "None (skip transcription setup)", } - choice = self.prompt_choice("Choose your transcription provider:", choices, "1") + choice = self.prompt_choice( + "Choose your transcription provider:", choices, "1" + ) if choice == "1": self.console.print("[blue][INFO][/blue] Deepgram selected") @@ -260,9 +302,9 @@ def setup_transcription(self): api_key = self.prompt_with_existing_masked( prompt_text="Deepgram API key (leave empty to skip)", env_key="DEEPGRAM_API_KEY", - placeholders=['your_deepgram_api_key_here', 'your-deepgram-key-here'], + placeholders=["your_deepgram_api_key_here", "your-deepgram-key-here"], is_password=True, - default="" + default="", ) if api_key: @@ -272,14 +314,20 @@ def setup_transcription(self): # Update config.yml to use Deepgram self.config_manager.update_config_defaults({"stt": "stt-deepgram"}) - self.console.print("[green][SUCCESS][/green] Deepgram configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Deepgram configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-deepgram") else: - self.console.print("[yellow][WARNING][/yellow] No API key provided - transcription will not work") + self.console.print( + "[yellow][WARNING][/yellow] No API key provided - transcription will not work" + ) elif choice == "2": self.console.print("[blue][INFO][/blue] Offline Parakeet ASR selected") - parakeet_url = self.prompt_value("Parakeet ASR URL", "http://host.docker.internal:8767") + parakeet_url = self.prompt_value( + "Parakeet ASR URL (without http:// prefix)", "host.docker.internal:8767" + ) # Write URL to .env for ${PARAKEET_ASR_URL} placeholder in config.yml self.config["PARAKEET_ASR_URL"] = parakeet_url @@ -287,13 +335,24 @@ def setup_transcription(self): # Update config.yml to use Parakeet self.config_manager.update_config_defaults({"stt": "stt-parakeet-batch"}) - self.console.print("[green][SUCCESS][/green] Parakeet configured in config.yml and .env") - self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-parakeet-batch") - self.console.print("[yellow][WARNING][/yellow] Remember to start Parakeet service: cd ../../extras/asr-services && docker compose up nemo-asr") + self.console.print( + "[green][SUCCESS][/green] Parakeet configured in config.yml and .env" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.stt: stt-parakeet-batch" + ) + self.console.print( + "[yellow][WARNING][/yellow] Remember to start Parakeet service: cd ../../extras/asr-services && docker compose up nemo-asr" + ) elif choice == "3": - self.console.print("[blue][INFO][/blue] Offline VibeVoice ASR selected (built-in speaker diarization)") - vibevoice_url = self.prompt_value("VibeVoice ASR URL", "http://host.docker.internal:8767") + self.console.print( + "[blue][INFO][/blue] Offline VibeVoice ASR selected (built-in speaker diarization)" + ) + vibevoice_url = self.prompt_value( + "VibeVoice ASR URL (without http:// prefix)", + "host.docker.internal:8767", + ) # Write URL to .env for ${VIBEVOICE_ASR_URL} placeholder in config.yml self.config["VIBEVOICE_ASR_URL"] = vibevoice_url @@ -301,14 +360,24 @@ def setup_transcription(self): # Update config.yml to use VibeVoice self.config_manager.update_config_defaults({"stt": "stt-vibevoice"}) - self.console.print("[green][SUCCESS][/green] VibeVoice configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] VibeVoice configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-vibevoice") - self.console.print("[blue][INFO][/blue] VibeVoice provides built-in speaker diarization - pyannote will be skipped") - self.console.print("[yellow][WARNING][/yellow] Remember to start VibeVoice service: cd ../../extras/asr-services && docker compose up vibevoice-asr") + self.console.print( + "[blue][INFO][/blue] VibeVoice provides built-in speaker diarization - pyannote will be skipped" + ) + self.console.print( + "[yellow][WARNING][/yellow] Remember to start VibeVoice service: cd ../../extras/asr-services && docker compose up vibevoice-asr" + ) elif choice == "4": - self.console.print("[blue][INFO][/blue] Qwen3-ASR selected (52 languages, streaming + batch via vLLM)") - qwen3_url = self.prompt_value("Qwen3-ASR URL", "http://host.docker.internal:8767") + self.console.print( + "[blue][INFO][/blue] Qwen3-ASR selected (52 languages, streaming + batch via vLLM)" + ) + qwen3_url = self.prompt_value( + "Qwen3-ASR URL", "http://host.docker.internal:8767" + ) # Write URL to .env for ${QWEN3_ASR_URL} placeholder in config.yml self.config["QWEN3_ASR_URL"] = qwen3_url.replace("http://", "").rstrip("/") @@ -320,9 +389,13 @@ def setup_transcription(self): # Update config.yml to use Qwen3-ASR self.config_manager.update_config_defaults({"stt": "stt-qwen3-asr"}) - self.console.print("[green][SUCCESS][/green] Qwen3-ASR configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Qwen3-ASR configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-qwen3-asr") - self.console.print("[yellow][WARNING][/yellow] Remember to start Qwen3-ASR: cd ../../extras/asr-services && docker compose up qwen3-asr-wrapper qwen3-asr-bridge -d") + self.console.print( + "[yellow][WARNING][/yellow] Remember to start Qwen3-ASR: cd ../../extras/asr-services && docker compose up qwen3-asr-wrapper qwen3-asr-bridge -d" + ) elif choice == "5": self.console.print("[blue][INFO][/blue] Smallest.ai Pulse selected") @@ -332,9 +405,9 @@ def setup_transcription(self): api_key = self.prompt_with_existing_masked( prompt_text="Smallest.ai API key (leave empty to skip)", env_key="SMALLEST_API_KEY", - placeholders=['your_smallest_api_key_here', 'your-smallest-key-here'], + placeholders=["your_smallest_api_key_here", "your-smallest-key-here"], is_password=True, - default="" + default="", ) if api_key: @@ -342,16 +415,21 @@ def setup_transcription(self): self.config["SMALLEST_API_KEY"] = api_key # Update config.yml to use Smallest.ai (batch + streaming) - self.config_manager.update_config_defaults({ - "stt": "stt-smallest", - "stt_stream": "stt-smallest-stream" - }) + self.config_manager.update_config_defaults( + {"stt": "stt-smallest", "stt_stream": "stt-smallest-stream"} + ) - self.console.print("[green][SUCCESS][/green] Smallest.ai configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Smallest.ai configured in config.yml and .env" + ) self.console.print("[blue][INFO][/blue] Set defaults.stt: stt-smallest") - self.console.print("[blue][INFO][/blue] Set defaults.stt_stream: stt-smallest-stream") + self.console.print( + "[blue][INFO][/blue] Set defaults.stt_stream: stt-smallest-stream" + ) else: - self.console.print("[yellow][WARNING][/yellow] No API key provided - transcription will not work") + self.console.print( + "[yellow][WARNING][/yellow] No API key provided - transcription will not work" + ) elif choice == "6": self.console.print("[blue][INFO][/blue] Skipping transcription setup") @@ -362,11 +440,16 @@ def setup_streaming_provider(self): When a different streaming provider is specified, sets defaults.stt_stream and enables always_batch_retranscribe (batch provider was set by setup_transcription). """ - if not hasattr(self.args, 'streaming_provider') or not self.args.streaming_provider: + if ( + not hasattr(self.args, "streaming_provider") + or not self.args.streaming_provider + ): return streaming_provider = self.args.streaming_provider - self.console.print(f"\n[green]✅[/green] Streaming provider: {streaming_provider} (configured via wizard)") + self.console.print( + f"\n[green]✅[/green] Streaming provider: {streaming_provider} (configured via wizard)" + ) # Map streaming provider to stt_stream config value provider_to_stt_stream = { @@ -377,7 +460,9 @@ def setup_streaming_provider(self): stream_stt = provider_to_stt_stream.get(streaming_provider) if not stream_stt: - self.console.print(f"[yellow][WARNING][/yellow] Unknown streaming provider: {streaming_provider}") + self.console.print( + f"[yellow][WARNING][/yellow] Unknown streaming provider: {streaming_provider}" + ) return # Set stt_stream (batch stt was already set by setup_transcription) @@ -385,11 +470,11 @@ def setup_streaming_provider(self): # Enable always_batch_retranscribe full_config = self.config_manager.get_full_config() - if 'backend' not in full_config: - full_config['backend'] = {} - if 'transcription' not in full_config['backend']: - full_config['backend']['transcription'] = {} - full_config['backend']['transcription']['always_batch_retranscribe'] = True + if "backend" not in full_config: + full_config["backend"] = {} + if "transcription" not in full_config["backend"]: + full_config["backend"]["transcription"] = {} + full_config["backend"]["transcription"]["always_batch_retranscribe"] = True self.config_manager.save_full_config(full_config) self.console.print(f"[blue][INFO][/blue] Set defaults.stt_stream: {stream_stt}") @@ -397,33 +482,47 @@ def setup_streaming_provider(self): # Prompt for streaming provider env vars if not already set if streaming_provider == "deepgram": - existing_key = read_env_value('.env', 'DEEPGRAM_API_KEY') - if not existing_key or existing_key in ('your_deepgram_api_key_here', 'your-deepgram-key-here'): + existing_key = read_env_value(".env", "DEEPGRAM_API_KEY") + if not existing_key or existing_key in ( + "your_deepgram_api_key_here", + "your-deepgram-key-here", + ): api_key = self.prompt_with_existing_masked( prompt_text="Deepgram API key for streaming", env_key="DEEPGRAM_API_KEY", - placeholders=['your_deepgram_api_key_here', 'your-deepgram-key-here'], + placeholders=[ + "your_deepgram_api_key_here", + "your-deepgram-key-here", + ], is_password=True, - default="" + default="", ) if api_key: self.config["DEEPGRAM_API_KEY"] = api_key elif streaming_provider == "smallest": - existing_key = read_env_value('.env', 'SMALLEST_API_KEY') - if not existing_key or existing_key in ('your_smallest_api_key_here', 'your-smallest-key-here'): + existing_key = read_env_value(".env", "SMALLEST_API_KEY") + if not existing_key or existing_key in ( + "your_smallest_api_key_here", + "your-smallest-key-here", + ): api_key = self.prompt_with_existing_masked( prompt_text="Smallest.ai API key for streaming", env_key="SMALLEST_API_KEY", - placeholders=['your_smallest_api_key_here', 'your-smallest-key-here'], + placeholders=[ + "your_smallest_api_key_here", + "your-smallest-key-here", + ], is_password=True, - default="" + default="", ) if api_key: self.config["SMALLEST_API_KEY"] = api_key elif streaming_provider == "qwen3-asr": - existing_url = read_env_value('.env', 'QWEN3_ASR_STREAM_URL') + existing_url = read_env_value(".env", "QWEN3_ASR_STREAM_URL") if not existing_url: - qwen3_url = self.prompt_value("Qwen3-ASR streaming URL", "http://host.docker.internal:8769") + qwen3_url = self.prompt_value( + "Qwen3-ASR streaming URL", "http://host.docker.internal:8769" + ) stream_host = qwen3_url.replace("http://", "").rstrip("/") self.config["QWEN3_ASR_STREAM_URL"] = stream_host @@ -431,51 +530,189 @@ def setup_llm(self): """Configure LLM provider - updates config.yml and .env""" self.print_section("LLM Provider Configuration") - self.console.print("[blue][INFO][/blue] LLM configuration will be saved to config.yml") + self.console.print( + "[blue][INFO][/blue] LLM configuration will be saved to config.yml" + ) self.console.print() choices = { "1": "OpenAI (GPT-4, GPT-3.5 - requires API key)", "2": "Ollama (local models - runs locally)", - "3": "Skip (no memory extraction)" + "3": "OpenAI-Compatible (custom endpoint - Groq, Together AI, LM Studio, etc.)", + "4": "Skip (no memory extraction)", } choice = self.prompt_choice("Which LLM provider will you use?", choices, "1") if choice == "1": self.console.print("[blue][INFO][/blue] OpenAI selected") - self.console.print("Get your API key from: https://platform.openai.com/api-keys") + self.console.print( + "Get your API key from: https://platform.openai.com/api-keys" + ) # Use the new masked prompt function api_key = self.prompt_with_existing_masked( prompt_text="OpenAI API key (leave empty to skip)", env_key="OPENAI_API_KEY", - placeholders=['your_openai_api_key_here', 'your-openai-key-here'], + placeholders=["your_openai_api_key_here", "your-openai-key-here"], is_password=True, - default="" + default="", ) if api_key: self.config["OPENAI_API_KEY"] = api_key # Update config.yml to use OpenAI models - self.config_manager.update_config_defaults({"llm": "openai-llm", "embedding": "openai-embed"}) - self.console.print("[green][SUCCESS][/green] OpenAI configured in config.yml") + self.config_manager.update_config_defaults( + {"llm": "openai-llm", "embedding": "openai-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] OpenAI configured in config.yml" + ) self.console.print("[blue][INFO][/blue] Set defaults.llm: openai-llm") - self.console.print("[blue][INFO][/blue] Set defaults.embedding: openai-embed") + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: openai-embed" + ) else: - self.console.print("[yellow][WARNING][/yellow] No API key provided - memory extraction will not work") + self.console.print( + "[yellow][WARNING][/yellow] No API key provided - memory extraction will not work" + ) elif choice == "2": self.console.print("[blue][INFO][/blue] Ollama selected") # Update config.yml to use Ollama models - self.config_manager.update_config_defaults({"llm": "local-llm", "embedding": "local-embed"}) - self.console.print("[green][SUCCESS][/green] Ollama configured in config.yml") + self.config_manager.update_config_defaults( + {"llm": "local-llm", "embedding": "local-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] Ollama configured in config.yml" + ) self.console.print("[blue][INFO][/blue] Set defaults.llm: local-llm") - self.console.print("[blue][INFO][/blue] Set defaults.embedding: local-embed") - self.console.print("[yellow][WARNING][/yellow] Make sure Ollama is running and models are pulled") + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: local-embed" + ) + self.console.print( + "[yellow][WARNING][/yellow] Make sure Ollama is running and models are pulled" + ) elif choice == "3": - self.console.print("[blue][INFO][/blue] Skipping LLM setup - memory extraction disabled") + self.console.print( + "[blue][INFO][/blue] OpenAI-Compatible custom endpoint selected" + ) + self.console.print( + "This works with any provider that exposes an OpenAI-compatible API" + ) + self.console.print("(e.g., Groq, Together AI, LM Studio, vLLM, etc.)") + self.console.print() + + # Prompt for base URL (required) + base_url = self.prompt_value( + "API Base URL (e.g., https://api.groq.com/openai/v1)", "" + ) + if not base_url: + self.console.print( + "[yellow][WARNING][/yellow] No base URL provided - skipping custom LLM setup" + ) + else: + # Prompt for API key + api_key = self.prompt_with_existing_masked( + prompt_text="API Key (leave empty if not required)", + env_key="CUSTOM_LLM_API_KEY", + placeholders=["your_custom_llm_api_key_here"], + is_password=True, + default="", + ) + if api_key: + self.config["CUSTOM_LLM_API_KEY"] = api_key + + # Prompt for model name (required) + model_name = self.prompt_value( + "LLM Model name (e.g., llama-3.1-70b-versatile)", "" + ) + if not model_name: + self.console.print( + "[yellow][WARNING][/yellow] No model name provided - skipping custom LLM setup" + ) + else: + # Create LLM model entry + llm_model = { + "name": "custom-llm", + "description": "Custom OpenAI-compatible LLM", + "model_type": "llm", + "model_provider": "openai", + "api_family": "openai", + "model_name": model_name, + "model_url": base_url, + "api_key": "${oc.env:CUSTOM_LLM_API_KEY,''}", + "model_params": {"temperature": 0.2, "max_tokens": 2000}, + "model_output": "json", + } + self.config_manager.add_or_update_model(llm_model) + + # Prompt for optional embedding model + embedding_model_name = self.prompt_value( + "Embedding model name (leave empty to use Ollama local-embed)", + "", + ) + + if embedding_model_name: + embed_dim_str = self.prompt_value( + "Embedding dimensions (e.g. 1536 for text-embedding-3-small, 3072 for text-embedding-3-large)", + "1536", + ) + try: + embedding_dimensions = int(embed_dim_str) + except ValueError: + self.console.print( + f"[yellow][WARNING][/yellow] Invalid dimensions '{embed_dim_str}', using default 1536" + ) + raise ValueError(f"Invalid dimensions '{embed_dim_str}'") + + embed_model = { + "name": "custom-embed", + "description": "Custom OpenAI-compatible embeddings", + "model_type": "embedding", + "model_provider": "openai", + "api_family": "openai", + "model_name": embedding_model_name, + "model_url": base_url, + "api_key": "${oc.env:CUSTOM_LLM_API_KEY,''}", + "embedding_dimensions": embedding_dimensions, + "model_output": "vector", + } + self.config_manager.add_or_update_model(embed_model) + self.config_manager.update_config_defaults( + {"llm": "custom-llm", "embedding": "custom-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] Custom LLM and embedding configured in config.yml" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.llm: custom-llm" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: custom-embed" + ) + else: + self.config_manager.update_config_defaults( + {"llm": "custom-llm", "embedding": "local-embed"} + ) + self.console.print( + "[green][SUCCESS][/green] Custom LLM configured in config.yml" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.llm: custom-llm" + ) + self.console.print( + "[blue][INFO][/blue] Set defaults.embedding: local-embed (Ollama)" + ) + self.console.print( + "[yellow][WARNING][/yellow] Make sure Ollama is running for embeddings" + ) + + elif choice == "4": + self.console.print( + "[blue][INFO][/blue] Skipping LLM setup - memory extraction disabled" + ) # Disable memory extraction in config.yml self.config_manager.update_memory_config({"extraction": {"enabled": False}}) @@ -491,80 +728,115 @@ def setup_memory(self): choice = self.prompt_choice("Choose your memory storage backend:", choices, "1") if choice == "1": - self.console.print("[blue][INFO][/blue] Chronicle Native memory provider selected") + self.console.print( + "[blue][INFO][/blue] Chronicle Native memory provider selected" + ) qdrant_url = self.prompt_value("Qdrant URL", "qdrant") self.config["QDRANT_BASE_URL"] = qdrant_url # Update config.yml (also updates .env automatically) self.config_manager.update_memory_config({"provider": "chronicle"}) - self.console.print("[green][SUCCESS][/green] Chronicle memory provider configured in config.yml and .env") + self.console.print( + "[green][SUCCESS][/green] Chronicle memory provider configured in config.yml and .env" + ) elif choice == "2": self.console.print("[blue][INFO][/blue] OpenMemory MCP selected") - mcp_url = self.prompt_value("OpenMemory MCP server URL", "http://host.docker.internal:8765") + mcp_url = self.prompt_value( + "OpenMemory MCP server URL", "http://host.docker.internal:8765" + ) client_name = self.prompt_value("OpenMemory client name", "chronicle") user_id = self.prompt_value("OpenMemory user ID", "openmemory") timeout = self.prompt_value("OpenMemory timeout (seconds)", "30") # Update config.yml with OpenMemory MCP settings (also updates .env automatically) - self.config_manager.update_memory_config({ - "provider": "openmemory_mcp", - "openmemory_mcp": { - "server_url": mcp_url, - "client_name": client_name, - "user_id": user_id, - "timeout": int(timeout) + self.config_manager.update_memory_config( + { + "provider": "openmemory_mcp", + "openmemory_mcp": { + "server_url": mcp_url, + "client_name": client_name, + "user_id": user_id, + "timeout": int(timeout), + }, } - }) - self.console.print("[green][SUCCESS][/green] OpenMemory MCP configured in config.yml and .env") - self.console.print("[yellow][WARNING][/yellow] Remember to start OpenMemory: cd ../../extras/openmemory-mcp && docker compose up -d") + ) + self.console.print( + "[green][SUCCESS][/green] OpenMemory MCP configured in config.yml and .env" + ) + self.console.print( + "[yellow][WARNING][/yellow] Remember to start OpenMemory: cd ../../extras/openmemory-mcp && docker compose up -d" + ) def setup_optional_services(self): """Configure optional services""" # Check if speaker service URL provided via args - has_speaker_arg = hasattr(self.args, 'speaker_service_url') and self.args.speaker_service_url - has_asr_arg = hasattr(self.args, 'parakeet_asr_url') and self.args.parakeet_asr_url + has_speaker_arg = ( + hasattr(self.args, "speaker_service_url") and self.args.speaker_service_url + ) + has_asr_arg = ( + hasattr(self.args, "parakeet_asr_url") and self.args.parakeet_asr_url + ) if has_speaker_arg: self.config["SPEAKER_SERVICE_URL"] = self.args.speaker_service_url - self.console.print(f"[green]✅[/green] Speaker Recognition: {self.args.speaker_service_url} (configured via wizard)") + self.console.print( + f"[green]✅[/green] Speaker Recognition: {self.args.speaker_service_url} (configured via wizard)" + ) if has_asr_arg: self.config["PARAKEET_ASR_URL"] = self.args.parakeet_asr_url - self.console.print(f"[green]✅[/green] Parakeet ASR: {self.args.parakeet_asr_url} (configured via wizard)") + self.console.print( + f"[green]✅[/green] Parakeet ASR: {self.args.parakeet_asr_url} (configured via wizard)" + ) # Only show interactive section if not all configured via args if not has_speaker_arg: try: - enable_speaker = Confirm.ask("Enable Speaker Recognition?", default=False) + enable_speaker = Confirm.ask( + "Enable Speaker Recognition?", default=False + ) except EOFError: self.console.print("Using default: No") enable_speaker = False - + if enable_speaker: - speaker_url = self.prompt_value("Speaker Recognition service URL", "http://host.docker.internal:8001") + speaker_url = self.prompt_value( + "Speaker Recognition service URL", + "http://host.docker.internal:8001", + ) self.config["SPEAKER_SERVICE_URL"] = speaker_url - self.console.print("[green][SUCCESS][/green] Speaker Recognition configured") - self.console.print("[blue][INFO][/blue] Start with: cd ../../extras/speaker-recognition && docker compose up -d") - + self.console.print( + "[green][SUCCESS][/green] Speaker Recognition configured" + ) + self.console.print( + "[blue][INFO][/blue] Start with: cd ../../extras/speaker-recognition && docker compose up -d" + ) + # Check if Tailscale auth key provided via args - if hasattr(self.args, 'ts_authkey') and self.args.ts_authkey: + if hasattr(self.args, "ts_authkey") and self.args.ts_authkey: self.config["TS_AUTHKEY"] = self.args.ts_authkey - self.console.print(f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)") + self.console.print( + f"[green][SUCCESS][/green] Tailscale auth key configured (Docker integration enabled)" + ) def setup_neo4j(self): """Configure Neo4j credentials (always required - used by Knowledge Graph)""" - neo4j_password = getattr(self.args, 'neo4j_password', None) + neo4j_password = getattr(self.args, "neo4j_password", None) if neo4j_password: - self.console.print(f"[green]✅[/green] Neo4j: password configured via wizard") + self.console.print( + f"[green]✅[/green] Neo4j: password configured via wizard" + ) else: # Interactive prompt (standalone init.py run) self.console.print() self.console.print("[bold cyan]Neo4j Configuration[/bold cyan]") - self.console.print("Neo4j is used for Knowledge Graph (entity/relationship extraction)") + self.console.print( + "Neo4j is used for Knowledge Graph (entity/relationship extraction)" + ) self.console.print() neo4j_password = self.prompt_password("Neo4j password (min 8 chars)") @@ -575,49 +847,54 @@ def setup_neo4j(self): def setup_obsidian(self): """Configure Obsidian integration (optional feature flag only - Neo4j credentials handled by setup_neo4j)""" - if hasattr(self.args, 'enable_obsidian') and self.args.enable_obsidian: + if hasattr(self.args, "enable_obsidian") and self.args.enable_obsidian: enable_obsidian = True - self.console.print(f"[green]✅[/green] Obsidian: enabled (configured via wizard)") + self.console.print( + f"[green]✅[/green] Obsidian: enabled (configured via wizard)" + ) else: # Interactive prompt (fallback) self.console.print() self.console.print("[bold cyan]Obsidian Integration (Optional)[/bold cyan]") - self.console.print("Enable graph-based knowledge management for Obsidian vault notes") + self.console.print( + "Enable graph-based knowledge management for Obsidian vault notes" + ) self.console.print() try: - enable_obsidian = Confirm.ask("Enable Obsidian integration?", default=False) + enable_obsidian = Confirm.ask( + "Enable Obsidian integration?", default=False + ) except EOFError: self.console.print("Using default: No") enable_obsidian = False if enable_obsidian: - self.config_manager.update_memory_config({ - "obsidian": { - "enabled": True, - "neo4j_host": "neo4j", - "timeout": 30 - } - }) + self.config_manager.update_memory_config( + {"obsidian": {"enabled": True, "neo4j_host": "neo4j", "timeout": 30}} + ) self.console.print("[green][SUCCESS][/green] Obsidian integration enabled") else: - self.config_manager.update_memory_config({ - "obsidian": { - "enabled": False, - "neo4j_host": "neo4j", - "timeout": 30 - } - }) + self.config_manager.update_memory_config( + {"obsidian": {"enabled": False, "neo4j_host": "neo4j", "timeout": 30}} + ) self.console.print("[blue][INFO][/blue] Obsidian integration disabled") def setup_knowledge_graph(self): """Configure Knowledge Graph (Neo4j-based entity/relationship extraction - enabled by default)""" - if hasattr(self.args, 'enable_knowledge_graph') and self.args.enable_knowledge_graph: + if ( + hasattr(self.args, "enable_knowledge_graph") + and self.args.enable_knowledge_graph + ): enable_kg = True else: self.console.print() - self.console.print("[bold cyan]Knowledge Graph (Entity Extraction)[/bold cyan]") - self.console.print("Extract people, places, organizations, events, and tasks from conversations") + self.console.print( + "[bold cyan]Knowledge Graph (Entity Extraction)[/bold cyan]" + ) + self.console.print( + "Extract people, places, organizations, events, and tasks from conversations" + ) self.console.print() try: @@ -627,56 +904,77 @@ def setup_knowledge_graph(self): enable_kg = True if enable_kg: - self.config_manager.update_memory_config({ - "knowledge_graph": { - "enabled": True, - "neo4j_host": "neo4j", - "timeout": 30 + self.config_manager.update_memory_config( + { + "knowledge_graph": { + "enabled": True, + "neo4j_host": "neo4j", + "timeout": 30, + } } - }) + ) self.console.print("[green][SUCCESS][/green] Knowledge Graph enabled") - self.console.print("[blue][INFO][/blue] Entities and relationships will be extracted from conversations") + self.console.print( + "[blue][INFO][/blue] Entities and relationships will be extracted from conversations" + ) else: - self.config_manager.update_memory_config({ - "knowledge_graph": { - "enabled": False, - "neo4j_host": "neo4j", - "timeout": 30 + self.config_manager.update_memory_config( + { + "knowledge_graph": { + "enabled": False, + "neo4j_host": "neo4j", + "timeout": 30, + } } - }) + ) self.console.print("[blue][INFO][/blue] Knowledge Graph disabled") def setup_langfuse(self): """Configure LangFuse observability and prompt management""" self.console.print() - self.console.print("[bold cyan]LangFuse Observability & Prompt Management[/bold cyan]") + self.console.print( + "[bold cyan]LangFuse Observability & Prompt Management[/bold cyan]" + ) # Check if keys were passed from wizard (langfuse init already ran) - langfuse_pub = getattr(self.args, 'langfuse_public_key', None) - langfuse_sec = getattr(self.args, 'langfuse_secret_key', None) + langfuse_pub = getattr(self.args, "langfuse_public_key", None) + langfuse_sec = getattr(self.args, "langfuse_secret_key", None) if langfuse_pub and langfuse_sec: # Auto-configure from wizard — no prompts needed - langfuse_host = getattr(self.args, 'langfuse_host', None) or "http://langfuse-web:3000" + langfuse_host = ( + getattr(self.args, "langfuse_host", None) or "http://langfuse-web:3000" + ) self.config["LANGFUSE_HOST"] = langfuse_host self.config["LANGFUSE_PUBLIC_KEY"] = langfuse_pub self.config["LANGFUSE_SECRET_KEY"] = langfuse_sec self.config["LANGFUSE_BASE_URL"] = langfuse_host # Derive browser-accessible URL for deep-links - public_url = getattr(self.args, 'langfuse_public_url', None) or "http://localhost:3002" + public_url = ( + getattr(self.args, "langfuse_public_url", None) + or "http://localhost:3002" + ) self._save_langfuse_public_url(public_url) source = "external" if "langfuse-web" not in langfuse_host else "local" - self.console.print(f"[green][SUCCESS][/green] LangFuse auto-configured ({source})") + self.console.print( + f"[green][SUCCESS][/green] LangFuse auto-configured ({source})" + ) self.console.print(f"[blue][INFO][/blue] Host: {langfuse_host}") self.console.print(f"[blue][INFO][/blue] Public URL: {public_url}") - self.console.print(f"[blue][INFO][/blue] Public key: {self.mask_api_key(langfuse_pub)}") + self.console.print( + f"[blue][INFO][/blue] Public key: {self.mask_api_key(langfuse_pub)}" + ) return # Manual configuration (standalone init.py run) - self.console.print("Enable LLM tracing, observability, and prompt management with LangFuse") - self.console.print("Self-host: cd ../../extras/langfuse && docker compose up -d") + self.console.print( + "Enable LLM tracing, observability, and prompt management with LangFuse" + ) + self.console.print( + "Self-host: cd ../../extras/langfuse && docker compose up -d" + ) self.console.print() try: @@ -748,52 +1046,68 @@ def setup_network(self): def setup_https(self): """Configure HTTPS settings for microphone access""" # Check if HTTPS configuration provided via command line - if hasattr(self.args, 'enable_https') and self.args.enable_https: + if hasattr(self.args, "enable_https") and self.args.enable_https: enable_https = True - server_ip = getattr(self.args, 'server_ip', 'localhost') - self.console.print(f"[green]✅[/green] HTTPS: {server_ip} (configured via wizard)") + server_ip = getattr(self.args, "server_ip", "localhost") + self.console.print( + f"[green]✅[/green] HTTPS: {server_ip} (configured via wizard)" + ) else: # Interactive configuration self.print_section("HTTPS Configuration (Optional)") try: - enable_https = Confirm.ask("Enable HTTPS for microphone access?", default=False) + enable_https = Confirm.ask( + "Enable HTTPS for microphone access?", default=False + ) except EOFError: self.console.print("Using default: No") enable_https = False if enable_https: - self.console.print("[blue][INFO][/blue] HTTPS enables microphone access in browsers") + self.console.print( + "[blue][INFO][/blue] HTTPS enables microphone access in browsers" + ) # Try to auto-detect Tailscale address ts_dns, ts_ip = detect_tailscale_info() if ts_dns: - self.console.print(f"[green][AUTO-DETECTED][/green] Tailscale DNS: {ts_dns}") + self.console.print( + f"[green][AUTO-DETECTED][/green] Tailscale DNS: {ts_dns}" + ) if ts_ip: - self.console.print(f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}") + self.console.print( + f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}" + ) default_address = ts_dns elif ts_ip: - self.console.print(f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}") + self.console.print( + f"[green][AUTO-DETECTED][/green] Tailscale IP: {ts_ip}" + ) default_address = ts_ip else: self.console.print("[blue][INFO][/blue] Tailscale not detected") - self.console.print("[blue][INFO][/blue] To find your Tailscale address: tailscale status --json | jq -r '.Self.DNSName'") + self.console.print( + "[blue][INFO][/blue] To find your Tailscale address: tailscale status --json | jq -r '.Self.DNSName'" + ) default_address = "localhost" - self.console.print("[blue][INFO][/blue] For local-only access, use 'localhost'") + self.console.print( + "[blue][INFO][/blue] For local-only access, use 'localhost'" + ) # Use the new masked prompt function (not masked for IP, but shows existing) server_ip = self.prompt_with_existing_masked( prompt_text="Server IP/Domain for SSL certificate", env_key="SERVER_IP", - placeholders=['localhost', 'your-server-ip-here'], + placeholders=["localhost", "your-server-ip-here"], is_password=False, - default=default_address + default=default_address, ) - + if enable_https: - + # Generate SSL certificates self.console.print("[blue][INFO][/blue] Generating SSL certificates...") # Use path relative to this script's directory @@ -802,17 +1116,32 @@ def setup_https(self): if ssl_script.exists(): try: # Run from the backend directory so paths work correctly - subprocess.run([str(ssl_script), server_ip], check=True, cwd=str(script_dir), timeout=180) - self.console.print("[green][SUCCESS][/green] SSL certificates generated") + subprocess.run( + [str(ssl_script), server_ip], + check=True, + cwd=str(script_dir), + timeout=180, + ) + self.console.print( + "[green][SUCCESS][/green] SSL certificates generated" + ) except subprocess.TimeoutExpired: - self.console.print("[yellow][WARNING][/yellow] SSL certificate generation timed out after 3 minutes") + self.console.print( + "[yellow][WARNING][/yellow] SSL certificate generation timed out after 3 minutes" + ) except subprocess.CalledProcessError: - self.console.print("[yellow][WARNING][/yellow] SSL certificate generation failed") + self.console.print( + "[yellow][WARNING][/yellow] SSL certificate generation failed" + ) else: - self.console.print(f"[yellow][WARNING][/warning] SSL script not found at {ssl_script}") + self.console.print( + f"[yellow][WARNING][/warning] SSL script not found at {ssl_script}" + ) # Generate Caddyfile from template - self.console.print("[blue][INFO][/blue] Creating Caddyfile configuration...") + self.console.print( + "[blue][INFO][/blue] Creating Caddyfile configuration..." + ) caddyfile_template = script_dir / "Caddyfile.template" caddyfile_path = script_dir / "Caddyfile" @@ -820,32 +1149,50 @@ def setup_https(self): try: # Check if Caddyfile exists as a directory (common issue) if caddyfile_path.exists() and caddyfile_path.is_dir(): - self.console.print("[red]❌ ERROR: 'Caddyfile' exists as a directory![/red]") - self.console.print("[yellow] Please remove it manually:[/yellow]") - self.console.print(f"[yellow] rm -rf {caddyfile_path}[/yellow]") - self.console.print("[red] HTTPS will NOT work without a proper Caddyfile![/red]") + self.console.print( + "[red]❌ ERROR: 'Caddyfile' exists as a directory![/red]" + ) + self.console.print( + "[yellow] Please remove it manually:[/yellow]" + ) + self.console.print( + f"[yellow] rm -rf {caddyfile_path}[/yellow]" + ) + self.console.print( + "[red] HTTPS will NOT work without a proper Caddyfile![/red]" + ) self.config["HTTPS_ENABLED"] = "false" else: - with open(caddyfile_template, 'r') as f: + with open(caddyfile_template, "r") as f: caddyfile_content = f.read() # Replace TAILSCALE_IP with server_ip - caddyfile_content = caddyfile_content.replace('TAILSCALE_IP', server_ip) + caddyfile_content = caddyfile_content.replace( + "TAILSCALE_IP", server_ip + ) - with open(caddyfile_path, 'w') as f: + with open(caddyfile_path, "w") as f: f.write(caddyfile_content) - self.console.print(f"[green][SUCCESS][/green] Caddyfile created for: {server_ip}") + self.console.print( + f"[green][SUCCESS][/green] Caddyfile created for: {server_ip}" + ) self.config["HTTPS_ENABLED"] = "true" self.config["SERVER_IP"] = server_ip except Exception as e: - self.console.print(f"[red]❌ ERROR: Caddyfile generation failed: {e}[/red]") - self.console.print("[red] HTTPS will NOT work without a proper Caddyfile![/red]") + self.console.print( + f"[red]❌ ERROR: Caddyfile generation failed: {e}[/red]" + ) + self.console.print( + "[red] HTTPS will NOT work without a proper Caddyfile![/red]" + ) self.config["HTTPS_ENABLED"] = "false" else: self.console.print("[red]❌ ERROR: Caddyfile.template not found[/red]") - self.console.print("[red] HTTPS will NOT work without a proper Caddyfile![/red]") + self.console.print( + "[red] HTTPS will NOT work without a proper Caddyfile![/red]" + ) self.config["HTTPS_ENABLED"] = "false" else: self.config["HTTPS_ENABLED"] = "false" @@ -863,7 +1210,9 @@ def generate_env_file(self): shutil.copy2(env_template, env_path) self.console.print("[blue][INFO][/blue] Copied .env.template to .env") else: - self.console.print("[yellow][WARNING][/yellow] .env.template not found, creating new .env") + self.console.print( + "[yellow][WARNING][/yellow] .env.template not found, creating new .env" + ) env_path.touch(mode=0o600) # Update configured values using set_key @@ -875,24 +1224,35 @@ def generate_env_file(self): # Ensure secure permissions os.chmod(env_path, 0o600) - self.console.print("[green][SUCCESS][/green] .env file configured successfully with secure permissions") + self.console.print( + "[green][SUCCESS][/green] .env file configured successfully with secure permissions" + ) # Note: config.yml is automatically saved by ConfigManager when updates are made - self.console.print("[blue][INFO][/blue] Configuration saved to config.yml and .env (via ConfigManager)") + self.console.print( + "[blue][INFO][/blue] Configuration saved to config.yml and .env (via ConfigManager)" + ) def copy_config_templates(self): """Copy other configuration files""" - if not Path("diarization_config.json").exists() and Path("diarization_config.json.template").exists(): + if ( + not Path("diarization_config.json").exists() + and Path("diarization_config.json.template").exists() + ): shutil.copy2("diarization_config.json.template", "diarization_config.json") - self.console.print("[green][SUCCESS][/green] diarization_config.json created") + self.console.print( + "[green][SUCCESS][/green] diarization_config.json created" + ) def show_summary(self): """Show configuration summary""" self.print_section("Configuration Summary") self.console.print() - self.console.print(f"✅ Admin Account: {self.config.get('ADMIN_EMAIL', 'Not configured')}") + self.console.print( + f"✅ Admin Account: {self.config.get('ADMIN_EMAIL', 'Not configured')}" + ) # Get current config from ConfigManager (single source of truth) config_yml = self.config_manager.get_full_config() @@ -901,10 +1261,16 @@ def show_summary(self): stt_default = config_yml.get("defaults", {}).get("stt", "not set") stt_model = next( (m for m in config_yml.get("models", []) if m.get("name") == stt_default), - None + None, + ) + stt_provider = ( + stt_model.get("model_provider", "unknown") + if stt_model + else "not configured" + ) + self.console.print( + f"✅ Transcription: {stt_provider} ({stt_default}) - config.yml" ) - stt_provider = stt_model.get("model_provider", "unknown") if stt_model else "not configured" - self.console.print(f"✅ Transcription: {stt_provider} ({stt_default}) - config.yml") # Show LLM config from config.yml llm_default = config_yml.get("defaults", {}).get("llm", "not set") @@ -929,13 +1295,13 @@ def show_summary(self): self.console.print(f"✅ Knowledge Graph: Enabled ({neo4j_host})") # Auto-determine URLs based on HTTPS configuration - if self.config.get('HTTPS_ENABLED') == 'true': - server_ip = self.config.get('SERVER_IP', 'localhost') + if self.config.get("HTTPS_ENABLED") == "true": + server_ip = self.config.get("SERVER_IP", "localhost") self.console.print(f"✅ Backend URL: https://{server_ip}/") self.console.print(f"✅ Dashboard URL: https://{server_ip}/") else: - backend_port = self.config.get('BACKEND_PUBLIC_PORT', '8000') - webui_port = self.config.get('WEBUI_PORT', '5173') + backend_port = self.config.get("BACKEND_PUBLIC_PORT", "8000") + webui_port = self.config.get("WEBUI_PORT", "5173") self.console.print(f"✅ Backend URL: http://localhost:{backend_port}") self.console.print(f"✅ Dashboard URL: http://localhost:{webui_port}") @@ -950,40 +1316,52 @@ def show_next_steps(self): self.console.print("1. Start the main services:") self.console.print(" [cyan]docker compose up --build -d[/cyan]") self.console.print() - + # Auto-determine URLs for next steps - if self.config.get('HTTPS_ENABLED') == 'true': - server_ip = self.config.get('SERVER_IP', 'localhost') + if self.config.get("HTTPS_ENABLED") == "true": + server_ip = self.config.get("SERVER_IP", "localhost") self.console.print("2. Access the dashboard:") self.console.print(f" [cyan]https://{server_ip}/[/cyan]") self.console.print() self.console.print("3. Check service health:") self.console.print(f" [cyan]curl -k https://{server_ip}/health[/cyan]") else: - webui_port = self.config.get('WEBUI_PORT', '5173') - backend_port = self.config.get('BACKEND_PUBLIC_PORT', '8000') + webui_port = self.config.get("WEBUI_PORT", "5173") + backend_port = self.config.get("BACKEND_PUBLIC_PORT", "8000") self.console.print("2. Access the dashboard:") self.console.print(f" [cyan]http://localhost:{webui_port}[/cyan]") self.console.print() self.console.print("3. Check service health:") - self.console.print(f" [cyan]curl http://localhost:{backend_port}/health[/cyan]") + self.console.print( + f" [cyan]curl http://localhost:{backend_port}/health[/cyan]" + ) if self.config.get("MEMORY_PROVIDER") == "openmemory_mcp": self.console.print() self.console.print("4. Start OpenMemory MCP:") - self.console.print(" [cyan]cd ../../extras/openmemory-mcp && docker compose up -d[/cyan]") + self.console.print( + " [cyan]cd ../../extras/openmemory-mcp && docker compose up -d[/cyan]" + ) if self.config.get("TRANSCRIPTION_PROVIDER") == "offline": self.console.print() self.console.print("5. Start Parakeet ASR:") - self.console.print(" [cyan]cd ../../extras/asr-services && docker compose up parakeet -d[/cyan]") + self.console.print( + " [cyan]cd ../../extras/asr-services && docker compose up parakeet -d[/cyan]" + ) def run(self): """Run the complete setup process""" self.print_header("🚀 Chronicle Interactive Setup") - self.console.print("This wizard will help you configure Chronicle with all necessary services.") - self.console.print("[dim]Safe to run again — it backs up your config and preserves previous values.[/dim]") - self.console.print("[dim]When unsure, just press Enter — the defaults will work.[/dim]") + self.console.print( + "This wizard will help you configure Chronicle with all necessary services." + ) + self.console.print( + "[dim]Safe to run again — it backs up your config and preserves previous values.[/dim]" + ) + self.console.print( + "[dim]When unsure, just press Enter — the defaults will work.[/dim]" + ) self.console.print() try: @@ -1018,7 +1396,9 @@ def run(self): self.console.print() self.console.print("📝 [bold]Configuration files updated:[/bold]") self.console.print(f" • .env - API keys and environment variables") - self.console.print(f" • ../../config/config.yml - Model and memory provider configuration") + self.console.print( + f" • ../../config/config.yml - Model and memory provider configuration" + ) self.console.print() self.console.print("For detailed documentation, see:") self.console.print(" • Docs/quickstart.md") @@ -1037,39 +1417,68 @@ def run(self): def main(): """Main entry point""" parser = argparse.ArgumentParser(description="Chronicle Advanced Backend Setup") - parser.add_argument("--speaker-service-url", - help="Speaker Recognition service URL (default: prompt user)") - parser.add_argument("--parakeet-asr-url", - help="Parakeet ASR service URL (default: prompt user)") - parser.add_argument("--transcription-provider", - choices=["deepgram", "parakeet", "vibevoice", "qwen3-asr", "smallest", "none"], - help="Transcription provider (default: prompt user)") - parser.add_argument("--enable-https", action="store_true", - help="Enable HTTPS configuration (default: prompt user)") - parser.add_argument("--server-ip", - help="Server IP/domain for SSL certificate (default: prompt user)") - parser.add_argument("--enable-obsidian", action="store_true", - help="Enable Obsidian/Neo4j integration (default: prompt user)") - parser.add_argument("--enable-knowledge-graph", action="store_true", - help="Enable Knowledge Graph entity extraction (default: prompt user)") - parser.add_argument("--neo4j-password", - help="Neo4j password (default: prompt user)") - parser.add_argument("--ts-authkey", - help="Tailscale auth key for Docker integration (default: prompt user)") - parser.add_argument("--langfuse-public-key", - help="LangFuse project public key (from langfuse init or external)") - parser.add_argument("--langfuse-secret-key", - help="LangFuse project secret key (from langfuse init or external)") - parser.add_argument("--langfuse-host", - help="LangFuse host URL (default: http://langfuse-web:3000 for local)") - parser.add_argument("--langfuse-public-url", - help="LangFuse browser-accessible URL for deep-links (default: http://localhost:3002)") - parser.add_argument("--streaming-provider", - choices=["deepgram", "smallest", "qwen3-asr"], - help="Streaming provider when different from batch (enables batch re-transcription)") + parser.add_argument( + "--speaker-service-url", + help="Speaker Recognition service URL (default: prompt user)", + ) + parser.add_argument( + "--parakeet-asr-url", help="Parakeet ASR service URL (default: prompt user)" + ) + parser.add_argument( + "--transcription-provider", + choices=["deepgram", "parakeet", "vibevoice", "qwen3-asr", "smallest", "none"], + help="Transcription provider (default: prompt user)", + ) + parser.add_argument( + "--enable-https", + action="store_true", + help="Enable HTTPS configuration (default: prompt user)", + ) + parser.add_argument( + "--server-ip", + help="Server IP/domain for SSL certificate (default: prompt user)", + ) + parser.add_argument( + "--enable-obsidian", + action="store_true", + help="Enable Obsidian/Neo4j integration (default: prompt user)", + ) + parser.add_argument( + "--enable-knowledge-graph", + action="store_true", + help="Enable Knowledge Graph entity extraction (default: prompt user)", + ) + parser.add_argument( + "--neo4j-password", help="Neo4j password (default: prompt user)" + ) + parser.add_argument( + "--ts-authkey", + help="Tailscale auth key for Docker integration (default: prompt user)", + ) + parser.add_argument( + "--langfuse-public-key", + help="LangFuse project public key (from langfuse init or external)", + ) + parser.add_argument( + "--langfuse-secret-key", + help="LangFuse project secret key (from langfuse init or external)", + ) + parser.add_argument( + "--langfuse-host", + help="LangFuse host URL (default: http://langfuse-web:3000 for local)", + ) + parser.add_argument( + "--langfuse-public-url", + help="LangFuse browser-accessible URL for deep-links (default: http://localhost:3002)", + ) + parser.add_argument( + "--streaming-provider", + choices=["deepgram", "smallest", "qwen3-asr"], + help="Streaming provider when different from batch (enables batch re-transcription)", + ) args = parser.parse_args() - + setup = ChronicleSetup(args) setup.run() diff --git a/config_manager.py b/config_manager.py index 1c5079a2..6d85bba7 100644 --- a/config_manager.py +++ b/config_manager.py @@ -40,7 +40,9 @@ class ConfigManager: """Manages Chronicle configuration across config.yml and .env files.""" - def __init__(self, service_path: Optional[str] = None, repo_root: Optional[Path] = None): + def __init__( + self, service_path: Optional[str] = None, repo_root: Optional[Path] = None + ): """ Initialize ConfigManager. @@ -63,8 +65,10 @@ def __init__(self, service_path: Optional[str] = None, repo_root: Optional[Path] self.config_yml_path = self.repo_root / "config" / "config.yml" self.env_path = self.service_path / ".env" if self.service_path else None - logger.debug(f"ConfigManager initialized: repo_root={self.repo_root}, " - f"service_path={self.service_path}, config_yml={self.config_yml_path}") + logger.debug( + f"ConfigManager initialized: repo_root={self.repo_root}, " + f"service_path={self.service_path}, config_yml={self.config_yml_path}" + ) def _find_repo_root(self) -> Path: """Find repository root using __file__ location (config_manager.py is always at repo root).""" @@ -99,7 +103,7 @@ def _load_config_yml(self) -> Dict[str, Any]: ) try: - with open(self.config_yml_path, 'r') as f: + with open(self.config_yml_path, "r") as f: config = _yaml.load(f) if config is None: raise RuntimeError( @@ -120,12 +124,14 @@ def _save_config_yml(self, config: Dict[str, Any]): # Create backup if self.config_yml_path.exists(): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - backup_path = self.config_yml_path.parent / f"config.yml.backup.{timestamp}" + backup_path = ( + self.config_yml_path.parent / f"config.yml.backup.{timestamp}" + ) shutil.copy2(self.config_yml_path, backup_path) logger.info(f"Backed up config.yml to {backup_path.name}") # Write updated config - with open(self.config_yml_path, 'w') as f: + with open(self.config_yml_path, "w") as f: _yaml.dump(config, f) logger.info(f"Saved config.yml to {self.config_yml_path}") @@ -146,7 +152,7 @@ def _update_env_file(self, key: str, value: str): try: # Read current .env - with open(self.env_path, 'r') as f: + with open(self.env_path, "r") as f: lines = f.readlines() # Update or add line @@ -162,7 +168,9 @@ def _update_env_file(self, key: str, value: str): # If key wasn't found, add it if not key_found: - updated_lines.append(f"\n# Auto-updated by ConfigManager\n{key}={value}\n") + updated_lines.append( + f"\n# Auto-updated by ConfigManager\n{key}={value}\n" + ) # Create backup backup_path = f"{self.env_path}.bak" @@ -170,7 +178,7 @@ def _update_env_file(self, key: str, value: str): logger.debug(f"Backed up .env to {backup_path}") # Write updated file - with open(self.env_path, 'w') as f: + with open(self.env_path, "w") as f: f.writelines(updated_lines) # Update environment variable for current process @@ -248,7 +256,7 @@ def set_memory_provider(self, provider: str) -> Dict[str, Any]: "config_yml_path": str(self.config_yml_path), "env_path": str(self.env_path) if self.env_path else None, "requires_restart": True, - "status": "success" + "status": "success", } def get_memory_config(self) -> Dict[str, Any]: @@ -326,6 +334,25 @@ def update_config_defaults(self, updates: Dict[str, str]): self._save_config_yml(config) + def add_or_update_model(self, model_def: Dict[str, Any]): + """ + Add or update a model in the models list by name. + + Args: + model_def: Model definition dict with at least a 'name' key. + """ + config = self._load_config_yml() + if "models" not in config: + config["models"] = [] + # Update existing or append + for i, m in enumerate(config["models"]): + if m.get("name") == model_def["name"]: + config["models"][i] = model_def + break + else: + config["models"].append(model_def) + self._save_config_yml(config) + def get_full_config(self) -> Dict[str, Any]: """ Get complete config.yml as dictionary. diff --git a/extras/asr-services/docker-compose.yml b/extras/asr-services/docker-compose.yml index 84539124..fea49372 100644 --- a/extras/asr-services/docker-compose.yml +++ b/extras/asr-services/docker-compose.yml @@ -90,6 +90,8 @@ services: build: context: . dockerfile: providers/vibevoice/Dockerfile + args: + PYTORCH_CUDA_VERSION: ${PYTORCH_CUDA_VERSION:-cu126} image: chronicle-asr-vibevoice:latest ports: - "${ASR_PORT:-8767}:8765" @@ -119,6 +121,9 @@ services: # LoRA adapter: path to pre-trained adapter to auto-load on startup (optional) - LORA_ADAPTER_PATH=${LORA_ADAPTER_PATH:-} # Batching config: managed via config/defaults.yml (asr_services.vibevoice) + dns: + - 8.8.8.8 + - 8.8.4.4 restart: unless-stopped # ============================================================================ diff --git a/extras/asr-services/providers/vibevoice/Dockerfile b/extras/asr-services/providers/vibevoice/Dockerfile index 93ed36af..89002c59 100644 --- a/extras/asr-services/providers/vibevoice/Dockerfile +++ b/extras/asr-services/providers/vibevoice/Dockerfile @@ -8,6 +8,8 @@ ######################### builder ################################# FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim AS builder +ARG PYTORCH_CUDA_VERSION=cu126 + WORKDIR /app # Install system dependencies for building @@ -21,7 +23,7 @@ ENV UV_LINK_MODE=copy # Dependency manifest first for cache-friendly installs COPY pyproject.toml uv.lock ./ RUN --mount=type=cache,target=/root/.cache/uv \ - uv sync --frozen --no-install-project --group vibevoice + uv sync --frozen --no-install-project --group vibevoice --extra ${PYTORCH_CUDA_VERSION} ######################### runtime ################################# FROM python:3.12-slim-bookworm AS runtime diff --git a/extras/asr-services/tests/test_cuda_version_config.py b/extras/asr-services/tests/test_cuda_version_config.py new file mode 100644 index 00000000..51acf439 --- /dev/null +++ b/extras/asr-services/tests/test_cuda_version_config.py @@ -0,0 +1,236 @@ +""" +Unit tests for CUDA version configuration in ASR service Dockerfiles. + +Tests the configurable PYTORCH_CUDA_VERSION build arg that allows selecting +different CUDA versions (cu121, cu126, cu128) for different GPU architectures. +""" + +import os +import re +from pathlib import Path + +import pytest + + +class TestDockerfileCUDASupport: + """Test that Dockerfiles support configurable CUDA versions.""" + + @pytest.fixture + def vibevoice_dockerfile_path(self): + """Path to VibeVoice Dockerfile.""" + return Path(__file__).parent.parent / "providers" / "vibevoice" / "Dockerfile" + + @pytest.fixture + def nemo_dockerfile_path(self): + """Path to NeMo Dockerfile.""" + return Path(__file__).parent.parent / "providers" / "nemo" / "Dockerfile" + + @pytest.fixture + def docker_compose_path(self): + """Path to docker-compose.yml.""" + return Path(__file__).parent.parent / "docker-compose.yml" + + def test_vibevoice_dockerfile_has_cuda_arg(self, vibevoice_dockerfile_path): + """Test that VibeVoice Dockerfile declares PYTORCH_CUDA_VERSION arg.""" + content = vibevoice_dockerfile_path.read_text() + + # Should have ARG declaration + assert re.search( + r"ARG\s+PYTORCH_CUDA_VERSION", content + ), "Dockerfile must declare PYTORCH_CUDA_VERSION build arg" + + # Should have default value + arg_match = re.search(r"ARG\s+PYTORCH_CUDA_VERSION=(\w+)", content) + assert arg_match, "PYTORCH_CUDA_VERSION should have default value" + default_version = arg_match.group(1) + assert default_version in [ + "cu121", + "cu126", + "cu128", + ], f"Default CUDA version {default_version} should be cu121, cu126, or cu128" + + def test_vibevoice_dockerfile_uses_cuda_arg_in_uv_sync( + self, vibevoice_dockerfile_path + ): + """Test that VibeVoice Dockerfile uses CUDA arg in uv sync command.""" + content = vibevoice_dockerfile_path.read_text() + + # Should use --extra ${PYTORCH_CUDA_VERSION} + assert re.search( + r"uv\s+sync.*--extra\s+\$\{PYTORCH_CUDA_VERSION\}", content + ), "uv sync command must include --extra ${PYTORCH_CUDA_VERSION}" + + def test_nemo_dockerfile_has_cuda_support(self, nemo_dockerfile_path): + """Test that NeMo Dockerfile (reference implementation) has CUDA support.""" + content = nemo_dockerfile_path.read_text() + + assert re.search( + r"ARG\s+PYTORCH_CUDA_VERSION", content + ), "NeMo Dockerfile should have PYTORCH_CUDA_VERSION arg" + + assert re.search( + r"uv\s+sync.*--extra\s+\$\{PYTORCH_CUDA_VERSION\}", content + ), "NeMo Dockerfile should use CUDA version in uv sync" + + def test_docker_compose_passes_cuda_arg_to_vibevoice(self, docker_compose_path): + """Test that docker-compose.yml passes PYTORCH_CUDA_VERSION to vibevoice service.""" + content = docker_compose_path.read_text() + + # Find vibevoice-asr service section + vibevoice_section = re.search( + r"vibevoice-asr:.*?(?=^\S|\Z)", content, re.MULTILINE | re.DOTALL + ) + assert vibevoice_section, "docker-compose.yml must have vibevoice-asr service" + + section_text = vibevoice_section.group(0) + + # Should have build args section + assert re.search( + r"args:", section_text + ), "vibevoice-asr service should have build args section" + + # Should pass PYTORCH_CUDA_VERSION + assert re.search( + r"PYTORCH_CUDA_VERSION:\s*\$\{PYTORCH_CUDA_VERSION:-cu126\}", section_text + ), "vibevoice-asr should pass PYTORCH_CUDA_VERSION build arg with cu126 default" + + def test_docker_compose_cuda_arg_consistency(self, docker_compose_path): + """Test that all GPU-enabled services use consistent CUDA version pattern.""" + content = docker_compose_path.read_text() + + # Services that should have CUDA support + gpu_services = ["vibevoice-asr", "nemo-asr", "parakeet-asr"] + + for service_name in gpu_services: + service_match = re.search( + rf"{service_name}:.*?(?=^\S|\Z)", content, re.MULTILINE | re.DOTALL + ) + + if service_match: + service_text = service_match.group(0) + + # Check if service has GPU resources + if "devices:" in service_text and "nvidia" in service_text: + # Should have PYTORCH_CUDA_VERSION arg + assert re.search( + r"PYTORCH_CUDA_VERSION:\s*\$\{PYTORCH_CUDA_VERSION:-cu\d+\}", + service_text, + ), f"{service_name} with GPU should have PYTORCH_CUDA_VERSION build arg" + + +class TestCUDAVersionEnvironmentVariable: + """Test CUDA version environment variable handling.""" + + def test_cuda_version_env_var_format(self): + """Test that CUDA version environment variables follow correct format.""" + valid_versions = ["cu121", "cu126", "cu128"] + + for version in valid_versions: + assert re.match( + r"^cu\d{3}$", version + ), f"{version} should match pattern cu### (e.g., cu121, cu126)" + + def test_cuda_version_from_env(self): + """Test reading CUDA version from environment.""" + test_version = "cu128" + + with pytest.MonkeyPatch.context() as mp: + mp.setenv("PYTORCH_CUDA_VERSION", test_version) + cuda_version = os.getenv("PYTORCH_CUDA_VERSION") + + assert cuda_version == test_version + assert cuda_version in ["cu121", "cu126", "cu128"] + + def test_cuda_version_default_fallback(self): + """Test that default CUDA version is used when env var not set.""" + with pytest.MonkeyPatch.context() as mp: + mp.delenv("PYTORCH_CUDA_VERSION", raising=False) + + # Simulate docker-compose default: ${PYTORCH_CUDA_VERSION:-cu126} + cuda_version = os.getenv("PYTORCH_CUDA_VERSION", "cu126") + + assert cuda_version == "cu126" + + +class TestGPUArchitectureCUDAMapping: + """Test that GPU architectures map to correct CUDA versions.""" + + def test_rtx_5090_requires_cu128(self): + """ + Test that RTX 5090 (sm_120) requires CUDA 12.8+. + + RTX 5090 has CUDA capability 12.0 (sm_120) which requires + PyTorch built with CUDA 12.8 or higher. + """ + gpu_arch = "sm_120" # RTX 5090 + required_cuda = "cu128" + + # Map GPU architecture to minimum CUDA version + arch_to_cuda = { + "sm_120": "cu128", # RTX 5090, RTX 50 series + "sm_90": "cu126", # RTX 4090, H100 + "sm_89": "cu121", # RTX 4090 + "sm_86": "cu121", # RTX 3090, A6000 + } + + assert ( + arch_to_cuda.get(gpu_arch) == required_cuda + ), f"GPU architecture {gpu_arch} requires CUDA version {required_cuda}" + + # Architectures supported by each CUDA version (minimum cu version that supports them) + # Used as authoritative reference for architecture-to-CUDA mapping tests. + CUDA_ARCH_SUPPORT = { + "cu121": {"sm_75", "sm_80", "sm_86", "sm_89"}, + "cu126": {"sm_75", "sm_80", "sm_86", "sm_89", "sm_90"}, + "cu128": {"sm_75", "sm_80", "sm_86", "sm_89", "sm_90", "sm_120"}, + } + + def test_older_gpus_work_with_cu121(self): + """Test that older GPUs (sm_86, sm_80) work with cu121.""" + older_archs = ["sm_86", "sm_80", "sm_75"] # RTX 3090, A100, RTX 2080 + cu121_supported = self.CUDA_ARCH_SUPPORT["cu121"] + + for arch in older_archs: + assert arch in cu121_supported, f"{arch} should be supported by CUDA 12.1" + + +class TestPyProjectCUDAExtras: + """Test that pyproject.toml defines CUDA version extras correctly.""" + + @pytest.fixture + def pyproject_path(self): + """Path to pyproject.toml.""" + return Path(__file__).parent.parent / "pyproject.toml" + + def test_pyproject_has_cuda_extras(self, pyproject_path): + """Test that pyproject.toml defines cu121, cu126, cu128 extras.""" + if not pyproject_path.exists(): + pytest.skip("pyproject.toml not found") + + content = pyproject_path.read_text() + + # Should have [project.optional-dependencies] or [tool.uv] with extras + cuda_versions = ["cu121", "cu126", "cu128"] + + for version in cuda_versions: + # Look for the CUDA version as an extra + assert re.search( + rf'["\']?{version}["\']?\s*=', content + ), f"pyproject.toml should define {version} extra" + + def test_pyproject_cuda_extras_have_pytorch(self, pyproject_path): + """Test that CUDA extras include torch/torchaudio dependencies.""" + if not pyproject_path.exists(): + pytest.skip("pyproject.toml not found") + + content = pyproject_path.read_text() + + # Each CUDA extra should reference torch with the appropriate index + # e.g., { extra = "cu128" } or { index = "pytorch-cu128" } + assert re.search(r'extra\s*=\s*["\']cu\d{3}["\']', content) or re.search( + r'index\s*=\s*["\']pytorch-cu\d{3}["\']', content + ), "CUDA extras should reference PyTorch with CUDA version" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/configuration/test_llm_custom_provider.robot b/tests/configuration/test_llm_custom_provider.robot new file mode 100644 index 00000000..fa9a09c3 --- /dev/null +++ b/tests/configuration/test_llm_custom_provider.robot @@ -0,0 +1,258 @@ +*** Settings *** +Documentation Tests for LLM Custom Provider Setup (ConfigManager) +Library OperatingSystem +Library Collections +Library String +Library ../libs/ConfigTestHelper.py + +*** Keywords *** +Setup Temp Config + [Documentation] Creates a temporary configuration environment + ${random_suffix}= Generate Random String 8 [NUMBERS] + ${temp_path}= Join Path ${OUTPUT DIR} temp_config_${random_suffix} + Create Directory ${temp_path} + + # Create initial default config content + ${defaults}= Create Dictionary llm=openai-llm embedding=openai-embed stt=stt-deepgram + ${model1_params}= Create Dictionary temperature=${0.2} max_tokens=${2000} + ${model1}= Create Dictionary + ... name=openai-llm + ... description=OpenAI GPT-4o-mini + ... model_type=llm + ... model_provider=openai + ... api_family=openai + ... model_name=gpt-4o-mini + ... model_url=https://api.openai.com/v1 + ... api_key=\${oc.env:OPENAI_API_KEY,''} + ... model_params=${model1_params} + ... model_output=json + + ${model2}= Create Dictionary + ... name=local-embed + ... description=Local embeddings via Ollama + ... model_type=embedding + ... model_provider=ollama + ... api_family=openai + ... model_name=nomic-embed-text:latest + ... model_url=http://localhost:11434/v1 + ... api_key=\${oc.env:OPENAI_API_KEY,ollama} + ... embedding_dimensions=${768} + ... model_output=vector + + ${models}= Create List ${model1} ${model2} + ${memory}= Create Dictionary provider=chronicle + ${config}= Create Dictionary defaults=${defaults} models=${models} memory=${memory} + + Create Temp Config Structure ${temp_path} ${config} + Set Test Variable ${TEMP_PATH} ${temp_path} + +Cleanup Temp Config + Remove Directory ${TEMP_PATH} recursive=True + +*** Test Cases *** +Add New Model To Config + [Documentation] add_or_update_model() should append a new model when name doesn't exist. + [Setup] Setup Temp Config + [Teardown] Cleanup Temp Config + + ${params}= Create Dictionary temperature=${0.2} max_tokens=${2000} + ${new_model}= Create Dictionary + ... name=custom-llm + ... description=Custom OpenAI-compatible LLM + ... model_type=llm + ... model_provider=openai + ... api_family=openai + ... model_name=llama-3.1-70b-versatile + ... model_url=https://api.groq.com/openai/v1 + ... api_key=\${oc.env:CUSTOM_LLM_API_KEY,''} + ... model_params=${params} + ... model_output=json + + ${cm}= Get Config Manager Instance ${TEMP_PATH} + Add Model To Config Manager ${cm} ${new_model} + + ${config}= Call Method ${cm} get_full_config + ${models}= Get From Dictionary ${config} models + + ${target_model}= Set Variable ${None} + FOR ${m} IN @{models} + Run Keyword If '${m["name"]}' == 'custom-llm' Set Test Variable ${target_model} ${m} + END + + Should Not Be Equal ${target_model} ${None} + Should Be Equal ${target_model["model_name"]} llama-3.1-70b-versatile + Should Be Equal ${target_model["model_url"]} https://api.groq.com/openai/v1 + Should Be Equal ${target_model["model_type"]} llm + +Update Existing Model + [Documentation] add_or_update_model() should replace an existing model with the same name. + [Setup] Setup Temp Config + [Teardown] Cleanup Temp Config + + ${cm}= Get Config Manager Instance ${TEMP_PATH} + + # First add + ${model_v1}= Create Dictionary name=custom-llm model_type=llm model_name=model-v1 model_url=https://example.com/v1 + Add Model To Config Manager ${cm} ${model_v1} + + # Then update + ${model_v2}= Create Dictionary name=custom-llm model_type=llm model_name=model-v2 model_url=https://example.com/v2 + Add Model To Config Manager ${cm} ${model_v2} + + ${config}= Call Method ${cm} get_full_config + ${models}= Get From Dictionary ${config} models + + ${count}= Set Variable 0 + ${target_model}= Set Variable ${None} + FOR ${m} IN @{models} + IF '${m["name"]}' == 'custom-llm' + Set Test Variable ${target_model} ${m} + ${count}= Evaluate ${count} + 1 + END + END + + Should Be Equal As Integers ${count} 1 + Should Be Equal ${target_model["model_name"]} model-v2 + Should Be Equal ${target_model["model_url"]} https://example.com/v2 + +Add Model To Empty Models List + [Documentation] add_or_update_model() should create models list if it doesn't exist. + [Setup] Setup Temp Config + [Teardown] Cleanup Temp Config + + # Overwrite config with empty models + ${defaults}= Create Dictionary llm=openai-llm + ${empty_config}= Create Dictionary defaults=${defaults} + Create Temp Config Structure ${TEMP_PATH} ${empty_config} + + ${cm}= Get Config Manager Instance ${TEMP_PATH} + ${test_model}= Create Dictionary name=test-model model_type=llm + Add Model To Config Manager ${cm} ${test_model} + + ${config}= Call Method ${cm} get_full_config + Dictionary Should Contain Key ${config} models + ${models}= Get From Dictionary ${config} models + Length Should Be ${models} 1 + Should Be Equal ${models[0]["name"]} test-model + +Custom LLM And Embedding Model Added + [Documentation] Both LLM and embedding models should be created when embedding model is provided. + [Setup] Setup Temp Config + [Teardown] Cleanup Temp Config + + ${cm}= Get Config Manager Instance ${TEMP_PATH} + + ${params}= Create Dictionary temperature=${0.2} max_tokens=${2000} + ${llm_model}= Create Dictionary + ... name=custom-llm + ... model_type=llm + ... model_provider=openai + ... api_family=openai + ... model_name=llama-3.1-70b-versatile + ... model_url=https://api.groq.com/openai/v1 + ... api_key=\${oc.env:CUSTOM_LLM_API_KEY,''} + ... model_params=${params} + ... model_output=json + + ${embed_model}= Create Dictionary + ... name=custom-embed + ... description=Custom OpenAI-compatible embeddings + ... model_type=embedding + ... model_provider=openai + ... api_family=openai + ... model_name=text-embedding-3-small + ... model_url=https://api.groq.com/openai/v1 + ... api_key=\${oc.env:CUSTOM_LLM_API_KEY,''} + ... embedding_dimensions=${1536} + ... model_output=vector + + Add Model To Config Manager ${cm} ${llm_model} + Add Model To Config Manager ${cm} ${embed_model} + + ${config}= Call Method ${cm} get_full_config + ${models}= Get From Dictionary ${config} models + ${model_names}= Create List + FOR ${m} IN @{models} + Append To List ${model_names} ${m["name"]} + END + + List Should Contain Value ${model_names} custom-llm + List Should Contain Value ${model_names} custom-embed + + ${target_embed}= Set Variable ${None} + FOR ${m} IN @{models} + Run Keyword If '${m["name"]}' == 'custom-embed' Set Test Variable ${target_embed} ${m} + END + + Should Be Equal ${target_embed["model_type"]} embedding + Should Be Equal ${target_embed["model_name"]} text-embedding-3-small + Should Be Equal As Integers ${target_embed["embedding_dimensions"]} 1536 + +Custom LLM Without Embedding Falls Back To Local + [Documentation] defaults.embedding should be local-embed when no custom embedding is provided. + [Setup] Setup Temp Config + [Teardown] Cleanup Temp Config + + ${cm}= Get Config Manager Instance ${TEMP_PATH} + + ${llm_model}= Create Dictionary + ... name=custom-llm + ... model_type=llm + ... model_name=some-model + ... model_url=https://api.example.com/v1 + + Add Model To Config Manager ${cm} ${llm_model} + ${defaults_update}= Create Dictionary llm=custom-llm embedding=local-embed + Update Defaults In Config Manager ${cm} ${defaults_update} + + ${defaults}= Call Method ${cm} get_config_defaults + Should Be Equal ${defaults["llm"]} custom-llm + Should Be Equal ${defaults["embedding"]} local-embed + +Custom LLM Updates Defaults With Embedding + [Documentation] defaults.llm and defaults.embedding should be updated correctly with custom embed. + [Setup] Setup Temp Config + [Teardown] Cleanup Temp Config + + ${cm}= Get Config Manager Instance ${TEMP_PATH} + + ${defaults_update}= Create Dictionary llm=custom-llm embedding=custom-embed + Update Defaults In Config Manager ${cm} ${defaults_update} + + ${defaults}= Call Method ${cm} get_config_defaults + Should Be Equal ${defaults["llm"]} custom-llm + Should Be Equal ${defaults["embedding"]} custom-embed + +Existing Models Preserved After Adding Custom + [Documentation] Adding a custom model should not remove existing models. + [Setup] Setup Temp Config + [Teardown] Cleanup Temp Config + + ${cm}= Get Config Manager Instance ${TEMP_PATH} + ${config_before}= Call Method ${cm} get_full_config + ${models_before}= Get From Dictionary ${config_before} models + ${original_count}= Get Length ${models_before} + + ${new_model}= Create Dictionary + ... name=custom-llm + ... model_type=llm + ... model_name=test-model + ... model_url=https://example.com/v1 + + Add Model To Config Manager ${cm} ${new_model} + + ${config_after}= Call Method ${cm} get_full_config + ${models_after}= Get From Dictionary ${config_after} models + ${new_count}= Get Length ${models_after} + ${expected_count}= Evaluate ${original_count} + 1 + + Should Be Equal As Integers ${new_count} ${expected_count} + + ${model_names}= Create List + FOR ${m} IN @{models_after} + Append To List ${model_names} ${m["name"]} + END + + List Should Contain Value ${model_names} openai-llm + List Should Contain Value ${model_names} local-embed + List Should Contain Value ${model_names} custom-llm \ No newline at end of file diff --git a/tests/configuration/test_transcription_url.robot b/tests/configuration/test_transcription_url.robot new file mode 100644 index 00000000..e0ba40e8 --- /dev/null +++ b/tests/configuration/test_transcription_url.robot @@ -0,0 +1,126 @@ +*** Settings *** +Documentation Tests for Transcription Service URL Configuration +Library Collections +Library ../libs/ConfigTestHelper.py + +*** Test Cases *** +Vibevoice Url Without Http Prefix + [Documentation] Test that VIBEVOICE_ASR_URL without http:// prefix works correctly. + ${config_template}= Create Dictionary model_url=http://\${oc.env:VIBEVOICE_ASR_URL,host.docker.internal:8767} + ${env_vars}= Create Dictionary VIBEVOICE_ASR_URL=host.docker.internal:8767 + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + Should Be Equal ${resolved["model_url"]} http://host.docker.internal:8767 + Should Not Contain ${resolved["model_url"]} http://http:// + +Vibevoice Url With Http Prefix Causes Double Prefix + [Documentation] Test that VIBEVOICE_ASR_URL WITH http:// causes double prefix (bug scenario). + ${config_template}= Create Dictionary model_url=http://\${oc.env:VIBEVOICE_ASR_URL,host.docker.internal:8767} + ${env_vars}= Create Dictionary VIBEVOICE_ASR_URL=http://host.docker.internal:8767 + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + Should Be Equal ${resolved["model_url"]} http://http://host.docker.internal:8767 + Should Contain ${resolved["model_url"]} http://http:// + +Vibevoice Url Default Fallback + [Documentation] Test that default fallback works when VIBEVOICE_ASR_URL is not set. + ${config_template}= Create Dictionary model_url=http://\${oc.env:VIBEVOICE_ASR_URL,host.docker.internal:8767} + ${env_vars}= Create Dictionary + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + Should Be Equal ${resolved["model_url"]} http://host.docker.internal:8767 + +Parakeet Url Configuration + [Documentation] Test that PARAKEET_ASR_URL follows same pattern. + ${config_template}= Create Dictionary model_url=http://\${oc.env:PARAKEET_ASR_URL,172.17.0.1:8767} + ${env_vars}= Create Dictionary PARAKEET_ASR_URL=host.docker.internal:8767 + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + Should Be Equal ${resolved["model_url"]} http://host.docker.internal:8767 + Should Not Contain ${resolved["model_url"]} http://http:// + +Url Parsing Removes Double Slashes + [Documentation] Test that URL with double http:// causes connection failures (simulated by parsing check). + + # Valid URL + ${valid_url}= Set Variable http://host.docker.internal:8767/transcribe + ${parsed_valid}= Check Url Parsing ${valid_url} + Should Be Equal ${parsed_valid["scheme"]} http + Should Be Equal ${parsed_valid["netloc"]} host.docker.internal:8767 + + # Invalid URL + ${invalid_url}= Set Variable http://http://host.docker.internal:8767/transcribe + ${parsed_invalid}= Check Url Parsing ${invalid_url} + Should Be Equal ${parsed_invalid["scheme"]} http + # In python urlparse, 'http:' becomes the netloc for 'http://http://...' + Should Be Equal ${parsed_invalid["netloc"]} http: + Should Not Be Equal ${parsed_invalid["netloc"]} host.docker.internal:8767 + +Use Provider Segments Default False + [Documentation] Test that use_provider_segments defaults to false. + ${transcription}= Create Dictionary + ${backend}= Create Dictionary transcription=${transcription} + ${config_template}= Create Dictionary backend=${backend} + ${env_vars}= Create Dictionary + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + ${val}= Evaluate $resolved.get('backend', {}).get('transcription', {}).get('use_provider_segments', False) + Should Be Equal ${val} ${FALSE} + +Use Provider Segments Explicit True + [Documentation] Test that use_provider_segments can be enabled. + ${transcription}= Create Dictionary use_provider_segments=${TRUE} + ${backend}= Create Dictionary transcription=${transcription} + ${config_template}= Create Dictionary backend=${backend} + ${env_vars}= Create Dictionary + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + ${val}= Evaluate $resolved['backend']['transcription']['use_provider_segments'] + Should Be Equal ${val} ${TRUE} + +Vibevoice Should Use Provider Segments + [Documentation] Test that VibeVoice provider should have use_provider_segments=true since it provides diarized segments. + # Logic simulation + ${vibevoice_capabilities}= Create List segments diarization + ${has_diarization}= Evaluate "diarization" in $vibevoice_capabilities + ${has_segments}= Evaluate "segments" in $vibevoice_capabilities + ${should_use_segments}= Evaluate $has_diarization and $has_segments + Should Be Equal ${should_use_segments} ${TRUE} + +Model Registry Url Resolution With Env Var + [Documentation] Test that model URLs resolve correctly from environment. + ${model_def}= Create Dictionary + ... name=stt-vibevoice + ... model_type=stt + ... model_provider=vibevoice + ... model_url=http://\${oc.env:VIBEVOICE_ASR_URL,host.docker.internal:8767} + + ${models}= Create List ${model_def} + ${defaults}= Create Dictionary stt=stt-vibevoice + ${config_template}= Create Dictionary defaults=${defaults} models=${models} + + ${env_vars}= Create Dictionary VIBEVOICE_ASR_URL=host.docker.internal:8767 + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + ${resolved_models}= Get From Dictionary ${resolved} models + Should Be Equal ${resolved_models[0]["model_url"]} http://host.docker.internal:8767 + +Multiple Asr Providers Url Resolution + [Documentation] Test that multiple ASR providers can use different URL patterns. + ${m1}= Create Dictionary name=stt-vibevoice model_url=http://\${oc.env:VIBEVOICE_ASR_URL,host.docker.internal:8767} + ${m2}= Create Dictionary name=stt-parakeet model_url=http://\${oc.env:PARAKEET_ASR_URL,172.17.0.1:8767} + ${m3}= Create Dictionary name=stt-deepgram model_url=https://api.deepgram.com/v1 + + ${models}= Create List ${m1} ${m2} ${m3} + ${config_template}= Create Dictionary models=${models} + + ${env_vars}= Create Dictionary + ... VIBEVOICE_ASR_URL=host.docker.internal:8767 + ... PARAKEET_ASR_URL=localhost:8080 + + ${resolved}= Resolve Omega Config ${config_template} ${env_vars} + ${resolved_models}= Get From Dictionary ${resolved} models + + Should Be Equal ${resolved_models[0]["model_url"]} http://host.docker.internal:8767 + Should Be Equal ${resolved_models[1]["model_url"]} http://localhost:8080 + Should Be Equal ${resolved_models[2]["model_url"]} https://api.deepgram.com/v1 diff --git a/tests/libs/ConfigTestHelper.py b/tests/libs/ConfigTestHelper.py new file mode 100644 index 00000000..6fbdcab4 --- /dev/null +++ b/tests/libs/ConfigTestHelper.py @@ -0,0 +1,73 @@ +import os +import sys +import yaml +from pathlib import Path +from typing import Dict, Any, Optional, List +from omegaconf import OmegaConf +from unittest.mock import patch + +# Add repo root to path to import config_manager +sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent)) +from config_manager import ConfigManager + +class ConfigTestHelper: + """Helper library for testing configuration logic.""" + + def _to_dict(self, obj: Any) -> Any: + """Recursively converts Robot Framework DotDict to standard dict.""" + if isinstance(obj, dict): + return {k: self._to_dict(v) for k, v in obj.items()} + if isinstance(obj, list): + return [self._to_dict(v) for v in obj] + return obj + + def resolve_omega_config(self, config_template: Dict[str, Any], env_vars: Dict[str, str]) -> Dict[str, Any]: + """ + Resolves an OmegaConf configuration template with provided environment variables. + """ + config_template = self._to_dict(config_template) + # We need to ensure values are strings for os.environ + str_env_vars = {k: str(v) for k, v in env_vars.items()} + + with patch.dict(os.environ, str_env_vars): + conf = OmegaConf.create(config_template) + resolved = OmegaConf.to_container(conf, resolve=True) + return resolved + + def check_url_parsing(self, url: str) -> Dict[str, Any]: + """ + Parses a URL and returns its components to verify correct parsing. + """ + from urllib.parse import urlparse + parsed = urlparse(url) + return { + "scheme": parsed.scheme, + "netloc": parsed.netloc, + "path": parsed.path + } + + def create_temp_config_structure(self, base_path: str, content: Dict[str, Any]) -> str: + """ + Creates the config folder structure and config.yml within the given base path. + """ + content = self._to_dict(content) + path = Path(base_path) / "config" + path.mkdir(parents=True, exist_ok=True) + config_file = path / "config.yml" + with open(config_file, "w") as f: + yaml.dump(content, f, default_flow_style=False, sort_keys=False) + return str(base_path) + + def get_config_manager_instance(self, repo_root: str) -> ConfigManager: + """Returns a ConfigManager instance configured with the given repo_root.""" + return ConfigManager(service_path=None, repo_root=Path(repo_root)) + + def add_model_to_config_manager(self, cm: ConfigManager, model_def: Dict[str, Any]): + """Wrapper for add_or_update_model that converts arguments.""" + model_def = self._to_dict(model_def) + cm.add_or_update_model(model_def) + + def update_defaults_in_config_manager(self, cm: ConfigManager, updates: Dict[str, str]): + """Wrapper for update_config_defaults that converts arguments.""" + updates = self._to_dict(updates) + cm.update_config_defaults(updates) \ No newline at end of file diff --git a/tests/test-requirements.txt b/tests/test-requirements.txt index f32614e0..5cd8f020 100644 --- a/tests/test-requirements.txt +++ b/tests/test-requirements.txt @@ -6,4 +6,6 @@ robotframework-databaselibrary python-dotenv websockets pymongo +omegaconf +pyyaml \ No newline at end of file diff --git a/wizard.py b/wizard.py index e4d52395..e8efaf80 100755 --- a/wizard.py +++ b/wizard.py @@ -258,7 +258,7 @@ def run_service_setup( if "speaker-recognition" in selected_services: cmd.extend(["--speaker-service-url", "http://speaker-service:8085"]) if "asr-services" in selected_services: - cmd.extend(["--parakeet-asr-url", "http://host.docker.internal:8767"]) + cmd.extend(["--parakeet-asr-url", "host.docker.internal:8767"]) # Pass transcription provider choice from wizard if transcription_provider: