diff --git a/README.md b/README.md index a4383ba2..3ff922d3 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,7 @@ Usecases are numerous - OMI Mentor is one of them. Friend/Omi/pendants are a sma Regardless - this repo will try to do the minimal of this - multiple OMI-like audio devices feeding audio data - and from it: - Memories -- Action items +- Action items - Home automation ## Golden Goals (Not Yet Achieved) @@ -179,4 +179,3 @@ Regardless - this repo will try to do the minimal of this - multiple OMI-like au - **Home automation integration** (planned) - **Multi-device coordination** (planned) - **Visual context capture** (smart glasses integration planned) - diff --git a/app/src/components/StatusIndicator.tsx b/app/src/components/StatusIndicator.tsx index df55f80f..593df0d3 100644 --- a/app/src/components/StatusIndicator.tsx +++ b/app/src/components/StatusIndicator.tsx @@ -15,12 +15,12 @@ const StatusIndicator: React.FC = ({ inactiveColor = '#FF3B30', // Red }) => { return ( - (false); const [retryAttempts, setRetryAttempts] = useState(0); const { addEvent } = useConnectionLog(); - + const audioSubscriptionRef = useRef(null); const uiUpdateIntervalRef = useRef(null); const localPacketCounterRef = useRef(0); const retryTimeoutRef = useRef(null); const shouldRetryRef = useRef(false); const currentOnAudioDataRef = useRef<((bytes: Uint8Array) => void) | null>(null); - + // Retry configuration const MAX_RETRY_ATTEMPTS = 10; const INITIAL_RETRY_DELAY = 1000; // 1 second @@ -37,23 +37,23 @@ export const useAudioListener = ( const stopAudioListener = useCallback(async () => { console.log('Attempting to stop audio listener...'); - + // Stop retry mechanism shouldRetryRef.current = false; setIsRetrying(false); setRetryAttempts(0); currentOnAudioDataRef.current = null; - + if (retryTimeoutRef.current) { clearTimeout(retryTimeoutRef.current); retryTimeoutRef.current = null; } - + if (uiUpdateIntervalRef.current) { clearInterval(uiUpdateIntervalRef.current); uiUpdateIntervalRef.current = null; } - + if (audioSubscriptionRef.current) { try { await omiConnection.stopAudioBytesListener(audioSubscriptionRef.current); @@ -147,7 +147,7 @@ export const useAudioListener = ( setIsRetrying(true); const success = await attemptStartAudioListener(currentOnAudioDataRef.current); - + if (success) { console.log('[AudioListener] Retry successful'); return; @@ -157,7 +157,7 @@ export const useAudioListener = ( if (shouldRetryRef.current) { const delay = getRetryDelay(currentAttempt); console.log(`[AudioListener] Scheduling retry in ${Math.round(delay)}ms`); - + retryTimeoutRef.current = setTimeout(() => { if (shouldRetryRef.current) { retryStartAudioListener(); @@ -171,7 +171,7 @@ export const useAudioListener = ( Alert.alert('Not Connected', 'Please connect to a device first to start audio listener.'); return; } - + if (isListeningAudio) { console.log('[AudioListener] Audio listener is already active. Stopping first.'); await stopAudioListener(); @@ -180,7 +180,7 @@ export const useAudioListener = ( // Store the callback for retry attempts currentOnAudioDataRef.current = onAudioData; shouldRetryRef.current = true; - + setAudioPacketsReceived(0); // Reset counter on start localPacketCounterRef.current = 0; setRetryAttempts(0); @@ -197,7 +197,7 @@ export const useAudioListener = ( // Try to start audio listener const success = await attemptStartAudioListener(onAudioData); - + if (!success && shouldRetryRef.current) { console.log('[AudioListener] Initial attempt failed, starting retry mechanism'); setIsRetrying(true); @@ -227,4 +227,4 @@ export const useAudioListener = ( isRetrying, retryAttempts, }; -}; \ No newline at end of file +}; diff --git a/app/src/hooks/useBluetoothManager.ts b/app/src/hooks/useBluetoothManager.ts index f2f4aba2..1543ad4f 100644 --- a/app/src/hooks/useBluetoothManager.ts +++ b/app/src/hooks/useBluetoothManager.ts @@ -49,7 +49,7 @@ export const useBluetoothManager = () => { PermissionsAndroid.PERMISSIONS.ACCESS_FINE_LOCATION, ]; } - + console.log('[BTManager] Android permissions to request:', permissionsToRequest); const statuses = await PermissionsAndroid.requestMultiple(permissionsToRequest); console.log('[BTManager] Android permission statuses:', statuses); @@ -100,7 +100,7 @@ export const useBluetoothManager = () => { checkAndRequestPermissions(); } }, [bluetoothState, checkAndRequestPermissions]); // Rerun if BT state changes or on initial mount - + return { bleManager, bluetoothState, @@ -108,4 +108,4 @@ export const useBluetoothManager = () => { requestBluetoothPermission: checkAndRequestPermissions, isPermissionsLoading, }; -}; \ No newline at end of file +}; diff --git a/app/src/hooks/useDeviceConnection.ts b/app/src/hooks/useDeviceConnection.ts index 96468dac..03d6a4cf 100644 --- a/app/src/hooks/useDeviceConnection.ts +++ b/app/src/hooks/useDeviceConnection.ts @@ -189,4 +189,4 @@ export const useDeviceConnection = ( getRawBatteryLevel, connectedDeviceId }; -}; \ No newline at end of file +}; diff --git a/app/src/hooks/useDeviceScanning.ts b/app/src/hooks/useDeviceScanning.ts index 9380036c..96778c32 100644 --- a/app/src/hooks/useDeviceScanning.ts +++ b/app/src/hooks/useDeviceScanning.ts @@ -49,7 +49,7 @@ export const useDeviceScanning = ( const startScan = useCallback(async () => { console.log('[Scanner] startScan called'); setError(null); - setDevices([]); + setDevices([]); if (scanning) { console.log('[Scanner] Scan already in progress. Stopping previous scan first.'); @@ -80,7 +80,7 @@ export const useDeviceScanning = ( setError('Bluetooth is not enabled. Please turn on Bluetooth.'); return; } - + const currentState = await bleManager.state(); if (currentState !== BluetoothState.PoweredOn) { console.warn(`[Scanner] Bluetooth state is ${currentState}, not PoweredOn. Cannot scan.`); @@ -143,4 +143,4 @@ export const useDeviceScanning = ( }, [handleStopScan]); return { devices, scanning, startScan, stopScan: handleStopScan, error }; -}; \ No newline at end of file +}; diff --git a/app/src/hooks/usePhoneAudioRecorder.ts b/app/src/hooks/usePhoneAudioRecorder.ts index d80fbabb..25bc3755 100644 --- a/app/src/hooks/usePhoneAudioRecorder.ts +++ b/app/src/hooks/usePhoneAudioRecorder.ts @@ -34,7 +34,7 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { const [isInitializing, setIsInitializing] = useState(false); const [error, setError] = useState(null); const [audioLevel, setAudioLevel] = useState(0); - + const onAudioDataRef = useRef<((pcmBuffer: Uint8Array) => void) | null>(null); const mountedRef = useRef(true); @@ -53,13 +53,13 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { try { const audioData = event.data; console.log('[PhoneAudioRecorder] processAudioDataEvent called, data type:', typeof audioData); - + if (typeof audioData === 'string') { // Base64 encoded data (native platforms) - decode using react-native-base64 console.log('[PhoneAudioRecorder] Decoding Base64 string, length:', audioData.length); const binaryString = base64.decode(audioData); console.log('[PhoneAudioRecorder] Decoded to binary string, length:', binaryString.length); - + const bytes = new Uint8Array(binaryString.length); for (let i = 0; i < binaryString.length; i++) { bytes[i] = binaryString.charCodeAt(i); @@ -148,10 +148,10 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { intervalAnalysis: 500, // Analysis every 500ms onAudioStream: async (event: AudioDataEvent) => { // EXACT payload handling from guide - const payload = typeof event.data === "string" - ? event.data + const payload = typeof event.data === "string" + ? event.data : Buffer.from(event.data as unknown as ArrayBuffer).toString("base64"); - + // Convert to our expected format if (onAudioDataRef.current && mountedRef.current) { const pcmBuffer = processAudioDataEvent(event); @@ -163,7 +163,7 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { }; const result = await startRecorderInternal(config); - + if (!result) { throw new Error('Failed to start recording'); } @@ -185,7 +185,7 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { // Stop recording const stopRecording = useCallback(async (): Promise => { console.log('[PhoneAudioRecorder] Stopping recording...'); - + // Early return if not recording if (!isRecording) { console.log('[PhoneAudioRecorder] Not recording, nothing to stop'); @@ -194,7 +194,7 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { setStateSafe(setIsInitializing, false); return; } - + onAudioDataRef.current = null; setStateSafe(setAudioLevel, 0); @@ -231,13 +231,13 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { console.log('[PhoneAudioRecorder] Component unmounting, setting mountedRef to false'); }; }, []); // Empty dependency array - only runs on mount/unmount - + // Separate effect for stopping recording when needed useEffect(() => { return () => { // Stop recording if active when dependencies change if (isRecording) { - stopRecorderInternal().catch(err => + stopRecorderInternal().catch(err => console.error('[PhoneAudioRecorder] Cleanup stop error:', err) ); } @@ -252,4 +252,4 @@ export const usePhoneAudioRecorder = (): UsePhoneAudioRecorder => { startRecording, stopRecording, }; -}; \ No newline at end of file +}; diff --git a/app/src/utils/storage.ts b/app/src/utils/storage.ts index e6aa6e95..c61e78bd 100644 --- a/app/src/utils/storage.ts +++ b/app/src/utils/storage.ts @@ -207,4 +207,4 @@ export const clearAuthData = async (): Promise => { } catch (error) { console.error('[Storage] Error clearing auth data:', error); } -}; \ No newline at end of file +}; diff --git a/backends/advanced/init.py b/backends/advanced/init.py index eaf9f92f..b9b890eb 100644 --- a/backends/advanced/init.py +++ b/backends/advanced/init.py @@ -324,9 +324,11 @@ def setup_transcription(self): elif choice == "2": self.console.print("[blue][INFO][/blue] Offline Parakeet ASR selected") - parakeet_url = self.prompt_value( - "Parakeet ASR URL (without http:// prefix)", "host.docker.internal:8767" + existing_parakeet_url = ( + read_env_value(".env", "PARAKEET_ASR_URL") + or "http://host.docker.internal:8767" ) + parakeet_url = self.prompt_value("Parakeet ASR URL", existing_parakeet_url) # Write URL to .env for ${PARAKEET_ASR_URL} placeholder in config.yml self.config["PARAKEET_ASR_URL"] = parakeet_url @@ -348,9 +350,12 @@ def setup_transcription(self): self.console.print( "[blue][INFO][/blue] Offline VibeVoice ASR selected (built-in speaker diarization)" ) + existing_vibevoice_url = ( + read_env_value(".env", "VIBEVOICE_ASR_URL") + or "http://host.docker.internal:8767" + ) vibevoice_url = self.prompt_value( - "VibeVoice ASR URL (without http:// prefix)", - "host.docker.internal:8767", + "VibeVoice ASR URL", existing_vibevoice_url ) # Write URL to .env for ${VIBEVOICE_ASR_URL} placeholder in config.yml @@ -374,9 +379,13 @@ def setup_transcription(self): self.console.print( "[blue][INFO][/blue] Qwen3-ASR selected (52 languages, streaming + batch via vLLM)" ) - qwen3_url = self.prompt_value( - "Qwen3-ASR URL", "http://host.docker.internal:8767" + existing_qwen3_url_raw = read_env_value(".env", "QWEN3_ASR_URL") + existing_qwen3_url = ( + f"http://{existing_qwen3_url_raw}" + if existing_qwen3_url_raw + else "http://host.docker.internal:8767" ) + qwen3_url = self.prompt_value("Qwen3-ASR URL", existing_qwen3_url) # Write URL to .env for ${QWEN3_ASR_URL} placeholder in config.yml self.config["QWEN3_ASR_URL"] = qwen3_url.replace("http://", "").rstrip("/") @@ -527,21 +536,38 @@ def setup_streaming_provider(self): def setup_llm(self): """Configure LLM provider - updates config.yml and .env""" - self.print_section("LLM Provider Configuration") - - self.console.print( - "[blue][INFO][/blue] LLM configuration will be saved to config.yml" - ) - self.console.print() + # Check if LLM provider was provided via command line (from wizard.py) + if hasattr(self.args, "llm_provider") and self.args.llm_provider: + provider = self.args.llm_provider + self.console.print( + f"[green]✅[/green] LLM provider: {provider} (configured via wizard)" + ) + choice = {"openai": "1", "ollama": "2", "none": "3"}.get(provider, "1") + else: + # Standalone init.py run — read existing config as default + existing_choice = "1" + full_config = self.config_manager.get_full_config() + existing_llm = full_config.get("defaults", {}).get("llm", "") + if existing_llm == "local-llm": + existing_choice = "2" + elif existing_llm == "openai-llm": + existing_choice = "1" + + self.print_section("LLM Provider Configuration") + self.console.print( + "[blue][INFO][/blue] LLM configuration will be saved to config.yml" + ) + self.console.print() - choices = { - "1": "OpenAI (GPT-4, GPT-3.5 - requires API key)", - "2": "Ollama (local models - runs locally)", - "3": "OpenAI-Compatible (custom endpoint - Groq, Together AI, LM Studio, etc.)", - "4": "Skip (no memory extraction)", - } + choices = { + "1": "OpenAI (GPT-4, GPT-3.5 - requires API key)", + "2": "Ollama (local models - runs locally)", + "3": "Skip (no memory extraction)", + } - choice = self.prompt_choice("Which LLM provider will you use?", choices, "1") + choice = self.prompt_choice( + "Which LLM provider will you use?", choices, existing_choice + ) if choice == "1": self.console.print("[blue][INFO][/blue] OpenAI selected") @@ -717,14 +743,33 @@ def setup_llm(self): def setup_memory(self): """Configure memory provider - updates config.yml""" - self.print_section("Memory Storage Configuration") + # Check if memory provider was provided via command line (from wizard.py) + if hasattr(self.args, "memory_provider") and self.args.memory_provider: + provider = self.args.memory_provider + self.console.print( + f"[green]✅[/green] Memory provider: {provider} (configured via wizard)" + ) + choice = {"chronicle": "1", "openmemory_mcp": "2"}.get(provider, "1") + else: + # Standalone init.py run — read existing config as default + existing_choice = "1" + full_config = self.config_manager.get_full_config() + existing_provider = full_config.get("memory", {}).get( + "provider", "chronicle" + ) + if existing_provider == "openmemory_mcp": + existing_choice = "2" - choices = { - "1": "Chronicle Native (Qdrant + custom extraction)", - "2": "OpenMemory MCP (cross-client compatible, external server)", - } + self.print_section("Memory Storage Configuration") - choice = self.prompt_choice("Choose your memory storage backend:", choices, "1") + choices = { + "1": "Chronicle Native (Qdrant + custom extraction)", + "2": "OpenMemory MCP (cross-client compatible, external server)", + } + + choice = self.prompt_choice( + "Choose your memory storage backend:", choices, existing_choice + ) if choice == "1": self.console.print( @@ -852,13 +897,26 @@ def setup_neo4j(self): def setup_obsidian(self): """Configure Obsidian integration (optional feature flag only - Neo4j credentials handled by setup_neo4j)""" - if hasattr(self.args, "enable_obsidian") and self.args.enable_obsidian: + has_enable = hasattr(self.args, "enable_obsidian") and self.args.enable_obsidian + has_disable = hasattr(self.args, "no_obsidian") and self.args.no_obsidian + + if has_enable: enable_obsidian = True self.console.print( f"[green]✅[/green] Obsidian: enabled (configured via wizard)" ) + elif has_disable: + enable_obsidian = False + self.console.print( + f"[blue][INFO][/blue] Obsidian: disabled (configured via wizard)" + ) else: - # Interactive prompt (fallback) + # Standalone init.py run — read existing config as default + full_config = self.config_manager.get_full_config() + existing_enabled = ( + full_config.get("memory", {}).get("obsidian", {}).get("enabled", False) + ) + self.console.print() self.console.print("[bold cyan]Obsidian Integration (Optional)[/bold cyan]") self.console.print( @@ -868,11 +926,13 @@ def setup_obsidian(self): try: enable_obsidian = Confirm.ask( - "Enable Obsidian integration?", default=False + "Enable Obsidian integration?", default=existing_enabled ) except EOFError: - self.console.print("Using default: No") - enable_obsidian = False + self.console.print( + f"Using default: {'Yes' if existing_enabled else 'No'}" + ) + enable_obsidian = existing_enabled if enable_obsidian: self.config_manager.update_memory_config( @@ -887,12 +947,33 @@ def setup_obsidian(self): def setup_knowledge_graph(self): """Configure Knowledge Graph (Neo4j-based entity/relationship extraction - enabled by default)""" - if ( + has_enable = ( hasattr(self.args, "enable_knowledge_graph") and self.args.enable_knowledge_graph - ): + ) + has_disable = ( + hasattr(self.args, "no_knowledge_graph") and self.args.no_knowledge_graph + ) + + if has_enable: enable_kg = True + self.console.print( + f"[green]✅[/green] Knowledge Graph: enabled (configured via wizard)" + ) + elif has_disable: + enable_kg = False + self.console.print( + f"[blue][INFO][/blue] Knowledge Graph: disabled (configured via wizard)" + ) else: + # Standalone init.py run — read existing config as default + full_config = self.config_manager.get_full_config() + existing_enabled = ( + full_config.get("memory", {}) + .get("knowledge_graph", {}) + .get("enabled", True) + ) + self.console.print() self.console.print( "[bold cyan]Knowledge Graph (Entity Extraction)[/bold cyan]" @@ -903,10 +984,14 @@ def setup_knowledge_graph(self): self.console.print() try: - enable_kg = Confirm.ask("Enable Knowledge Graph?", default=True) + enable_kg = Confirm.ask( + "Enable Knowledge Graph?", default=existing_enabled + ) except EOFError: - self.console.print("Using default: Yes") - enable_kg = True + self.console.print( + f"Using default: {'Yes' if existing_enabled else 'No'}" + ) + enable_kg = existing_enabled if enable_kg: self.config_manager.update_memory_config( @@ -1454,15 +1539,31 @@ def main(): "--langfuse-host", help="LangFuse host URL (default: http://langfuse-web:3000 for local)", ) - parser.add_argument( - "--langfuse-public-url", - help="LangFuse browser-accessible URL for deep-links (default: http://localhost:3002)", - ) parser.add_argument( "--streaming-provider", choices=["deepgram", "smallest", "qwen3-asr"], help="Streaming provider when different from batch (enables batch re-transcription)", ) + parser.add_argument( + "--llm-provider", + choices=["openai", "ollama", "none"], + help="LLM provider for memory extraction (default: prompt user)", + ) + parser.add_argument( + "--memory-provider", + choices=["chronicle", "openmemory_mcp"], + help="Memory storage backend (default: prompt user)", + ) + parser.add_argument( + "--no-obsidian", + action="store_true", + help="Explicitly disable Obsidian integration (complementary to --enable-obsidian)", + ) + parser.add_argument( + "--no-knowledge-graph", + action="store_true", + help="Explicitly disable Knowledge Graph (complementary to --enable-knowledge-graph)", + ) args = parser.parse_args() diff --git a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py index 6973dffd..95f05eb2 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py @@ -129,7 +129,12 @@ def get_job_stats() -> Dict[str, Any]: deferred_jobs += len(queue.deferred_job_registry) total_jobs = ( - queued_jobs + started_jobs + finished_jobs + failed_jobs + canceled_jobs + deferred_jobs + queued_jobs + + started_jobs + + finished_jobs + + failed_jobs + + canceled_jobs + + deferred_jobs ) return { @@ -168,7 +173,9 @@ def get_jobs( f"🔍 DEBUG get_jobs: Filtering - queue_name={queue_name}, job_type={job_type}, client_id={client_id}" ) all_jobs = [] - seen_job_ids = set() # Track which job IDs we've already processed to avoid duplicates + seen_job_ids = ( + set() + ) # Track which job IDs we've already processed to avoid duplicates queues_to_check = [queue_name] if queue_name else QUEUE_NAMES logger.info(f"🔍 DEBUG get_jobs: Checking queues: {queues_to_check}") @@ -179,8 +186,14 @@ def get_jobs( # Collect jobs from all registries (using RQ standard status names) registries = [ (queue.job_ids, "queued"), - (queue.started_job_registry.get_job_ids(), "started"), # RQ standard, not "processing" - (queue.finished_job_registry.get_job_ids(), "finished"), # RQ standard, not "completed" + ( + queue.started_job_registry.get_job_ids(), + "started", + ), # RQ standard, not "processing" + ( + queue.finished_job_registry.get_job_ids(), + "finished", + ), # RQ standard, not "completed" (queue.failed_job_registry.get_job_ids(), "failed"), ( queue.deferred_job_registry.get_job_ids(), @@ -202,7 +215,9 @@ def get_jobs( user_id = job.kwargs.get("user_id", "") if job.kwargs else "" # Extract just the function name (e.g., "listen_for_speech_job" from "module.listen_for_speech_job") - func_name = job.func_name.split(".")[-1] if job.func_name else "unknown" + func_name = ( + job.func_name.split(".")[-1] if job.func_name else "unknown" + ) # Debug: Log job details before filtering logger.debug( @@ -218,14 +233,18 @@ def get_jobs( # Apply client_id filter (partial match in meta) if client_id: - job_client_id = job.meta.get("client_id", "") if job.meta else "" + job_client_id = ( + job.meta.get("client_id", "") if job.meta else "" + ) if client_id not in job_client_id: logger.debug( f"🔍 DEBUG get_jobs: Filtered out {job_id} - client_id '{client_id}' not in job_client_id '{job_client_id}'" ) continue - logger.debug(f"🔍 DEBUG get_jobs: Including job {job_id} in results") + logger.debug( + f"🔍 DEBUG get_jobs: Including job {job_id} in results" + ) all_jobs.append( { @@ -239,12 +258,24 @@ def get_jobs( "queue": qname, }, "result": job.result if hasattr(job, "result") else None, - "meta": job.meta if job.meta else {}, # Include job metadata - "error_message": str(job.exc_info) if job.exc_info else None, - "created_at": job.created_at.isoformat() if job.created_at else None, - "started_at": job.started_at.isoformat() if job.started_at else None, - "completed_at": job.ended_at.isoformat() if job.ended_at else None, - "retry_count": job.retries_left if hasattr(job, "retries_left") else 0, + "meta": ( + job.meta if job.meta else {} + ), # Include job metadata + "error_message": ( + str(job.exc_info) if job.exc_info else None + ), + "created_at": ( + job.created_at.isoformat() if job.created_at else None + ), + "started_at": ( + job.started_at.isoformat() if job.started_at else None + ), + "completed_at": ( + job.ended_at.isoformat() if job.ended_at else None + ), + "retry_count": ( + job.retries_left if hasattr(job, "retries_left") else 0 + ), "max_retries": 3, # Default max retries "progress_percent": (job.meta or {}) .get("batch_progress", {}) @@ -345,7 +376,9 @@ def is_job_complete(job): return True -def start_streaming_jobs(session_id: str, user_id: str, client_id: str) -> Dict[str, str]: +def start_streaming_jobs( + session_id: str, user_id: str, client_id: str +) -> Dict[str, str]: """ Enqueue jobs for streaming audio session (initial session setup). @@ -401,7 +434,9 @@ def start_streaming_jobs(session_id: str, user_id: str, client_id: str) -> Dict[ # Store job ID for cleanup (keyed by client_id for easy WebSocket cleanup) try: - redis_conn.set(f"speech_detection_job:{client_id}", speech_job.id, ex=86400) # 24 hour TTL + redis_conn.set( + f"speech_detection_job:{client_id}", speech_job.id, ex=86400 + ) # 24 hour TTL logger.info(f"📌 Stored speech detection job ID for client {client_id}") except Exception as e: logger.warning(f"⚠️ Failed to store job ID for {client_id}: {e}") @@ -421,7 +456,10 @@ def start_streaming_jobs(session_id: str, user_id: str, client_id: str) -> Dict[ failure_ttl=86400, # Cleanup failed jobs after 24h job_id=f"audio-persist_{session_id}", description=f"Audio persistence for session {session_id}", - meta={"client_id": client_id, "session_level": True}, # Mark as session-level job + meta={ + "client_id": client_id, + "session_level": True, + }, # Mark as session-level job ) # Log job enqueue with TTL information for debugging actual_ttl = redis_conn.ttl(f"rq:job:{audio_job.id}") @@ -735,11 +773,15 @@ async def cleanup_stuck_stream_workers(request): stream_keys = await redis_client.keys("audio:stream:*") for stream_key in stream_keys: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) try: # First check stream age - delete old streams (>1 hour) immediately - stream_info = await redis_client.execute_command("XINFO", "STREAM", stream_name) + stream_info = await redis_client.execute_command( + "XINFO", "STREAM", stream_name + ) # Parse stream info info_dict = {} @@ -761,7 +803,9 @@ async def cleanup_stuck_stream_workers(request): if stream_length == 0: should_delete_stream = True stream_age = 0 - elif last_entry and isinstance(last_entry, list) and len(last_entry) > 0: + elif ( + last_entry and isinstance(last_entry, list) and len(last_entry) > 0 + ): try: last_id = last_entry[0] if isinstance(last_id, bytes): @@ -789,7 +833,9 @@ async def cleanup_stuck_stream_workers(request): continue # Get consumer groups - groups = await redis_client.execute_command("XINFO", "GROUPS", stream_name) + groups = await redis_client.execute_command( + "XINFO", "GROUPS", stream_name + ) if not groups: cleanup_results[stream_name] = { @@ -803,7 +849,11 @@ async def cleanup_stuck_stream_workers(request): group_dict = {} group = groups[0] for i in range(0, len(group), 2): - key = group[i].decode() if isinstance(group[i], bytes) else str(group[i]) + key = ( + group[i].decode() + if isinstance(group[i], bytes) + else str(group[i]) + ) value = group[i + 1] if isinstance(value, bytes): try: @@ -892,7 +942,9 @@ async def cleanup_stuck_stream_workers(request): ) # Acknowledge it immediately - await redis_client.xack(stream_name, group_name, msg_id) + await redis_client.xack( + stream_name, group_name, msg_id + ) cleaned_count += 1 except Exception as claim_error: logger.warning( @@ -908,7 +960,11 @@ async def cleanup_stuck_stream_workers(request): if is_dead and consumer_pending == 0: try: await redis_client.execute_command( - "XGROUP", "DELCONSUMER", stream_name, group_name, consumer_name + "XGROUP", + "DELCONSUMER", + stream_name, + group_name, + consumer_name, ) deleted_consumers += 1 logger.info( @@ -954,5 +1010,6 @@ async def cleanup_stuck_stream_workers(request): except Exception as e: logger.error(f"Error cleaning up stuck workers: {e}", exc_info=True) return JSONResponse( - status_code=500, content={"error": f"Failed to cleanup stuck workers: {str(e)}"} + status_code=500, + content={"error": f"Failed to cleanup stuck workers: {str(e)}"}, ) diff --git a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py index 6a96883b..cd01b099 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py @@ -59,9 +59,15 @@ async def mark_session_complete( mark_time = time.time() await redis_client.hset( session_key, - mapping={"status": "finished", "completed_at": str(mark_time), "completion_reason": reason}, + mapping={ + "status": "finished", + "completed_at": str(mark_time), + "completion_reason": reason, + }, + ) + logger.info( + f"✅ Session {session_id[:12]} marked finished: {reason} [TIME: {mark_time:.3f}]" ) - logger.info(f"✅ Session {session_id[:12]} marked finished: {reason} [TIME: {mark_time:.3f}]") async def request_conversation_close( @@ -91,7 +97,9 @@ async def request_conversation_close( if not await redis_client.exists(session_key): return False await redis_client.hset(session_key, "conversation_close_requested", reason) - logger.info(f"🔒 Conversation close requested for session {session_id[:12]}: {reason}") + logger.info( + f"🔒 Conversation close requested for session {session_id[:12]}: {reason}" + ) return True @@ -130,7 +138,9 @@ async def get_session_info(redis_client, session_id: str) -> Optional[Dict]: "provider": session_data.get(b"provider", b"").decode(), "mode": session_data.get(b"mode", b"").decode(), "status": session_data.get(b"status", b"").decode(), - "websocket_connected": session_data.get(b"websocket_connected", b"false").decode() + "websocket_connected": session_data.get( + b"websocket_connected", b"false" + ).decode() == "true", "completion_reason": session_data.get(b"completion_reason", b"").decode(), "chunks_published": int(session_data.get(b"chunks_published", b"0")), @@ -142,8 +152,12 @@ async def get_session_info(redis_client, session_id: str) -> Optional[Dict]: # Speech detection events "last_event": session_data.get(b"last_event", b"").decode(), "speech_detected_at": session_data.get(b"speech_detected_at", b"").decode(), - "speaker_check_status": session_data.get(b"speaker_check_status", b"").decode(), - "identified_speakers": session_data.get(b"identified_speakers", b"").decode(), + "speaker_check_status": session_data.get( + b"speaker_check_status", b"" + ).decode(), + "identified_speakers": session_data.get( + b"identified_speakers", b"" + ).decode(), } except Exception as e: @@ -167,7 +181,9 @@ async def get_all_sessions(redis_client, limit: int = 100) -> List[Dict]: session_keys = [] cursor = b"0" while cursor and len(session_keys) < limit: - cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=limit) + cursor, keys = await redis_client.scan( + cursor, match="audio:session:*", count=limit + ) session_keys.extend(keys[: limit - len(session_keys)]) # Get info for each session @@ -223,7 +239,9 @@ async def increment_session_conversation_count(redis_client, session_id: str) -> logger.info(f"📊 Conversation count for session {session_id}: {count}") return count except Exception as e: - logger.error(f"Error incrementing conversation count for session {session_id}: {e}") + logger.error( + f"Error incrementing conversation count for session {session_id}: {e}" + ) return 0 @@ -332,10 +350,14 @@ async def get_streaming_status(request): current_time = time.time() for stream_key in stream_keys: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) try: # Check if stream exists - stream_info = await redis_client.execute_command("XINFO", "STREAM", stream_name) + stream_info = await redis_client.execute_command( + "XINFO", "STREAM", stream_name + ) # Parse stream info (returns flat list of key-value pairs) info_dict = {} @@ -391,7 +413,9 @@ async def get_streaming_status(request): session_idle_seconds = session_data.get("idle_seconds", 0) # Get consumer groups - groups = await redis_client.execute_command("XINFO", "GROUPS", stream_name) + groups = await redis_client.execute_command( + "XINFO", "GROUPS", stream_name + ) stream_data = { "stream_length": info_dict.get("length", 0), @@ -411,7 +435,11 @@ async def get_streaming_status(request): for group in groups: group_dict = {} for i in range(0, len(group), 2): - key = group[i].decode() if isinstance(group[i], bytes) else str(group[i]) + key = ( + group[i].decode() + if isinstance(group[i], bytes) + else str(group[i]) + ) value = group[i + 1] if isinstance(value, bytes): try: @@ -456,7 +484,9 @@ async def get_streaming_status(request): consumer_pending_total += consumer_pending # Track minimum idle time - min_consumer_idle_ms = min(min_consumer_idle_ms, consumer_idle_ms) + min_consumer_idle_ms = min( + min_consumer_idle_ms, consumer_idle_ms + ) # Consumer is active if idle < 5 minutes (300000ms) if consumer_idle_ms < 300000: @@ -492,7 +522,9 @@ async def get_streaming_status(request): # Determine if stream is active or completed # Active: has active consumers OR pending messages OR recent activity (< 5 min) # Completed: no active consumers and idle > 5 minutes but < 1 hour - total_pending = sum(group["pending"] for group in stream_data["consumer_groups"]) + total_pending = sum( + group["pending"] for group in stream_data["consumer_groups"] + ) is_active = ( has_active_consumer or total_pending > 0 @@ -552,7 +584,8 @@ async def get_streaming_status(request): except Exception as e: logger.error(f"Error getting streaming status: {e}", exc_info=True) return JSONResponse( - status_code=500, content={"error": f"Failed to get streaming status: {str(e)}"} + status_code=500, + content={"error": f"Failed to get streaming status: {str(e)}"}, ) @@ -597,7 +630,11 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): if should_clean: old_sessions.append( - {"session_id": session_id, "age_seconds": age_seconds, "status": status} + { + "session_id": session_id, + "age_seconds": age_seconds, + "status": status, + } ) await redis_client.delete(key) cleaned_sessions += 1 @@ -608,11 +645,15 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): old_streams = [] for stream_key in stream_keys: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) try: # Check stream info to get last activity - stream_info = await redis_client.execute_command("XINFO", "STREAM", stream_name) + stream_info = await redis_client.execute_command( + "XINFO", "STREAM", stream_name + ) # Parse stream info info_dict = {} @@ -635,7 +676,9 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): # Empty stream - safe to delete should_delete = True reason = "empty" - elif last_entry and isinstance(last_entry, list) and len(last_entry) > 0: + elif ( + last_entry and isinstance(last_entry, list) and len(last_entry) > 0 + ): # Extract timestamp from last entry ID last_id = last_entry[0] if isinstance(last_id, bytes): @@ -654,7 +697,11 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): except (ValueError, IndexError): # If we can't parse timestamp, check if first entry is old first_entry = info_dict.get("first-entry") - if first_entry and isinstance(first_entry, list) and len(first_entry) > 0: + if ( + first_entry + and isinstance(first_entry, list) + and len(first_entry) > 0 + ): try: first_id = first_entry[0] if isinstance(first_id, bytes): @@ -697,5 +744,6 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): except Exception as e: logger.error(f"Error cleaning up old sessions: {e}", exc_info=True) return JSONResponse( - status_code=500, content={"error": f"Failed to cleanup old sessions: {str(e)}"} + status_code=500, + content={"error": f"Failed to cleanup old sessions: {str(e)}"}, ) diff --git a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py index 274861c8..6d1b627a 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py @@ -93,7 +93,10 @@ async def get_config_diagnostics(): "resolution": "Check config/defaults.yml and config/config.yml syntax", } ) - diagnostics["components"]["omegaconf"] = {"status": "unhealthy", "message": str(e)} + diagnostics["components"]["omegaconf"] = { + "status": "unhealthy", + "message": str(e), + } # Test model registry try: @@ -230,7 +233,10 @@ async def get_config_diagnostics(): "resolution": "Check logs for detailed error information", } ) - diagnostics["components"]["model_registry"] = {"status": "unhealthy", "message": str(e)} + diagnostics["components"]["model_registry"] = { + "status": "unhealthy", + "message": str(e), + } # Check environment variables (only warn about keys relevant to configured providers) env_checks = [ @@ -243,7 +249,9 @@ async def get_config_diagnostics(): # Add LLM API key check based on active provider llm_model = registry.get_default("llm") if llm_model and llm_model.model_provider == "openai": - env_checks.append(("OPENAI_API_KEY", "Required for OpenAI LLM and embeddings")) + env_checks.append( + ("OPENAI_API_KEY", "Required for OpenAI LLM and embeddings") + ) elif llm_model and llm_model.model_provider == "groq": env_checks.append(("GROQ_API_KEY", "Required for Groq LLM")) @@ -252,7 +260,9 @@ async def get_config_diagnostics(): if stt_model: provider = stt_model.model_provider if provider == "deepgram": - env_checks.append(("DEEPGRAM_API_KEY", "Required for Deepgram transcription")) + env_checks.append( + ("DEEPGRAM_API_KEY", "Required for Deepgram transcription") + ) elif provider == "smallest": env_checks.append( ("SMALLEST_API_KEY", "Required for Smallest.ai Pulse transcription") @@ -320,7 +330,9 @@ async def get_observability_config(): from advanced_omi_backend.config_loader import load_config cfg = load_config() - public_url = cfg.get("observability", {}).get("langfuse", {}).get("public_url", "") + public_url = ( + cfg.get("observability", {}).get("langfuse", {}).get("public_url", "") + ) if public_url: # Strip trailing slash and build session URL session_base_url = f"{public_url.rstrip('/')}/project/chronicle/sessions" @@ -374,7 +386,8 @@ async def save_diarization_settings_controller(settings: dict): if key in ["min_speakers", "max_speakers"]: if not isinstance(value, int) or value < 1 or value > 20: raise HTTPException( - status_code=400, detail=f"Invalid value for {key}: must be integer 1-20" + status_code=400, + detail=f"Invalid value for {key}: must be integer 1-20", ) elif key == "diarization_source": if not isinstance(value, str) or value not in ["pyannote", "deepgram"]: @@ -385,14 +398,17 @@ async def save_diarization_settings_controller(settings: dict): else: if not isinstance(value, (int, float)) or value < 0: raise HTTPException( - status_code=400, detail=f"Invalid value for {key}: must be positive number" + status_code=400, + detail=f"Invalid value for {key}: must be positive number", ) filtered_settings[key] = value # Reject if NO valid keys provided (completely invalid request) if not filtered_settings: - raise HTTPException(status_code=400, detail="No valid diarization settings provided") + raise HTTPException( + status_code=400, detail="No valid diarization settings provided" + ) # Get current settings and merge with new values current_settings = load_diarization_settings() @@ -454,7 +470,8 @@ async def save_misc_settings_controller(settings: dict): if key in boolean_keys: if not isinstance(value, bool): raise HTTPException( - status_code=400, detail=f"Invalid value for {key}: must be boolean" + status_code=400, + detail=f"Invalid value for {key}: must be boolean", ) elif key == "transcription_job_timeout_seconds": if not isinstance(value, int) or value < 60 or value > 7200: @@ -467,7 +484,9 @@ async def save_misc_settings_controller(settings: dict): # Reject if NO valid keys provided if not filtered_settings: - raise HTTPException(status_code=400, detail="No valid misc settings provided") + raise HTTPException( + status_code=400, detail="No valid misc settings provided" + ) # Save using OmegaConf if save_misc_settings(filtered_settings): @@ -605,7 +624,9 @@ async def update_speaker_configuration(user: User, primary_speakers: list[dict]) } except Exception as e: - logger.exception(f"Error updating speaker configuration for user {user.user_id}") + logger.exception( + f"Error updating speaker configuration for user {user.user_id}" + ) raise e @@ -757,9 +778,13 @@ async def validate_memory_config(config_yaml: str): try: parsed = _yaml.load(config_yaml) except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid YAML syntax: {str(e)}") + raise HTTPException( + status_code=400, detail=f"Invalid YAML syntax: {str(e)}" + ) if not isinstance(parsed, dict): - raise HTTPException(status_code=400, detail="Configuration must be a YAML object") + raise HTTPException( + status_code=400, detail="Configuration must be a YAML object" + ) # Minimal checks # provider optional; timeout_seconds optional; extraction enabled/prompt optional return {"message": "Configuration is valid", "status": "success"} @@ -768,7 +793,9 @@ async def validate_memory_config(config_yaml: str): raise except Exception as e: logger.exception("Error validating memory config") - raise HTTPException(status_code=500, detail=f"Error validating memory config: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error validating memory config: {str(e)}" + ) async def reload_memory_config(): @@ -870,7 +897,9 @@ async def set_memory_provider(provider: str): # If MEMORY_PROVIDER wasn't found, add it if not provider_found: - updated_lines.append(f"\n# Memory Provider Configuration\nMEMORY_PROVIDER={provider}\n") + updated_lines.append( + f"\n# Memory Provider Configuration\nMEMORY_PROVIDER={provider}\n" + ) # Create backup backup_path = f"{env_path}.bak" @@ -950,19 +979,23 @@ async def save_llm_operations(operations: dict): for op_name, op_value in operations.items(): if not isinstance(op_value, dict): - raise HTTPException(status_code=400, detail=f"Operation '{op_name}' must be a dict") + raise HTTPException( + status_code=400, detail=f"Operation '{op_name}' must be a dict" + ) extra_keys = set(op_value.keys()) - valid_keys if extra_keys: raise HTTPException( - status_code=400, detail=f"Invalid keys for '{op_name}': {extra_keys}" + status_code=400, + detail=f"Invalid keys for '{op_name}': {extra_keys}", ) if "temperature" in op_value and op_value["temperature"] is not None: t = op_value["temperature"] if not isinstance(t, (int, float)) or t < 0 or t > 2: raise HTTPException( - status_code=400, detail=f"Invalid temperature for '{op_name}': must be 0-2" + status_code=400, + detail=f"Invalid temperature for '{op_name}': must be 0-2", ) if "max_tokens" in op_value and op_value["max_tokens"] is not None: @@ -976,13 +1009,18 @@ async def save_llm_operations(operations: dict): if "model" in op_value and op_value["model"] is not None: if not registry.get_by_name(op_value["model"]): raise HTTPException( - status_code=400, detail=f"Model '{op_value['model']}' not found in registry" + status_code=400, + detail=f"Model '{op_value['model']}' not found in registry", ) - if "response_format" in op_value and op_value["response_format"] is not None: + if ( + "response_format" in op_value + and op_value["response_format"] is not None + ): if op_value["response_format"] != "json": raise HTTPException( - status_code=400, detail=f"response_format must be 'json' or null" + status_code=400, + detail=f"response_format must be 'json' or null", ) if save_config_section("llm_operations", operations): @@ -1157,7 +1195,10 @@ async def validate_chat_config_yaml(prompt_text: str) -> dict: if len(prompt_text) < 10: return {"valid": False, "error": "Prompt too short (minimum 10 characters)"} if len(prompt_text) > 10000: - return {"valid": False, "error": "Prompt too long (maximum 10000 characters)"} + return { + "valid": False, + "error": "Prompt too long (maximum 10000 characters)", + } return {"valid": True, "message": "Configuration is valid"} @@ -1291,8 +1332,13 @@ async def validate_plugins_config_yaml(yaml_content: str) -> dict: } # Check required fields - if "enabled" in plugin_config and not isinstance(plugin_config["enabled"], bool): - return {"valid": False, "error": f"Plugin '{plugin_id}': 'enabled' must be boolean"} + if "enabled" in plugin_config and not isinstance( + plugin_config["enabled"], bool + ): + return { + "valid": False, + "error": f"Plugin '{plugin_id}': 'enabled' must be boolean", + } if ( "access_level" in plugin_config @@ -1448,11 +1494,14 @@ async def get_plugins_metadata() -> dict: for plugin_id, plugin_class in discovered_plugins.items(): # Get orchestration config (or empty dict if not configured) orchestration_config = orchestration_configs.get( - plugin_id, {"enabled": False, "events": [], "condition": {"type": "always"}} + plugin_id, + {"enabled": False, "events": [], "condition": {"type": "always"}}, ) # Get complete metadata including schema - metadata = get_plugin_metadata(plugin_id, plugin_class, orchestration_config) + metadata = get_plugin_metadata( + plugin_id, plugin_class, orchestration_config + ) plugins_metadata.append(metadata) logger.info(f"Retrieved metadata for {len(plugins_metadata)} plugins") @@ -1527,7 +1576,9 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: _yaml.dump(plugins_data, f) updated_files.append(str(plugins_yml_path)) - logger.info(f"Updated orchestration config for '{plugin_id}' in {plugins_yml_path}") + logger.info( + f"Updated orchestration config for '{plugin_id}' in {plugins_yml_path}" + ) # 2. Update plugins/{plugin_id}/config.yml (settings with env var references) if "settings" in config: @@ -1562,7 +1613,9 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: from advanced_omi_backend.services.plugin_service import save_plugin_env # Filter out masked values (unchanged secrets) - changed_vars = {k: v for k, v in config["env_vars"].items() if v != "••••••••••••"} + changed_vars = { + k: v for k, v in config["env_vars"].items() if v != "••••••••••••" + } if changed_vars: env_path = save_plugin_env(plugin_id, changed_vars) @@ -1582,7 +1635,9 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: except Exception as reload_err: logger.warning(f"Auto-reload failed, manual restart needed: {reload_err}") - message = f"Plugin '{plugin_id}' configuration updated and reloaded successfully." + message = ( + f"Plugin '{plugin_id}' configuration updated and reloaded successfully." + ) if reload_result is None: message = f"Plugin '{plugin_id}' configuration updated. Restart backend for changes to take effect." @@ -1658,13 +1713,19 @@ async def test_plugin_connection(plugin_id: str, config: dict) -> dict: # Call plugin's test_connection static method result = await plugin_class.test_connection(test_config) - logger.info(f"Test connection for '{plugin_id}': {result.get('message', 'No message')}") + logger.info( + f"Test connection for '{plugin_id}': {result.get('message', 'No message')}" + ) return result except Exception as e: logger.exception(f"Error testing connection for plugin '{plugin_id}'") - return {"success": False, "message": f"Connection test failed: {str(e)}", "status": "error"} + return { + "success": False, + "message": f"Connection test failed: {str(e)}", + "status": "error", + } # Plugin Lifecycle Management Functions (create / write-code / delete) @@ -1705,7 +1766,10 @@ async def create_plugin( # Validate name if not plugin_name.replace("_", "").isalnum(): - return {"success": False, "error": "Plugin name must be alphanumeric with underscores only"} + return { + "success": False, + "error": "Plugin name must be alphanumeric with underscores only", + } if not re.match(r"^[a-z][a-z0-9_]*$", plugin_name): return { @@ -1718,11 +1782,17 @@ async def create_plugin( # Collision check if plugin_dir.exists(): - return {"success": False, "error": f"Plugin '{plugin_name}' already exists at {plugin_dir}"} + return { + "success": False, + "error": f"Plugin '{plugin_name}' already exists at {plugin_dir}", + } discovered = discover_plugins() if plugin_name in discovered: - return {"success": False, "error": f"Plugin '{plugin_name}' is already registered"} + return { + "success": False, + "error": f"Plugin '{plugin_name}' is already registered", + } class_name = _snake_to_pascal(plugin_name) + "Plugin" created_files: list[str] = [] @@ -1740,7 +1810,9 @@ async def create_plugin( else: # Write standard boilerplate events_str = ( - ", ".join(f'"{e}"' for e in events) if events else '"conversation.complete"' + ", ".join(f'"{e}"' for e in events) + if events + else '"conversation.complete"' ) boilerplate = ( inspect.cleandoc( @@ -1861,7 +1933,10 @@ async def write_plugin_code( plugin_dir = plugins_dir / plugin_id if not plugin_dir.exists(): - return {"success": False, "error": f"Plugin '{plugin_id}' not found at {plugin_dir}"} + return { + "success": False, + "error": f"Plugin '{plugin_id}' not found at {plugin_dir}", + } updated_files: list[str] = [] @@ -1949,7 +2024,9 @@ async def delete_plugin(plugin_id: str, remove_files: bool = False) -> dict: "error": f"Plugin '{plugin_id}' not found in plugins.yml or on disk", } - logger.info(f"Deleted plugin '{plugin_id}' (yml={removed_from_yml}, files={files_removed})") + logger.info( + f"Deleted plugin '{plugin_id}' (yml={removed_from_yml}, files={files_removed})" + ) return { "success": True, "plugin_id": plugin_id, diff --git a/backends/advanced/src/advanced_omi_backend/llm_client.py b/backends/advanced/src/advanced_omi_backend/llm_client.py index 8b5f2d43..a5184ceb 100644 --- a/backends/advanced/src/advanced_omi_backend/llm_client.py +++ b/backends/advanced/src/advanced_omi_backend/llm_client.py @@ -73,7 +73,9 @@ def __init__( ) self.logger.info(f"OpenAI client initialized, base_url: {self.base_url}") except ImportError: - self.logger.error("OpenAI library not installed. Install with: pip install openai") + self.logger.error( + "OpenAI library not installed. Install with: pip install openai" + ) raise except Exception as e: self.logger.error(f"Failed to initialize OpenAI client: {e}") @@ -128,14 +130,18 @@ def health_check(self) -> Dict: "status": "✅ Connected", "base_url": self.base_url, "default_model": self.model, - "api_key_configured": bool(self.api_key and self.api_key != "dummy"), + "api_key_configured": bool( + self.api_key and self.api_key != "dummy" + ), } else: return { "status": "⚠️ Configuration incomplete", "base_url": self.base_url, "default_model": self.model, - "api_key_configured": bool(self.api_key and self.api_key != "dummy"), + "api_key_configured": bool( + self.api_key and self.api_key != "dummy" + ), } except Exception as e: self.logger.error(f"Health check failed: {e}") @@ -233,7 +239,9 @@ async def async_generate( # Fallback: use singleton client client = get_llm_client() loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, lambda: client.generate(prompt, model, temperature)) + return await loop.run_in_executor( + None, lambda: client.generate(prompt, model, temperature) + ) async def async_chat_with_tools( diff --git a/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py b/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py index dffa4f1e..1c83331e 100644 --- a/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py +++ b/backends/advanced/src/advanced_omi_backend/observability/otel_setup.py @@ -88,7 +88,9 @@ def init_otel() -> None: langfuse = is_langfuse_enabled() if not galileo and not langfuse: - logger.info("No OTEL backends configured (Galileo/Langfuse), skipping initialization") + logger.info( + "No OTEL backends configured (Galileo/Langfuse), skipping initialization" + ) return try: @@ -104,12 +106,15 @@ def init_otel() -> None: project = os.getenv("GALILEO_PROJECT", "chronicle") logstream = os.getenv("GALILEO_LOG_STREAM", "default") - galileo_processor = otel.GalileoSpanProcessor(project=project, logstream=logstream) + galileo_processor = otel.GalileoSpanProcessor( + project=project, logstream=logstream + ) tracer_provider.add_span_processor(galileo_processor) backends.append("Galileo") except ImportError: logger.warning( - "Galileo packages not installed. " "Install with: uv pip install '.[galileo]'" + "Galileo packages not installed. " + "Install with: uv pip install '.[galileo]'" ) except Exception as e: logger.error(f"Failed to add Galileo span processor: {e}") @@ -124,7 +129,8 @@ def init_otel() -> None: backends.append("Langfuse") except ImportError: logger.warning( - "Langfuse OTEL packages not installed. " "Ensure langfuse>=3.13.0 is installed." + "Langfuse OTEL packages not installed. " + "Ensure langfuse>=3.13.0 is installed." ) except Exception as e: logger.error(f"Failed to add Langfuse span processor: {e}") @@ -152,7 +158,8 @@ def init_otel() -> None: ) except ImportError: logger.warning( - "OTEL SDK packages not installed. " "Install opentelemetry-api and opentelemetry-sdk." + "OTEL SDK packages not installed. " + "Install opentelemetry-api and opentelemetry-sdk." ) except Exception as e: logger.error(f"Failed to initialize OTEL: {e}") diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py index cb0f7137..8f68cf21 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py @@ -73,7 +73,9 @@ async def list_jobs( @router.get("/jobs/{job_id}/status") -async def get_job_status(job_id: str, current_user: User = Depends(current_active_user)): +async def get_job_status( + job_id: str, current_user: User = Depends(current_active_user) +): """Get just the status of a specific job (lightweight endpoint).""" try: job = Job.fetch(job_id, connection=redis_conn) @@ -193,7 +195,9 @@ async def cancel_job(job_id: str, current_user: User = Depends(current_active_us @router.get("/jobs/by-client/{client_id}") -async def get_jobs_by_client(client_id: str, current_user: User = Depends(current_active_user)): +async def get_jobs_by_client( + client_id: str, current_user: User = Depends(current_active_user) +): """Get all jobs associated with a specific client device.""" try: from rq.registry import ( @@ -242,11 +246,17 @@ def process_job_and_dependents(job, queue_name, base_status): all_jobs.append( { "job_id": job.id, - "job_type": (job.func_name.split(".")[-1] if job.func_name else "unknown"), + "job_type": ( + job.func_name.split(".")[-1] if job.func_name else "unknown" + ), "queue": queue_name, "status": status, - "created_at": (job.created_at.isoformat() if job.created_at else None), - "started_at": (job.started_at.isoformat() if job.started_at else None), + "created_at": ( + job.created_at.isoformat() if job.created_at else None + ), + "started_at": ( + job.started_at.isoformat() if job.started_at else None + ), "ended_at": job.ended_at.isoformat() if job.ended_at else None, "description": job.description or "", "result": job.result, @@ -323,13 +333,17 @@ def process_job_and_dependents(job, queue_name, base_status): # Sort by created_at all_jobs.sort(key=lambda x: x["created_at"] or "", reverse=False) - logger.info(f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)") + logger.info( + f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)" + ) return {"client_id": client_id, "jobs": all_jobs, "total": len(all_jobs)} except Exception as e: logger.error(f"Failed to get jobs for client {client_id}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get jobs for client: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get jobs for client: {str(e)}" + ) @router.get("/events") @@ -349,7 +363,9 @@ async def get_events( if not router_instance: return {"events": [], "total": 0} - events = router_instance.get_recent_events(limit=limit, event_type=event_type or None) + events = router_instance.get_recent_events( + limit=limit, event_type=event_type or None + ) return {"events": events, "total": len(events)} except Exception as e: logger.error(f"Failed to get events: {e}") @@ -475,7 +491,9 @@ async def get_queue_worker_details(current_user: User = Depends(current_active_u except Exception as e: logger.error(f"Failed to get queue worker details: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get worker details: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get worker details: {str(e)}" + ) @router.get("/streams") @@ -506,7 +524,9 @@ async def get_stream_stats( async def get_stream_info(stream_key): try: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) # Get basic stream info info = await audio_service.redis.xinfo_stream(stream_name) @@ -566,7 +586,9 @@ async def get_stream_info(stream_key): "name": group_dict.get("name", "unknown"), "consumers": group_dict.get("consumers", 0), "pending": group_dict.get("pending", 0), - "last_delivered_id": group_dict.get("last-delivered-id", "N/A"), + "last_delivered_id": group_dict.get( + "last-delivered-id", "N/A" + ), "consumer_details": consumers, } ) @@ -577,7 +599,9 @@ async def get_stream_info(stream_key): "stream_name": stream_name, "length": info[b"length"], "first_entry_id": ( - info[b"first-entry"][0].decode() if info[b"first-entry"] else None + info[b"first-entry"][0].decode() + if info[b"first-entry"] + else None ), "last_entry_id": ( info[b"last-entry"][0].decode() if info[b"last-entry"] else None @@ -589,7 +613,9 @@ async def get_stream_info(stream_key): return None # Fetch all stream info in parallel - streams_info_results = await asyncio.gather(*[get_stream_info(key) for key in stream_keys]) + streams_info_results = await asyncio.gather( + *[get_stream_info(key) for key in stream_keys] + ) streams_info = [info for info in streams_info_results if info is not None] return { @@ -615,7 +641,9 @@ class FlushAllJobsRequest(BaseModel): @router.post("/flush") -async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(current_active_user)): +async def flush_jobs( + request: FlushJobsRequest, current_user: User = Depends(current_active_user) +): """Flush old inactive jobs based on age and status.""" if not current_user.is_superuser: raise HTTPException(status_code=403, detail="Admin access required") @@ -631,7 +659,9 @@ async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(cur from advanced_omi_backend.controllers.queue_controller import get_queue - cutoff_time = datetime.now(timezone.utc) - timedelta(hours=request.older_than_hours) + cutoff_time = datetime.now(timezone.utc) - timedelta( + hours=request.older_than_hours + ) total_removed = 0 # Get all queues @@ -663,7 +693,9 @@ async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(cur except Exception as e: logger.error(f"Error deleting job {job_id}: {e}") - if "canceled" in request.statuses: # RQ standard (US spelling), not "cancelled" + if ( + "canceled" in request.statuses + ): # RQ standard (US spelling), not "cancelled" registry = CanceledJobRegistry(queue=queue) for job_id in registry.get_job_ids(): try: @@ -745,8 +777,12 @@ async def flush_all_jobs( registries.append(("finished", FinishedJobRegistry(queue=queue))) for registry_name, registry in registries: - job_ids = list(registry.get_job_ids()) # Convert to list to avoid iterator issues - logger.info(f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}") + job_ids = list( + registry.get_job_ids() + ) # Convert to list to avoid iterator issues + logger.info( + f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}" + ) for job_id in job_ids: try: @@ -756,7 +792,9 @@ async def flush_all_jobs( # Skip session-level jobs (e.g., speech_detection, audio_persistence) # These run for the entire session and should not be killed by test cleanup if job.meta and job.meta.get("session_level"): - logger.info(f"Skipping session-level job {job_id} ({job.description})") + logger.info( + f"Skipping session-level job {job_id} ({job.description})" + ) continue # Handle running jobs differently to avoid worker deadlock @@ -767,7 +805,9 @@ async def flush_all_jobs( from rq.command import send_stop_job_command send_stop_job_command(redis_conn, job_id) - logger.info(f"Sent stop command to worker for job {job_id}") + logger.info( + f"Sent stop command to worker for job {job_id}" + ) # Don't delete yet - let worker move it to canceled/failed registry # It will be cleaned up on next flush or by worker cleanup continue @@ -778,9 +818,13 @@ async def flush_all_jobs( # If stop fails, try to cancel it (may already be finishing) try: job.cancel() - logger.info(f"Cancelled job {job_id} after stop failed") + logger.info( + f"Cancelled job {job_id} after stop failed" + ) except Exception as cancel_error: - logger.warning(f"Could not cancel job {job_id}: {cancel_error}") + logger.warning( + f"Could not cancel job {job_id}: {cancel_error}" + ) # For non-running jobs, safe to delete immediately job.delete() @@ -795,7 +839,9 @@ async def flush_all_jobs( f"Removed stale job reference {job_id} from {registry_name} registry" ) except Exception as reg_error: - logger.error(f"Could not remove {job_id} from registry: {reg_error}") + logger.error( + f"Could not remove {job_id} from registry: {reg_error}" + ) # Also clean up audio streams and consumer locks deleted_keys = 0 @@ -809,7 +855,9 @@ async def flush_all_jobs( # Delete audio streams cursor = 0 while True: - cursor, keys = await async_redis.scan(cursor, match="audio:*", count=1000) + cursor, keys = await async_redis.scan( + cursor, match="audio:*", count=1000 + ) if keys: await async_redis.delete(*keys) deleted_keys += len(keys) @@ -819,7 +867,9 @@ async def flush_all_jobs( # Delete consumer locks cursor = 0 while True: - cursor, keys = await async_redis.scan(cursor, match="consumer:*", count=1000) + cursor, keys = await async_redis.scan( + cursor, match="consumer:*", count=1000 + ) if keys: await async_redis.delete(*keys) deleted_keys += len(keys) @@ -848,7 +898,9 @@ async def flush_all_jobs( except Exception as e: logger.error(f"Failed to flush all jobs: {e}") - raise HTTPException(status_code=500, detail=f"Failed to flush all jobs: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to flush all jobs: {str(e)}" + ) @router.get("/sessions") @@ -868,7 +920,9 @@ async def get_redis_sessions( session_keys = [] cursor = b"0" while cursor and len(session_keys) < limit: - cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=limit) + cursor, keys = await redis_client.scan( + cursor, match="audio:session:*", count=limit + ) session_keys.extend(keys[: limit - len(session_keys)]) # Get session info @@ -880,8 +934,12 @@ async def get_redis_sessions( session_id = key.decode().replace("audio:session:", "") # Get conversation count for this session - conversation_count_key = f"session:conversation_count:{session_id}" - conversation_count_bytes = await redis_client.get(conversation_count_key) + conversation_count_key = ( + f"session:conversation_count:{session_id}" + ) + conversation_count_bytes = await redis_client.get( + conversation_count_key + ) conversation_count = ( int(conversation_count_bytes.decode()) if conversation_count_bytes @@ -892,16 +950,25 @@ async def get_redis_sessions( { "session_id": session_id, "user_id": session_data.get(b"user_id", b"").decode(), - "client_id": session_data.get(b"client_id", b"").decode(), - "stream_name": session_data.get(b"stream_name", b"").decode(), + "client_id": session_data.get( + b"client_id", b"" + ).decode(), + "stream_name": session_data.get( + b"stream_name", b"" + ).decode(), "provider": session_data.get(b"provider", b"").decode(), "mode": session_data.get(b"mode", b"").decode(), "status": session_data.get(b"status", b"").decode(), - "started_at": session_data.get(b"started_at", b"").decode(), + "started_at": session_data.get( + b"started_at", b"" + ).decode(), "chunks_published": int( - session_data.get(b"chunks_published", b"0").decode() or 0 + session_data.get(b"chunks_published", b"0").decode() + or 0 ), - "last_chunk_at": session_data.get(b"last_chunk_at", b"").decode(), + "last_chunk_at": session_data.get( + b"last_chunk_at", b"" + ).decode(), "conversation_count": conversation_count, } ) @@ -944,7 +1011,9 @@ async def clear_old_sessions( session_keys = [] cursor = b"0" while cursor: - cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=100) + cursor, keys = await redis_client.scan( + cursor, match="audio:session:*", count=100 + ) session_keys.extend(keys) # Check each session and delete if old @@ -963,13 +1032,18 @@ async def clear_old_sessions( except Exception as e: logger.error(f"Error processing session {key}: {e}") - return {"deleted_count": deleted_count, "cutoff_seconds": older_than_seconds} + return { + "deleted_count": deleted_count, + "cutoff_seconds": older_than_seconds, + } finally: await redis_client.aclose() except Exception as e: logger.error(f"Failed to clear sessions: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to clear sessions: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to clear sessions: {str(e)}" + ) @router.get("/dashboard") @@ -1021,11 +1095,17 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100): if status_name == "queued": job_ids = queue.job_ids[:limit] elif status_name == "started": # RQ standard, not "processing" - job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[:limit] + job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[ + :limit + ] elif status_name == "finished": # RQ standard, not "completed" - job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[:limit] + job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[ + :limit + ] elif status_name == "failed": - job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[:limit] + job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[ + :limit + ] else: continue @@ -1036,7 +1116,9 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100): # Check user permission if not current_user.is_superuser: - job_user_id = job.kwargs.get("user_id") if job.kwargs else None + job_user_id = ( + job.kwargs.get("user_id") if job.kwargs else None + ) if job_user_id != str(current_user.user_id): continue @@ -1045,24 +1127,38 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100): { "job_id": job.id, "job_type": ( - job.func_name.split(".")[-1] if job.func_name else "unknown" + job.func_name.split(".")[-1] + if job.func_name + else "unknown" + ), + "user_id": ( + job.kwargs.get("user_id") + if job.kwargs + else None ), - "user_id": (job.kwargs.get("user_id") if job.kwargs else None), "status": status_name, "priority": "normal", # RQ doesn't have priority concept "data": {"description": job.description or ""}, "result": job.result, "meta": job.meta if job.meta else {}, "kwargs": job.kwargs if job.kwargs else {}, - "error_message": (str(job.exc_info) if job.exc_info else None), + "error_message": ( + str(job.exc_info) if job.exc_info else None + ), "created_at": ( - job.created_at.isoformat() if job.created_at else None + job.created_at.isoformat() + if job.created_at + else None ), "started_at": ( - job.started_at.isoformat() if job.started_at else None + job.started_at.isoformat() + if job.started_at + else None ), "ended_at": ( - job.ended_at.isoformat() if job.ended_at else None + job.ended_at.isoformat() + if job.ended_at + else None ), "retry_count": 0, # RQ doesn't track this by default "max_retries": 0, @@ -1178,7 +1274,11 @@ def get_job_status(job): # Check user permission if not current_user.is_superuser: - job_user_id = job.kwargs.get("user_id") if job.kwargs else None + job_user_id = ( + job.kwargs.get("user_id") + if job.kwargs + else None + ) if job_user_id != str(current_user.user_id): continue @@ -1194,13 +1294,19 @@ def get_job_status(job): "queue": queue_name, "status": get_job_status(job), "created_at": ( - job.created_at.isoformat() if job.created_at else None + job.created_at.isoformat() + if job.created_at + else None ), "started_at": ( - job.started_at.isoformat() if job.started_at else None + job.started_at.isoformat() + if job.started_at + else None ), "ended_at": ( - job.ended_at.isoformat() if job.ended_at else None + job.ended_at.isoformat() + if job.ended_at + else None ), "description": job.description or "", "result": job.result, @@ -1263,12 +1369,20 @@ async def fetch_events(): ) queued_jobs = results[0] if not isinstance(results[0], Exception) else [] - started_jobs = results[1] if not isinstance(results[1], Exception) else [] # RQ standard - finished_jobs = results[2] if not isinstance(results[2], Exception) else [] # RQ standard + started_jobs = ( + results[1] if not isinstance(results[1], Exception) else [] + ) # RQ standard + finished_jobs = ( + results[2] if not isinstance(results[2], Exception) else [] + ) # RQ standard failed_jobs = results[3] if not isinstance(results[3], Exception) else [] - stats = results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0} + stats = ( + results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0} + ) streaming_status = ( - results[5] if not isinstance(results[5], Exception) else {"active_sessions": []} + results[5] + if not isinstance(results[5], Exception) + else {"active_sessions": []} ) events = results[6] if not isinstance(results[6], Exception) else [] recent_conversations = [] @@ -1287,7 +1401,9 @@ async def fetch_events(): { "conversation_id": conv.conversation_id, "user_id": str(conv.user_id) if conv.user_id else None, - "created_at": (conv.created_at.isoformat() if conv.created_at else None), + "created_at": ( + conv.created_at.isoformat() if conv.created_at else None + ), "title": conv.title, "summary": conv.summary, "transcript_text": ( @@ -1315,4 +1431,6 @@ async def fetch_events(): except Exception as e: logger.error(f"Failed to get dashboard data: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to get dashboard data: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get dashboard data: {str(e)}" + ) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py index 277d7dc1..0f580be7 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py @@ -29,6 +29,7 @@ # Request models for memory config endpoints class MemoryConfigRequest(BaseModel): """Request model for memory configuration validation and updates.""" + config_yaml: str @@ -68,8 +69,7 @@ async def get_diarization_settings(current_user: User = Depends(current_superuse @router.post("/diarization-settings") async def save_diarization_settings( - settings: dict, - current_user: User = Depends(current_superuser) + settings: dict, current_user: User = Depends(current_superuser) ): """Save diarization settings. Admin only.""" return await system_controller.save_diarization_settings_controller(settings) @@ -83,32 +83,36 @@ async def get_misc_settings(current_user: User = Depends(current_superuser)): @router.post("/misc-settings") async def save_misc_settings( - settings: dict, - current_user: User = Depends(current_superuser) + settings: dict, current_user: User = Depends(current_superuser) ): """Save miscellaneous configuration settings. Admin only.""" return await system_controller.save_misc_settings_controller(settings) @router.get("/cleanup-settings") -async def get_cleanup_settings( - current_user: User = Depends(current_superuser) -): +async def get_cleanup_settings(current_user: User = Depends(current_superuser)): """Get cleanup configuration settings. Admin only.""" return await system_controller.get_cleanup_settings_controller(current_user) @router.post("/cleanup-settings") async def save_cleanup_settings( - auto_cleanup_enabled: bool = Body(..., description="Enable automatic cleanup of soft-deleted conversations"), - retention_days: int = Body(..., ge=1, le=365, description="Number of days to keep soft-deleted conversations"), - current_user: User = Depends(current_superuser) + auto_cleanup_enabled: bool = Body( + ..., description="Enable automatic cleanup of soft-deleted conversations" + ), + retention_days: int = Body( + ..., + ge=1, + le=365, + description="Number of days to keep soft-deleted conversations", + ), + current_user: User = Depends(current_superuser), ): """Save cleanup configuration settings. Admin only.""" return await system_controller.save_cleanup_settings_controller( auto_cleanup_enabled=auto_cleanup_enabled, retention_days=retention_days, - user=current_user + user=current_user, ) @@ -120,11 +124,12 @@ async def get_speaker_configuration(current_user: User = Depends(current_active_ @router.post("/speaker-configuration") async def update_speaker_configuration( - primary_speakers: list[dict], - current_user: User = Depends(current_active_user) + primary_speakers: list[dict], current_user: User = Depends(current_active_user) ): """Update current user's primary speakers configuration.""" - return await system_controller.update_speaker_configuration(current_user, primary_speakers) + return await system_controller.update_speaker_configuration( + current_user, primary_speakers + ) @router.get("/enrolled-speakers") @@ -141,6 +146,7 @@ async def get_speaker_service_status(current_user: User = Depends(current_superu # LLM Operations Configuration Endpoints + @router.get("/admin/llm-operations") async def get_llm_operations(current_user: User = Depends(current_superuser)): """Get LLM operation configurations. Admin only.""" @@ -149,8 +155,7 @@ async def get_llm_operations(current_user: User = Depends(current_superuser)): @router.post("/admin/llm-operations") async def save_llm_operations( - operations: dict, - current_user: User = Depends(current_superuser) + operations: dict, current_user: User = Depends(current_superuser) ): """Save LLM operation configurations. Admin only.""" return await system_controller.save_llm_operations(operations) @@ -159,7 +164,7 @@ async def save_llm_operations( @router.post("/admin/llm-operations/test") async def test_llm_model( model_name: Optional[str] = Body(None, embed=True), - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Test an LLM model connection with a trivial prompt. Admin only.""" return await system_controller.test_llm_model(model_name) @@ -171,10 +176,11 @@ async def get_memory_config_raw(current_user: User = Depends(current_superuser)) """Get memory configuration YAML from config.yml. Admin only.""" return await system_controller.get_memory_config_raw() + @router.post("/admin/memory/config/raw") async def update_memory_config_raw( config_yaml: str = Body(..., media_type="text/plain"), - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Save memory YAML to config.yml and hot-reload. Admin only.""" return await system_controller.update_memory_config_raw(config_yaml) @@ -191,8 +197,7 @@ async def validate_memory_config_raw( @router.post("/admin/memory/config/validate") async def validate_memory_config( - request: MemoryConfigRequest, - current_user: User = Depends(current_superuser) + request: MemoryConfigRequest, current_user: User = Depends(current_superuser) ): """Validate memory configuration YAML sent as JSON (used by tests). Admin only.""" return await system_controller.validate_memory_config(request.config_yaml) @@ -212,6 +217,7 @@ async def delete_all_user_memories(current_user: User = Depends(current_active_u # Chat Configuration Management Endpoints + @router.get("/admin/chat/config", response_class=Response) async def get_chat_config(current_user: User = Depends(current_superuser)): """Get chat configuration as YAML. Admin only.""" @@ -225,13 +231,12 @@ async def get_chat_config(current_user: User = Depends(current_superuser)): @router.post("/admin/chat/config") async def save_chat_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Save chat configuration from YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.save_chat_config_yaml(yaml_str) return JSONResponse(content=result) except ValueError as e: @@ -243,13 +248,12 @@ async def save_chat_config( @router.post("/admin/chat/config/validate") async def validate_chat_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Validate chat configuration YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.validate_chat_config_yaml(yaml_str) return JSONResponse(content=result) except Exception as e: @@ -259,6 +263,7 @@ async def validate_chat_config( # Plugin Configuration Management Endpoints + @router.get("/admin/plugins/config", response_class=Response) async def get_plugins_config(current_user: User = Depends(current_superuser)): """Get plugins configuration as YAML. Admin only.""" @@ -272,13 +277,12 @@ async def get_plugins_config(current_user: User = Depends(current_superuser)): @router.post("/admin/plugins/config") async def save_plugins_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Save plugins configuration from YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.save_plugins_config_yaml(yaml_str) return JSONResponse(content=result) except ValueError as e: @@ -290,13 +294,12 @@ async def save_plugins_config( @router.post("/admin/plugins/config/validate") async def validate_plugins_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Validate plugins configuration YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.validate_plugins_config_yaml(yaml_str) return JSONResponse(content=result) except Exception as e: @@ -306,6 +309,7 @@ async def validate_plugins_config( # Structured Plugin Configuration Endpoints (Form-based UI) + @router.post("/admin/plugins/reload") async def reload_plugins( request: Request, @@ -360,9 +364,16 @@ async def get_plugins_health(current_user: User = Depends(current_superuser)): """Get plugin health status for all registered plugins. Admin only.""" try: from advanced_omi_backend.services.plugin_service import get_plugin_router + plugin_router = get_plugin_router() if not plugin_router: - return {"total": 0, "initialized": 0, "failed": 0, "registered": 0, "plugins": []} + return { + "total": 0, + "initialized": 0, + "failed": 0, + "registered": 0, + "plugins": [], + } return plugin_router.get_health_summary() except Exception as e: logger.error(f"Failed to get plugins health: {e}") @@ -377,6 +388,7 @@ async def get_plugins_connectivity(current_user: User = Depends(current_superuse """ try: from advanced_omi_backend.services.plugin_service import get_plugin_router + plugin_router = get_plugin_router() if not plugin_router: return {"plugins": {}} @@ -406,6 +418,7 @@ async def get_plugins_metadata(current_user: User = Depends(current_superuser)): class PluginConfigRequest(BaseModel): """Request model for structured plugin configuration updates.""" + orchestration: Optional[dict] = None settings: Optional[dict] = None env_vars: Optional[dict] = None @@ -413,6 +426,7 @@ class PluginConfigRequest(BaseModel): class CreatePluginRequest(BaseModel): """Request model for creating a new plugin.""" + plugin_name: str description: str events: list[str] = [] @@ -421,12 +435,14 @@ class CreatePluginRequest(BaseModel): class WritePluginCodeRequest(BaseModel): """Request model for writing plugin code.""" + code: str config_yml: Optional[str] = None class PluginAssistantRequest(BaseModel): """Request model for plugin assistant chat.""" + messages: list[dict] @@ -434,7 +450,7 @@ class PluginAssistantRequest(BaseModel): async def update_plugin_config_structured( plugin_id: str, config: PluginConfigRequest, - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Update plugin configuration from structured JSON (form data). Admin only. @@ -445,7 +461,9 @@ async def update_plugin_config_structured( """ try: config_dict = config.dict(exclude_none=True) - result = await system_controller.update_plugin_config_structured(plugin_id, config_dict) + result = await system_controller.update_plugin_config_structured( + plugin_id, config_dict + ) return JSONResponse(content=result) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -458,7 +476,7 @@ async def update_plugin_config_structured( async def test_plugin_connection( plugin_id: str, config: PluginConfigRequest, - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Test plugin connection/configuration without saving. Admin only. @@ -490,7 +508,9 @@ async def create_plugin( plugin_code=request.plugin_code, ) if not result.get("success"): - raise HTTPException(status_code=400, detail=result.get("error", "Unknown error")) + raise HTTPException( + status_code=400, detail=result.get("error", "Unknown error") + ) return JSONResponse(content=result) except HTTPException: raise @@ -513,7 +533,9 @@ async def write_plugin_code( config_yml=request.config_yml, ) if not result.get("success"): - raise HTTPException(status_code=400, detail=result.get("error", "Unknown error")) + raise HTTPException( + status_code=400, detail=result.get("error", "Unknown error") + ) return JSONResponse(content=result) except HTTPException: raise @@ -535,7 +557,9 @@ async def delete_plugin( remove_files=remove_files, ) if not result.get("success"): - raise HTTPException(status_code=400, detail=result.get("error", "Unknown error")) + raise HTTPException( + status_code=400, detail=result.get("error", "Unknown error") + ) return JSONResponse(content=result) except HTTPException: raise @@ -572,25 +596,34 @@ async def event_stream(): @router.get("/streaming/status") -async def get_streaming_status(request: Request, current_user: User = Depends(current_superuser)): +async def get_streaming_status( + request: Request, current_user: User = Depends(current_superuser) +): """Get status of active streaming sessions and Redis Streams health. Admin only.""" return await session_controller.get_streaming_status(request) @router.post("/streaming/cleanup") -async def cleanup_stuck_stream_workers(request: Request, current_user: User = Depends(current_superuser)): +async def cleanup_stuck_stream_workers( + request: Request, current_user: User = Depends(current_superuser) +): """Clean up stuck Redis Stream workers and pending messages. Admin only.""" return await queue_controller.cleanup_stuck_stream_workers(request) @router.post("/streaming/cleanup-sessions") -async def cleanup_old_sessions(request: Request, max_age_seconds: int = 3600, current_user: User = Depends(current_superuser)): +async def cleanup_old_sessions( + request: Request, + max_age_seconds: int = 3600, + current_user: User = Depends(current_superuser), +): """Clean up old session tracking metadata. Admin only.""" return await session_controller.cleanup_old_sessions(request, max_age_seconds) # Memory Provider Configuration Endpoints + @router.get("/admin/memory/provider") async def get_memory_provider(current_user: User = Depends(current_superuser)): """Get current memory provider configuration. Admin only.""" @@ -600,7 +633,7 @@ async def get_memory_provider(current_user: User = Depends(current_superuser)): @router.post("/admin/memory/provider") async def set_memory_provider( provider: str = Body(..., embed=True), - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Set memory provider and restart backend services. Admin only.""" return await system_controller.set_memory_provider(provider) diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/base.py b/backends/advanced/src/advanced_omi_backend/services/memory/base.py index 9eddddbc..242af2e6 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/base.py @@ -136,7 +136,9 @@ async def search_memories( pass @abstractmethod - async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]: + async def get_all_memories( + self, user_id: str, limit: int = 100 + ) -> List[MemoryEntry]: """Get all memories for a specific user. Args: @@ -265,7 +267,10 @@ async def reprocess_memory( @abstractmethod async def delete_memory( - self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + self, + memory_id: str, + user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> bool: """Delete a specific memory by ID. @@ -471,7 +476,11 @@ async def add_memories(self, memories: List[MemoryEntry]) -> List[str]: @abstractmethod async def search_memories( - self, query_embedding: List[float], user_id: str, limit: int, score_threshold: float = 0.0 + self, + query_embedding: List[float], + user_id: str, + limit: int, + score_threshold: float = 0.0, ) -> List[MemoryEntry]: """Search memories using vector similarity. diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py index 2363e5a8..87b6c342 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py @@ -69,10 +69,15 @@ async def initialize(self) -> None: try: # Initialize LLM provider - if self.config.llm_provider in [LLMProviderEnum.OPENAI, LLMProviderEnum.OLLAMA]: + if self.config.llm_provider in [ + LLMProviderEnum.OPENAI, + LLMProviderEnum.OLLAMA, + ]: self.llm_provider = OpenAIProvider(self.config.llm_config) else: - raise ValueError(f"Unsupported LLM provider: {self.config.llm_provider}") + raise ValueError( + f"Unsupported LLM provider: {self.config.llm_provider}" + ) # Initialize vector store if self.config.vector_store_provider == VectorStoreProvider.QDRANT: @@ -175,7 +180,9 @@ async def add_memory( memory_logger.debug(f"🧠 fact_memories_text: {fact_memories_text}") # Simple deduplication of extracted memories within the same call fact_memories_text = self._deduplicate_memories(fact_memories_text) - memory_logger.debug(f"🧠 fact_memories_text after deduplication: {fact_memories_text}") + memory_logger.debug( + f"🧠 fact_memories_text after deduplication: {fact_memories_text}" + ) # Generate embeddings embeddings = await asyncio.wait_for( self.llm_provider.generate_embeddings(fact_memories_text), @@ -206,7 +213,12 @@ async def add_memory( memory_logger.info(f"🔍 Not allowing update for {source_id}") # Add all extracted memories normally memory_entries = self._create_memory_entries( - fact_memories_text, embeddings, client_id, source_id, user_id, user_email + fact_memories_text, + embeddings, + client_id, + source_id, + user_id, + user_email, ) # Store new entries in vector database @@ -216,10 +228,14 @@ async def add_memory( # Update database relationships if helper provided if created_ids and db_helper: - await self._update_database_relationships(db_helper, source_id, created_ids) + await self._update_database_relationships( + db_helper, source_id, created_ids + ) if created_ids: - memory_logger.info(f"✅ Upserted {len(created_ids)} memories for {source_id}") + memory_logger.info( + f"✅ Upserted {len(created_ids)} memories for {source_id}" + ) return True, created_ids # No memories created - this is a valid outcome (duplicates, no extractable facts, etc.) @@ -276,7 +292,9 @@ async def search_memories( memory_logger.error(f"Search memories failed: {e}") return [] - async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]: + async def get_all_memories( + self, user_id: str, limit: int = 100 + ) -> List[MemoryEntry]: """Get all memories for a specific user. Retrieves all stored memories for the given user without @@ -294,7 +312,9 @@ async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryE try: memories = await self.vector_store.get_memories(user_id, limit) - memory_logger.info(f"📚 Retrieved {len(memories)} memories for user {user_id}") + memory_logger.info( + f"📚 Retrieved {len(memories)} memories for user {user_id}" + ) return memories except Exception as e: memory_logger.error(f"Get all memories failed: {e}") @@ -330,7 +350,9 @@ async def get_memories_by_source( await self.initialize() try: - memories = await self.vector_store.get_memories_by_source(user_id, source_id, limit) + memories = await self.vector_store.get_memories_by_source( + user_id, source_id, limit + ) memory_logger.info( f"📚 Retrieved {len(memories)} memories for source {source_id} (user {user_id})" ) @@ -416,7 +438,9 @@ async def update_memory( new_embedding = existing_memory.embedding else: # No existing embedding, generate one - embeddings = await self.llm_provider.generate_embeddings([new_content]) + embeddings = await self.llm_provider.generate_embeddings( + [new_content] + ) new_embedding = embeddings[0] # Update in vector store @@ -435,11 +459,16 @@ async def update_memory( return success except Exception as e: - memory_logger.error(f"Error updating memory {memory_id}: {e}", exc_info=True) + memory_logger.error( + f"Error updating memory {memory_id}: {e}", exc_info=True + ) return False async def delete_memory( - self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + self, + memory_id: str, + user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> bool: """Delete a specific memory by ID. @@ -538,7 +567,9 @@ async def reprocess_memory( try: # 1. Get existing memories for this conversation - existing_memories = await self.vector_store.get_memories_by_source(user_id, source_id) + existing_memories = await self.vector_store.get_memories_by_source( + user_id, source_id + ) # 2. If no existing memories, fall back to normal extraction if not existing_memories: @@ -635,9 +666,13 @@ async def reprocess_memory( self.llm_provider.generate_embeddings(texts_needing_embeddings), timeout=self.config.timeout_seconds, ) - text_to_embedding = dict(zip(texts_needing_embeddings, embeddings, strict=True)) + text_to_embedding = dict( + zip(texts_needing_embeddings, embeddings, strict=True) + ) except Exception as e: - memory_logger.warning(f"Batch embedding generation failed for reprocess: {e}") + memory_logger.warning( + f"Batch embedding generation failed for reprocess: {e}" + ) # 8. Apply the actions (reuses existing infrastructure) created_ids = await self._apply_memory_actions( @@ -651,14 +686,17 @@ async def reprocess_memory( ) memory_logger.info( - f"✅ Reprocess complete for {source_id}: " f"{len(created_ids)} memories affected" + f"✅ Reprocess complete for {source_id}: " + f"{len(created_ids)} memories affected" ) return True, created_ids except Exception as e: memory_logger.error(f"❌ Reprocess memory failed for {source_id}: {e}") # Fall back to normal extraction on any unexpected error - memory_logger.info(f"🔄 Falling back to normal extraction after reprocess error") + memory_logger.info( + f"🔄 Falling back to normal extraction after reprocess error" + ) return await self.add_memory( transcript, client_id, @@ -699,7 +737,8 @@ def _format_speaker_diff(transcript_diff: list) -> str: ) elif change_type == "new_segment": lines.append( - f"- New segment: {change.get('speaker', '?')}: " f"\"{change.get('text', '')}\"" + f"- New segment: {change.get('speaker', '?')}: " + f"\"{change.get('text', '')}\"" ) return "\n".join(lines) @@ -834,7 +873,9 @@ async def _process_memory_updates( for mem in candidates: retrieved_old_memory.append({"id": mem.id, "text": mem.content}) except Exception as e_search: - memory_logger.warning(f"Search failed while preparing updates: {e_search}") + memory_logger.warning( + f"Search failed while preparing updates: {e_search}" + ) # Dedupe by id and prepare temp mapping uniq = {} @@ -854,7 +895,9 @@ async def _process_memory_updates( f"🔍 Asking LLM for actions with {len(retrieved_old_memory)} old memories " f"and {len(memories_text)} new facts" ) - memory_logger.debug(f"🧠 Individual facts being sent to LLM: {memories_text}") + memory_logger.debug( + f"🧠 Individual facts being sent to LLM: {memories_text}" + ) # add update or delete etc actions using DEFAULT_UPDATE_MEMORY_PROMPT actions_obj = await self.llm_provider.propose_memory_actions( @@ -862,7 +905,9 @@ async def _process_memory_updates( new_facts=memories_text, custom_prompt=None, ) - memory_logger.info(f"📝 UpdateMemory LLM returned: {type(actions_obj)} - {actions_obj}") + memory_logger.info( + f"📝 UpdateMemory LLM returned: {type(actions_obj)} - {actions_obj}" + ) except Exception as e_actions: memory_logger.error(f"LLM propose_memory_actions failed: {e_actions}") actions_obj = {} @@ -899,7 +944,9 @@ def _normalize_actions(self, actions_obj: Any) -> List[dict]: if isinstance(memory_field, list): actions_list = memory_field elif isinstance(actions_obj.get("facts"), list): - actions_list = [{"event": "ADD", "text": str(t)} for t in actions_obj["facts"]] + actions_list = [ + {"event": "ADD", "text": str(t)} for t in actions_obj["facts"] + ] else: # Pick first list field found for v in actions_obj.values(): @@ -909,7 +956,9 @@ def _normalize_actions(self, actions_obj: Any) -> List[dict]: elif isinstance(actions_obj, list): actions_list = actions_obj - memory_logger.info(f"📋 Normalized to {len(actions_list)} actions: {actions_list}") + memory_logger.info( + f"📋 Normalized to {len(actions_list)} actions: {actions_list}" + ) except Exception as normalize_err: memory_logger.warning(f"Failed to normalize actions: {normalize_err}") actions_list = [] @@ -959,7 +1008,9 @@ async def _apply_memory_actions( memory_logger.warning(f"Skipping action with no text: {resp}") continue - memory_logger.debug(f"Processing action: {event_type} - {action_text[:50]}...") + memory_logger.debug( + f"Processing action: {event_type} - {action_text[:50]}..." + ) base_metadata = { "source": "offline_streaming", @@ -981,7 +1032,9 @@ async def _apply_memory_actions( ) emb = gen[0] if gen else None except Exception as gen_err: - memory_logger.warning(f"Embedding generation failed for action text: {gen_err}") + memory_logger.warning( + f"Embedding generation failed for action text: {gen_err}" + ) emb = None if event_type == "ADD": @@ -1003,7 +1056,9 @@ async def _apply_memory_actions( updated_at=current_time, ) ) - memory_logger.info(f"➕ Added new memory: {memory_id} - {action_text[:50]}...") + memory_logger.info( + f"➕ Added new memory: {memory_id} - {action_text[:50]}..." + ) elif event_type == "UPDATE": provided_id = resp.get("id") @@ -1023,11 +1078,15 @@ async def _apply_memory_actions( f"🔄 Updated memory: {actual_id} - {action_text[:50]}..." ) else: - memory_logger.warning(f"Failed to update memory {actual_id}") + memory_logger.warning( + f"Failed to update memory {actual_id}" + ) except Exception as update_err: memory_logger.error(f"Update memory failed: {update_err}") else: - memory_logger.warning(f"Skipping UPDATE due to missing ID or embedding") + memory_logger.warning( + f"Skipping UPDATE due to missing ID or embedding" + ) elif event_type == "DELETE": provided_id = resp.get("id") @@ -1038,14 +1097,20 @@ async def _apply_memory_actions( if deleted: memory_logger.info(f"🗑️ Deleted memory {actual_id}") else: - memory_logger.warning(f"Failed to delete memory {actual_id}") + memory_logger.warning( + f"Failed to delete memory {actual_id}" + ) except Exception as delete_err: memory_logger.error(f"Delete memory failed: {delete_err}") else: - memory_logger.warning(f"Skipping DELETE due to missing ID: {provided_id}") + memory_logger.warning( + f"Skipping DELETE due to missing ID: {provided_id}" + ) elif event_type == "NONE": - memory_logger.debug(f"NONE action - no changes for: {action_text[:50]}...") + memory_logger.debug( + f"NONE action - no changes for: {action_text[:50]}..." + ) continue else: memory_logger.warning(f"Unknown event type: {event_type}") @@ -1108,7 +1173,9 @@ async def example_usage(): print(f"🔍 Found {len(results)} search results") # Get all memories - all_memories = await memory_service.get_all_memories(user_id="user789", limit=100) + all_memories = await memory_service.get_all_memories( + user_id="user789", limit=100 + ) print(f"📚 Total memories: {len(all_memories)}") # Clean up test data diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py index 4d440fba..27f1430c 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py @@ -147,7 +147,9 @@ def __init__(self, config: Dict[str, Any]): # Ignore provider-specific envs; use registry as single source of truth registry = get_models_registry() if not registry: - raise RuntimeError("config.yml not found or invalid; cannot initialize model registry") + raise RuntimeError( + "config.yml not found or invalid; cannot initialize model registry" + ) self._registry = registry @@ -167,8 +169,12 @@ def __init__(self, config: Dict[str, Any]): self.embedding_model = ( self.embed_def.model_name if self.embed_def else self.llm_def.model_name ) - self.embedding_api_key = self.embed_def.api_key if self.embed_def else self.api_key - self.embedding_base_url = self.embed_def.model_url if self.embed_def else self.base_url + self.embedding_api_key = ( + self.embed_def.api_key if self.embed_def else self.api_key + ) + self.embedding_base_url = ( + self.embed_def.model_url if self.embed_def else self.base_url + ) # CRITICAL: Validate API keys are present - fail fast instead of hanging if not self.api_key or self.api_key.strip() == "": @@ -178,7 +184,9 @@ def __init__(self, config: Dict[str, Any]): f"Cannot proceed without valid API credentials." ) - if self.embed_def and (not self.embedding_api_key or self.embedding_api_key.strip() == ""): + if self.embed_def and ( + not self.embedding_api_key or self.embedding_api_key.strip() == "" + ): raise RuntimeError( f"API key is missing or empty for embedding provider '{self.embed_def.model_provider}' (model: {self.embedding_model}). " f"Please set the API key in config.yml or environment variables." @@ -225,7 +233,9 @@ async def _embed_for_chunking(texts: List[str]) -> List[List[float]]: model=self.embedding_model, ) - chunking_config = self._registry.memory.get("extraction", {}).get("chunking", {}) + chunking_config = self._registry.memory.get("extraction", {}).get( + "chunking", {} + ) dialogue_turns = [line for line in text.split("\n") if line.strip()] text_chunks = await semantic_chunk_text( text, @@ -433,7 +443,9 @@ async def propose_reprocess_actions( else: try: registry = get_prompt_registry() - system_prompt = await registry.get_prompt("memory.reprocess_speaker_update") + system_prompt = await registry.get_prompt( + "memory.reprocess_speaker_update" + ) except Exception as e: memory_logger.debug( f"Registry prompt fetch failed for " @@ -519,12 +531,16 @@ def _parse_memories_content(content: str) -> List[str]: for key in ("facts", "preferences"): value = parsed.get(key) if isinstance(value, list): - collected.extend([str(item).strip() for item in value if str(item).strip()]) + collected.extend( + [str(item).strip() for item in value if str(item).strip()] + ) # If the dict didn't contain expected keys, try to flatten any list values if not collected: for value in parsed.values(): if isinstance(value, list): - collected.extend([str(item).strip() for item in value if str(item).strip()]) + collected.extend( + [str(item).strip() for item in value if str(item).strip()] + ) if collected: return collected except Exception: @@ -559,13 +575,17 @@ def _try_parse_list_or_object(text: str) -> List[str] | None: for key in ("facts", "preferences"): value = data.get(key) if isinstance(value, list): - collected.extend([str(item).strip() for item in value if str(item).strip()]) + collected.extend( + [str(item).strip() for item in value if str(item).strip()] + ) if collected: return collected # As a last attempt, flatten any list values for value in data.values(): if isinstance(value, list): - collected.extend([str(item).strip() for item in value if str(item).strip()]) + collected.extend( + [str(item).strip() for item in value if str(item).strip()] + ) return collected if collected else None except Exception: return None diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py b/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py index 02e5b37c..520e94ad 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py @@ -76,32 +76,120 @@ async def transcribe( # Generate mock words with timestamps (spread across audio duration) words = [ - {"word": "This", "start": 0.0, "end": 0.3, "confidence": 0.99, "speaker": 0}, + { + "word": "This", + "start": 0.0, + "end": 0.3, + "confidence": 0.99, + "speaker": 0, + }, {"word": "is", "start": 0.3, "end": 0.5, "confidence": 0.99, "speaker": 0}, {"word": "a", "start": 0.5, "end": 0.6, "confidence": 0.99, "speaker": 0}, - {"word": "mock", "start": 0.6, "end": 0.9, "confidence": 0.99, "speaker": 0}, - {"word": "transcription", "start": 0.9, "end": 1.5, "confidence": 0.98, "speaker": 0}, + { + "word": "mock", + "start": 0.6, + "end": 0.9, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "transcription", + "start": 0.9, + "end": 1.5, + "confidence": 0.98, + "speaker": 0, + }, {"word": "for", "start": 1.5, "end": 1.7, "confidence": 0.99, "speaker": 0}, - {"word": "testing", "start": 1.7, "end": 2.1, "confidence": 0.99, "speaker": 0}, - {"word": "purposes", "start": 2.1, "end": 2.6, "confidence": 0.97, "speaker": 0}, + { + "word": "testing", + "start": 1.7, + "end": 2.1, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "purposes", + "start": 2.1, + "end": 2.6, + "confidence": 0.97, + "speaker": 0, + }, {"word": "It", "start": 2.6, "end": 2.8, "confidence": 0.99, "speaker": 0}, - {"word": "contains", "start": 2.8, "end": 3.2, "confidence": 0.99, "speaker": 0}, - {"word": "enough", "start": 3.2, "end": 3.5, "confidence": 0.99, "speaker": 0}, - {"word": "words", "start": 3.5, "end": 3.8, "confidence": 0.99, "speaker": 0}, + { + "word": "contains", + "start": 2.8, + "end": 3.2, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "enough", + "start": 3.2, + "end": 3.5, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "words", + "start": 3.5, + "end": 3.8, + "confidence": 0.99, + "speaker": 0, + }, {"word": "to", "start": 3.8, "end": 3.9, "confidence": 0.99, "speaker": 0}, - {"word": "meet", "start": 3.9, "end": 4.1, "confidence": 0.99, "speaker": 0}, - {"word": "minimum", "start": 4.1, "end": 4.5, "confidence": 0.98, "speaker": 0}, - {"word": "length", "start": 4.5, "end": 4.8, "confidence": 0.99, "speaker": 0}, - {"word": "requirements", "start": 4.8, "end": 5.4, "confidence": 0.98, "speaker": 0}, + { + "word": "meet", + "start": 3.9, + "end": 4.1, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "minimum", + "start": 4.1, + "end": 4.5, + "confidence": 0.98, + "speaker": 0, + }, + { + "word": "length", + "start": 4.5, + "end": 4.8, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "requirements", + "start": 4.8, + "end": 5.4, + "confidence": 0.98, + "speaker": 0, + }, {"word": "for", "start": 5.4, "end": 5.6, "confidence": 0.99, "speaker": 0}, - {"word": "automated", "start": 5.6, "end": 6.1, "confidence": 0.98, "speaker": 0}, - {"word": "testing", "start": 6.1, "end": 6.5, "confidence": 0.99, "speaker": 0}, + { + "word": "automated", + "start": 5.6, + "end": 6.1, + "confidence": 0.98, + "speaker": 0, + }, + { + "word": "testing", + "start": 6.1, + "end": 6.5, + "confidence": 0.99, + "speaker": 0, + }, ] # Mock segments (single speaker for simplicity) segments = [{"speaker": 0, "start": 0.0, "end": 6.5, "text": mock_transcript}] - return {"text": mock_transcript, "words": words, "segments": segments if diarize else []} + return { + "text": mock_transcript, + "words": words, + "segments": segments if diarize else [], + } async def connect(self, client_id: Optional[str] = None): """Initialize the mock provider (no-op).""" diff --git a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py index 96c52f57..835d52d9 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py @@ -43,7 +43,10 @@ def is_meaningful_speech(combined_results: dict) -> bool: if not combined_results.get("text"): return False - transcript_data = {"text": combined_results["text"], "words": combined_results.get("words", [])} + transcript_data = { + "text": combined_results["text"], + "words": combined_results.get("words", []), + } speech_analysis = analyze_speech(transcript_data) return speech_analysis["has_speech"] @@ -83,19 +86,25 @@ def analyze_speech(transcript_data: dict) -> dict: settings = get_speech_detection_settings() words = transcript_data.get("words", []) - logger.info(f"🔬 analyze_speech: words_list_length={len(words)}, settings={settings}") + logger.info( + f"🔬 analyze_speech: words_list_length={len(words)}, settings={settings}" + ) if words and len(words) > 0: logger.info(f"📝 First 3 words: {words[:3]}") # Method 1: Word-level analysis (preferred - has confidence scores and timing) if words: # Filter by confidence threshold - valid_words = [w for w in words if (w.get("confidence") or 0) >= settings["min_confidence"]] + valid_words = [ + w for w in words if (w.get("confidence") or 0) >= settings["min_confidence"] + ] if len(valid_words) < settings["min_words"]: # Not enough valid words in word-level data - fall through to text-only analysis # This handles cases where word-level data is incomplete or low confidence - logger.debug(f"Only {len(valid_words)} valid words, falling back to text-only analysis") + logger.debug( + f"Only {len(valid_words)} valid words, falling back to text-only analysis" + ) # Continue to Method 2 (don't return early) else: # Calculate speech duration from word timing @@ -113,7 +122,9 @@ def analyze_speech(transcript_data: dict) -> dict: # If no timing data (duration = 0), fall back to text-only analysis # This happens with some streaming transcription services if speech_duration == 0: - logger.debug("Word timing data missing, falling back to text-only analysis") + logger.debug( + "Word timing data missing, falling back to text-only analysis" + ) # Continue to Method 2 (text-only fallback) else: # Check minimum duration threshold when we have timing data @@ -245,7 +256,9 @@ async def generate_title_and_summary( # Fallback words = text.split()[:6] fallback_title = " ".join(words) - fallback_title = fallback_title[:40] + "..." if len(fallback_title) > 40 else fallback_title + fallback_title = ( + fallback_title[:40] + "..." if len(fallback_title) > 40 else fallback_title + ) fallback_summary = text[:120] + "..." if len(text) > 120 else text return fallback_title or "Conversation", fallback_summary or "No content" @@ -330,7 +343,10 @@ async def generate_detailed_summary( """ summary = await async_generate(prompt, operation="detailed_summary") - return summary.strip().strip('"').strip("'") or "No meaningful content to summarize" + return ( + summary.strip().strip('"').strip("'") + or "No meaningful content to summarize" + ) except Exception as e: logger.warning(f"Failed to generate detailed summary: {e}") @@ -373,7 +389,10 @@ def extract_speakers_from_segments(segments: list) -> List[str]: async def track_speech_activity( - speech_analysis: Dict[str, Any], last_word_count: int, conversation_id: str, redis_client + speech_analysis: Dict[str, Any], + last_word_count: int, + conversation_id: str, + redis_client, ) -> tuple[float, int]: """ Track new speech activity and update last speech timestamp using audio timestamps. @@ -477,7 +496,9 @@ async def update_job_progress_metadata( "conversation_id": conversation_id, "client_id": client_id, # Ensure client_id is always present "transcript": ( - combined["text"][:500] + "..." if len(combined["text"]) > 500 else combined["text"] + combined["text"][:500] + "..." + if len(combined["text"]) > 500 + else combined["text"] ), # First 500 chars "transcript_length": len(combined["text"]), "speakers": speakers, @@ -508,7 +529,9 @@ async def mark_conversation_deleted(conversation_id: str, deletion_reason: str) f"🗑️ Marking conversation {conversation_id} as deleted - reason: {deletion_reason}" ) - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if conversation: conversation.deleted = True conversation.deletion_reason = deletion_reason diff --git a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py index ba2a4ee0..e818a2d8 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/conversation_jobs.py @@ -109,7 +109,9 @@ async def handle_end_of_conversation( from advanced_omi_backend.models.conversation import Conversation - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if conversation: # Convert string to enum try: @@ -124,7 +126,9 @@ async def handle_end_of_conversation( f"💾 Saved conversation {conversation_id[:12]} end_reason: {conversation.end_reason}" ) else: - logger.warning(f"⚠️ Conversation {conversation_id} not found for end reason tracking") + logger.warning( + f"⚠️ Conversation {conversation_id} not found for end reason tracking" + ) # Increment conversation count for this session conversation_count_key = f"session:conversation_count:{session_id}" @@ -140,7 +144,9 @@ async def handle_end_of_conversation( ) if status_raw: - status_str = status_raw.decode() if isinstance(status_raw, bytes) else status_raw + status_str = ( + status_raw.decode() if isinstance(status_raw, bytes) else status_raw + ) ws_connected = ( ws_connected_raw.decode() if isinstance(ws_connected_raw, bytes) @@ -264,7 +270,9 @@ def _validate_segments(segments: list) -> list: start = seg.get("start", 0.0) end = seg.get("end", 0.0) if end <= start: - logger.debug(f"Segment {i} has invalid timing (start={start}, end={end}), correcting") + logger.debug( + f"Segment {i} has invalid timing (start={start}, end={end}), correcting" + ) estimated_duration = len(text.split()) * 0.5 # ~0.5 seconds per word seg["end"] = start + estimated_duration @@ -313,7 +321,9 @@ async def _initialize_conversation( conversation = None if existing_conversation_id_bytes: existing_conversation_id = existing_conversation_id_bytes.decode() - logger.info(f"🔍 Found Redis key with conversation_id={existing_conversation_id}") + logger.info( + f"🔍 Found Redis key with conversation_id={existing_conversation_id}" + ) # Try to fetch the existing conversation by conversation_id conversation = await Conversation.find_one( @@ -328,13 +338,16 @@ async def _initialize_conversation( f"processing_status={processing_status}" ) else: - logger.warning(f"⚠️ Conversation {existing_conversation_id} not found in database!") + logger.warning( + f"⚠️ Conversation {existing_conversation_id} not found in database!" + ) # Verify it's a placeholder conversation (always_persist=True, processing_status='pending_transcription') if ( conversation and getattr(conversation, "always_persist", False) - and getattr(conversation, "processing_status", None) == "pending_transcription" + and getattr(conversation, "processing_status", None) + == "pending_transcription" ): logger.info( f"🔄 Reusing placeholder conversation {conversation.conversation_id} for session {session_id}" @@ -353,7 +366,9 @@ async def _initialize_conversation( ) conversation = None else: - logger.info(f"🔍 No Redis key found for {conversation_key}, creating new conversation") + logger.info( + f"🔍 No Redis key found for {conversation_key}, creating new conversation" + ) # If no valid placeholder found, create new conversation if not conversation: @@ -365,14 +380,18 @@ async def _initialize_conversation( ) await conversation.insert() conversation_id = conversation.conversation_id - logger.info(f"✅ Created streaming conversation {conversation_id} for session {session_id}") + logger.info( + f"✅ Created streaming conversation {conversation_id} for session {session_id}" + ) # Attach markers from Redis session (e.g., button events captured during streaming) session_key = f"audio:session:{session_id}" markers_json = await redis_client.hget(session_key, "markers") if markers_json: try: - markers_data = markers_json if isinstance(markers_json, str) else markers_json.decode() + markers_data = ( + markers_json if isinstance(markers_json, str) else markers_json.decode() + ) conversation.markers = json.loads(markers_data) await conversation.save() logger.info( @@ -392,7 +411,9 @@ async def _initialize_conversation( speaker_check_job_id = speech_job.meta.get("speaker_check_job_id") if speaker_check_job_id: try: - speaker_check_job = Job.fetch(speaker_check_job_id, connection=redis_conn) + speaker_check_job = Job.fetch( + speaker_check_job_id, connection=redis_conn + ) speaker_check_job.meta["conversation_id"] = conversation_id speaker_check_job.save_meta() except Exception as e: @@ -416,7 +437,9 @@ async def _initialize_conversation( # Signal audio persistence job to rotate to this conversation's file rotation_signal_key = f"conversation:current:{session_id}" - await redis_client.set(rotation_signal_key, conversation_id, ex=86400) # 24 hour TTL + await redis_client.set( + rotation_signal_key, conversation_id, ex=86400 + ) # 24 hour TTL logger.info( f"🔄 Signaled audio persistence to rotate file for conversation {conversation_id[:12]}" ) @@ -445,12 +468,16 @@ async def _monitor_conversation_loop( close_requested_reason, last_result_count, and last_word_count. """ session_key = f"audio:session:{state.session_id}" - max_runtime = 10740 # 3 hours - 60 seconds (single conversations shouldn't exceed 3 hours) + max_runtime = ( + 10740 # 3 hours - 60 seconds (single conversations shouldn't exceed 3 hours) + ) finalize_received = False # Inactivity timeout configuration - inactivity_timeout_seconds = float(os.getenv("SPEECH_INACTIVITY_THRESHOLD_SECONDS", "60")) + inactivity_timeout_seconds = float( + os.getenv("SPEECH_INACTIVITY_THRESHOLD_SECONDS", "60") + ) inactivity_timeout_minutes = inactivity_timeout_seconds / 60 last_inactivity_log_time = ( time.time() @@ -458,7 +485,9 @@ async def _monitor_conversation_loop( # Test mode: wait for audio queue to drain before timing out # In real usage, ambient noise keeps connection alive. In tests, chunks arrive in bursts. - wait_for_queue_drain = os.getenv("WAIT_FOR_AUDIO_QUEUE_DRAIN", "false").lower() == "true" + wait_for_queue_drain = ( + os.getenv("WAIT_FOR_AUDIO_QUEUE_DRAIN", "false").lower() == "true" + ) logger.info( f"📊 Conversation timeout configured: {inactivity_timeout_minutes} minutes ({inactivity_timeout_seconds}s)" @@ -508,7 +537,9 @@ async def _monitor_conversation_loop( f"🔌 WebSocket disconnected for session {state.session_id[:12]} - " f"ending conversation early" ) - state.timeout_triggered = False # This is a disconnect, not a timeout + state.timeout_triggered = ( + False # This is a disconnect, not a timeout + ) else: logger.info( f"🛑 Session finalizing (reason: {completion_reason_str}), " @@ -518,16 +549,20 @@ async def _monitor_conversation_loop( # Check for conversation close request (set by API, plugins, button press) if not finalize_received: - close_reason = await redis_client.hget(session_key, "conversation_close_requested") + close_reason = await redis_client.hget( + session_key, "conversation_close_requested" + ) if close_reason: await redis_client.hdel(session_key, "conversation_close_requested") state.close_requested_reason = ( - close_reason.decode() if isinstance(close_reason, bytes) else close_reason + close_reason.decode() + if isinstance(close_reason, bytes) + else close_reason ) - logger.info(f"🔒 Conversation close requested: {state.close_requested_reason}") - state.timeout_triggered = ( - True # Session stays active (same restart behavior as inactivity timeout) + logger.info( + f"🔒 Conversation close requested: {state.close_requested_reason}" ) + state.timeout_triggered = True # Session stays active (same restart behavior as inactivity timeout) finalize_received = True break @@ -586,7 +621,9 @@ async def _monitor_conversation_loop( # Can't reliably detect inactivity, so skip timeout check this iteration inactivity_duration = 0 if speech_analysis.get("fallback", False): - logger.debug("⚠️ Skipping inactivity check (no audio timestamps available)") + logger.debug( + "⚠️ Skipping inactivity check (no audio timestamps available)" + ) current_time = time.time() @@ -698,11 +735,15 @@ async def _save_streaming_transcript( """ from advanced_omi_backend.models.conversation import Conversation - logger.info(f"📝 Retrieving final streaming transcript for conversation {conversation_id[:12]}") + logger.info( + f"📝 Retrieving final streaming transcript for conversation {conversation_id[:12]}" + ) final_transcript = await aggregator.get_combined_results(session_id) # Fetch conversation from database to ensure we have latest state - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if not conversation: logger.error(f"❌ Conversation {conversation_id} not found in database") raise ValueError(f"Conversation {conversation_id} not found") @@ -976,7 +1017,9 @@ async def open_conversation_job( logger.info(f"📊 Using completion_reason from session: {state.end_reason}") elif state.close_requested_reason: state.end_reason = "close_requested" - logger.info(f"📊 Conversation closed by request: {state.close_requested_reason}") + logger.info( + f"📊 Conversation closed by request: {state.close_requested_reason}" + ) elif state.timeout_triggered: state.end_reason = "inactivity_timeout" elif time.time() - state.start_time > 10740: @@ -984,7 +1027,9 @@ async def open_conversation_job( else: state.end_reason = "user_stopped" - logger.info(f"📊 Conversation {conversation_id[:12]} end_reason determined: {state.end_reason}") + logger.info( + f"📊 Conversation {conversation_id[:12]} end_reason determined: {state.end_reason}" + ) # Phase 4-7: Post-processing (wrapped in try/finally for guaranteed cleanup) end_of_conversation_handled = False @@ -1054,7 +1099,9 @@ async def open_conversation_job( end_reason=state.end_reason, ) - logger.info(f"📦 MongoDB audio chunks ready for conversation {conversation_id[:12]}") + logger.info( + f"📦 MongoDB audio chunks ready for conversation {conversation_id[:12]}" + ) # Phase 6: Save streaming transcript version_id = await _save_streaming_transcript( @@ -1108,7 +1155,9 @@ async def open_conversation_job( @async_job(redis=True, beanie=True) -async def generate_title_summary_job(conversation_id: str, *, redis_client=None) -> Dict[str, Any]: +async def generate_title_summary_job( + conversation_id: str, *, redis_client=None +) -> Dict[str, Any]: """ Generate title, short summary, and detailed summary for a conversation using LLM. @@ -1132,12 +1181,16 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) ) set_otel_session(conversation_id) - logger.info(f"📝 Starting title/summary generation for conversation {conversation_id}") + logger.info( + f"📝 Starting title/summary generation for conversation {conversation_id}" + ) start_time = time.time() # Get the conversation - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if not conversation: logger.error(f"Conversation {conversation_id} not found") return {"success": False, "error": "Conversation not found"} @@ -1147,7 +1200,9 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) segments = conversation.segments or [] if not transcript_text and (not segments or len(segments) == 0): - logger.warning(f"⚠️ No transcript or segments available for conversation {conversation_id}") + logger.warning( + f"⚠️ No transcript or segments available for conversation {conversation_id}" + ) return { "success": False, "error": "No transcript or segments available", @@ -1179,7 +1234,9 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) else: logger.info(f"📚 No memories found for context enrichment") except Exception as mem_error: - logger.warning(f"⚠️ Could not fetch memory context (continuing without): {mem_error}") + logger.warning( + f"⚠️ Could not fetch memory context (continuing without): {mem_error}" + ) # Generate title+summary (one call) and detailed summary in parallel import asyncio @@ -1203,7 +1260,9 @@ async def generate_title_summary_job(conversation_id: str, *, redis_client=None) logger.info(f"✅ Generated title: '{conversation.title}'") logger.info(f"✅ Generated summary: '{conversation.summary}'") - logger.info(f"✅ Generated detailed summary: {len(conversation.detailed_summary)} chars") + logger.info( + f"✅ Generated detailed summary: {len(conversation.detailed_summary)} chars" + ) # Update processing status for placeholder/reprocessing conversations if getattr(conversation, "processing_status", None) in [ @@ -1300,12 +1359,16 @@ async def dispatch_conversation_complete_event_job( """ from advanced_omi_backend.models.conversation import Conversation - logger.info(f"📌 Dispatching conversation.complete event for conversation {conversation_id}") + logger.info( + f"📌 Dispatching conversation.complete event for conversation {conversation_id}" + ) start_time = time.time() # Get the conversation to include in event data - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if not conversation: logger.error(f"Conversation {conversation_id} not found") return {"success": False, "error": "Conversation not found"} diff --git a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py index 6cbf5af3..ae43c9c2 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/transcription_jobs.py @@ -79,7 +79,9 @@ async def apply_speaker_recognition( speaker_client = SpeakerRecognitionClient() if not speaker_client.enabled: - logger.info(f"🎤 Speaker recognition disabled, using original speaker labels") + logger.info( + f"🎤 Speaker recognition disabled, using original speaker labels" + ) return segments logger.info( @@ -120,7 +122,9 @@ def get_speaker_at_time(timestamp: float, speaker_segments: list) -> str: updated_count = 0 for seg in segments: seg_mid = (seg.start + seg.end) / 2.0 - identified_speaker = get_speaker_at_time(seg_mid, speaker_identified_segments) + identified_speaker = get_speaker_at_time( + seg_mid, speaker_identified_segments + ) if identified_speaker and identified_speaker != "Unknown": original_speaker = seg.speaker @@ -183,7 +187,9 @@ async def transcribe_full_audio_job( start_time = time.time() # Get the conversation - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if not conversation: raise ValueError(f"Conversation {conversation_id} not found") @@ -200,18 +206,23 @@ async def transcribe_full_audio_job( logger.info(f"Using transcription provider: {provider_name}") # Reconstruct audio from MongoDB chunks - logger.info(f"📦 Reconstructing audio from MongoDB chunks for conversation {conversation_id}") + logger.info( + f"📦 Reconstructing audio from MongoDB chunks for conversation {conversation_id}" + ) try: # Reconstruct WAV from MongoDB chunks (already in memory as bytes) wav_data = await reconstruct_wav_from_conversation(conversation_id) logger.info( - f"📦 Reconstructed audio from MongoDB chunks: " f"{len(wav_data) / 1024 / 1024:.2f} MB" + f"📦 Reconstructed audio from MongoDB chunks: " + f"{len(wav_data) / 1024 / 1024:.2f} MB" ) except ValueError as e: # No chunks found for conversation - raise FileNotFoundError(f"No audio chunks found for conversation {conversation_id}: {e}") + raise FileNotFoundError( + f"No audio chunks found for conversation {conversation_id}: {e}" + ) except Exception as e: logger.error(f"Failed to reconstruct audio from MongoDB: {e}", exc_info=True) raise RuntimeError(f"Audio reconstruction failed: {e}") @@ -294,7 +305,9 @@ def _on_batch_progress(event: dict) -> None: description=f"conversation={conversation_id[:12]}, words={len(words)}", ) except Exception as e: - logger.exception(f"⚠️ Error triggering transcript plugins in batch mode: {e}") + logger.exception( + f"⚠️ Error triggering transcript plugins in batch mode: {e}" + ) logger.info(f"🔍 DEBUG: Plugin processing complete, moving to speech validation") @@ -349,7 +362,9 @@ def _on_batch_progress(event: dict) -> None: f"Job {job_id} hash not found (likely already completed or expired)" ) else: - logger.debug(f"Job {job_id} not found or already completed: {e}") + logger.debug( + f"Job {job_id} not found or already completed: {e}" + ) if cancelled_jobs: logger.info( @@ -579,7 +594,9 @@ async def create_audio_only_conversation( # Update status to show batch transcription is starting placeholder_conversation.processing_status = "batch_transcription" placeholder_conversation.title = "Audio Recording (Batch Transcription...)" - placeholder_conversation.summary = "Processing audio with offline transcription..." + placeholder_conversation.summary = ( + "Processing audio with offline transcription..." + ) await placeholder_conversation.save() # Audio chunks are already linked to this conversation_id @@ -606,7 +623,9 @@ async def create_audio_only_conversation( ) await conversation.insert() - logger.info(f"✅ Created batch transcription conversation {session_id[:12]} for fallback") + logger.info( + f"✅ Created batch transcription conversation {session_id[:12]} for fallback" + ) return conversation @@ -752,14 +771,18 @@ async def transcription_fallback_check_job( sample_rate, channels, sample_width = 16000, 1, 2 session_key = f"audio:session:{session_id}" try: - audio_format_raw = await redis_client.hget(session_key, "audio_format") + audio_format_raw = await redis_client.hget( + session_key, "audio_format" + ) if audio_format_raw: audio_format = json.loads(audio_format_raw) sample_rate = int(audio_format.get("rate", 16000)) channels = int(audio_format.get("channels", 1)) sample_width = int(audio_format.get("width", 2)) except Exception as e: - logger.warning(f"Failed to read audio_format from Redis for {session_id}: {e}") + logger.warning( + f"Failed to read audio_format from Redis for {session_id}: {e}" + ) bytes_per_second = sample_rate * channels * sample_width logger.info( @@ -768,7 +791,9 @@ async def transcription_fallback_check_job( ) # Create conversation placeholder - conversation = await create_audio_only_conversation(session_id, user_id, client_id) + conversation = await create_audio_only_conversation( + session_id, user_id, client_id + ) # Save audio to MongoDB chunks for batch transcription num_chunks = await convert_audio_to_chunks( @@ -785,7 +810,9 @@ async def transcription_fallback_check_job( ) except Exception as e: - logger.error(f"❌ Failed to extract audio from Redis stream: {e}", exc_info=True) + logger.error( + f"❌ Failed to extract audio from Redis stream: {e}", exc_info=True + ) raise else: logger.info( @@ -794,7 +821,9 @@ async def transcription_fallback_check_job( ) # Create conversation placeholder for batch transcription - conversation = await create_audio_only_conversation(session_id, user_id, client_id) + conversation = await create_audio_only_conversation( + session_id, user_id, client_id + ) # Enqueue batch transcription job version_id = f"batch_fallback_{session_id[:12]}" @@ -890,10 +919,14 @@ async def stream_speech_detection_job( # Get conversation count conversation_count_key = f"session:conversation_count:{session_id}" conversation_count_bytes = await redis_client.get(conversation_count_key) - conversation_count = int(conversation_count_bytes) if conversation_count_bytes else 0 + conversation_count = ( + int(conversation_count_bytes) if conversation_count_bytes else 0 + ) # Check if speaker filtering is enabled - speaker_filter_enabled = os.getenv("RECORD_ONLY_ENROLLED_SPEAKERS", "false").lower() == "true" + speaker_filter_enabled = ( + os.getenv("RECORD_ONLY_ENROLLED_SPEAKERS", "false").lower() == "true" + ) logger.info( f"📊 Conversation #{conversation_count + 1}, Speaker filter: {'enabled' if speaker_filter_enabled else 'disabled'}" ) @@ -936,17 +969,24 @@ async def stream_speech_detection_job( ) # Exit if grace period expired without speech - if session_closed_at and (time.time() - session_closed_at) > final_check_grace_period: + if ( + session_closed_at + and (time.time() - session_closed_at) > final_check_grace_period + ): logger.info(f"✅ Session ended without speech (grace period expired)") break # Consume any stale conversation close request (defensive — shouldn't normally # appear since services.py gates on conversation:current, but handles race conditions) - close_reason = await redis_client.hget(session_key, "conversation_close_requested") + close_reason = await redis_client.hget( + session_key, "conversation_close_requested" + ) if close_reason: await redis_client.hdel(session_key, "conversation_close_requested") close_reason_str = ( - close_reason.decode() if isinstance(close_reason, bytes) else close_reason + close_reason.decode() + if isinstance(close_reason, bytes) + else close_reason ) logger.info( f"🔒 Conversation close requested ({close_reason_str}) during speech detection — " @@ -963,11 +1003,15 @@ async def stream_speech_detection_job( # Health check: detect transcription errors early during grace period if session_closed_at: # Check for streaming consumer errors in session metadata - error_status = await redis_client.hget(session_key, "transcription_error") + error_status = await redis_client.hget( + session_key, "transcription_error" + ) if error_status: error_msg = error_status.decode() logger.error(f"❌ Transcription service error: {error_msg}") - logger.error(f"❌ Session failed - transcription service unavailable") + logger.error( + f"❌ Session failed - transcription service unavailable" + ) break # Check if we've been waiting too long with no results at all @@ -977,7 +1021,9 @@ async def stream_speech_detection_job( logger.error( f"❌ No transcription activity after {grace_elapsed:.1f}s - possible API key or connectivity issue" ) - logger.error(f"❌ Session failed - check transcription service configuration") + logger.error( + f"❌ Session failed - check transcription service configuration" + ) break await asyncio.sleep(2) @@ -1017,7 +1063,9 @@ async def stream_speech_detection_job( "last_event", f"speech_detected:{datetime.utcnow().isoformat()}", ) - await redis_client.hset(session_key, "speech_detected_at", datetime.utcnow().isoformat()) + await redis_client.hset( + session_key, "speech_detected_at", datetime.utcnow().isoformat() + ) # Step 2: If speaker filter enabled, check for enrolled speakers identified_speakers = [] @@ -1069,7 +1117,9 @@ async def stream_speech_detection_job( result = speaker_check_job.result enrolled_present = result.get("enrolled_present", False) identified_speakers = result.get("identified_speakers", []) - logger.info(f"✅ Speaker check completed: enrolled={enrolled_present}") + logger.info( + f"✅ Speaker check completed: enrolled={enrolled_present}" + ) # Update session event for speaker check complete await redis_client.hset( @@ -1098,7 +1148,9 @@ async def stream_speech_detection_job( "last_event", f"speaker_check_failed:{datetime.utcnow().isoformat()}", ) - await redis_client.hset(session_key, "speaker_check_status", "failed") + await redis_client.hset( + session_key, "speaker_check_status", "failed" + ) break await asyncio.sleep(poll_interval) waited += poll_interval @@ -1151,7 +1203,9 @@ async def stream_speech_detection_job( ) # Track the job - await redis_client.set(open_job_key, open_job.id, ex=10800) # 3 hours to match job timeout + await redis_client.set( + open_job_key, open_job.id, ex=10800 + ) # 3 hours to match job timeout # Store metadata in speech detection job if current_job: @@ -1164,23 +1218,31 @@ async def stream_speech_detection_job( current_job.meta.update( { "conversation_job_id": open_job.id, - "speaker_check_job_id": (speaker_check_job.id if speaker_check_job else None), + "speaker_check_job_id": ( + speaker_check_job.id if speaker_check_job else None + ), "detected_speakers": identified_speakers, - "speech_detected_at": datetime.fromtimestamp(speech_detected_at).isoformat(), + "speech_detected_at": datetime.fromtimestamp( + speech_detected_at + ).isoformat(), "session_id": session_id, "client_id": client_id, # For job grouping } ) current_job.save_meta() - logger.info(f"✅ Started conversation job {open_job.id}, exiting speech detection") + logger.info( + f"✅ Started conversation job {open_job.id}, exiting speech detection" + ) return { "session_id": session_id, "user_id": user_id, "client_id": client_id, "conversation_job_id": open_job.id, - "speech_detected_at": datetime.fromtimestamp(speech_detected_at).isoformat(), + "speech_detected_at": datetime.fromtimestamp( + speech_detected_at + ).isoformat(), "runtime_seconds": time.time() - start_time, } @@ -1208,7 +1270,9 @@ async def stream_speech_detection_job( # Check if this is an always_persist conversation that needs to be marked as failed # NOTE: We check MongoDB directly because the conversation:current Redis key might have been # deleted by the audio persistence job cleanup (which runs in parallel). - logger.info(f"🔍 Checking MongoDB for always_persist conversation with client_id: {client_id}") + logger.info( + f"🔍 Checking MongoDB for always_persist conversation with client_id: {client_id}" + ) # Find conversation by client_id that matches this session # session_id == client_id for streaming sessions (set in _initialize_streaming_session) diff --git a/config_manager.py b/config_manager.py index 6d85bba7..21d24b00 100644 --- a/config_manager.py +++ b/config_manager.py @@ -94,6 +94,26 @@ def _detect_service_path(self) -> Optional[str]: logger.debug("Could not auto-detect service path from cwd") return None + def ensure_config_yml(self) -> None: + """Create config.yml from template if it doesn't exist. + + Raises: + RuntimeError: If config.yml doesn't exist and template is not found. + """ + if self.config_yml_path.exists(): + return + + template_path = self.config_yml_path.parent / "config.yml.template" + if not template_path.exists(): + raise RuntimeError( + f"config.yml.template not found at {template_path}. " + "Cannot create config.yml." + ) + + self.config_yml_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(template_path, self.config_yml_path) + logger.info(f"Created {self.config_yml_path} from template") + def _load_config_yml(self) -> Dict[str, Any]: """Load config.yml file.""" if not self.config_yml_path.exists(): diff --git a/extras/asr-services/common/base_service.py b/extras/asr-services/common/base_service.py index 2b81df3f..ff1222ca 100644 --- a/extras/asr-services/common/base_service.py +++ b/extras/asr-services/common/base_service.py @@ -13,11 +13,7 @@ from abc import ABC, abstractmethod from typing import Optional -from common.response_models import ( - HealthResponse, - InfoResponse, - TranscriptionResult, -) +from common.response_models import HealthResponse, InfoResponse, TranscriptionResult from fastapi import FastAPI, File, Form, HTTPException, UploadFile from fastapi.responses import JSONResponse, StreamingResponse @@ -236,14 +232,17 @@ def _ndjson_generator(): """Wrap sync generator as NDJSON lines, clean up temp file when done.""" try: for event in service.transcribe_with_progress( - tmp_filename, context_info=context_info, + tmp_filename, + context_info=context_info, ): yield json.dumps(event) + "\n" finally: try: os.unlink(tmp_filename) except Exception as e: - logger.warning(f"Failed to delete temp file {tmp_filename}: {e}") + logger.warning( + f"Failed to delete temp file {tmp_filename}: {e}" + ) return StreamingResponse( _ndjson_generator(), diff --git a/extras/asr-services/providers/vibevoice/service.py b/extras/asr-services/providers/vibevoice/service.py index 085af599..5779ef86 100644 --- a/extras/asr-services/providers/vibevoice/service.py +++ b/extras/asr-services/providers/vibevoice/service.py @@ -117,7 +117,8 @@ def transcribe_with_progress(self, audio_file_path: str, context_info=None): if self.transcriber is None: raise RuntimeError("Service not initialized") yield from self.transcriber._transcribe_batched_with_progress( - audio_file_path, hotwords=context_info, + audio_file_path, + hotwords=context_info, ) @@ -149,12 +150,16 @@ def _run_lora_training( finetune_script = vibevoice_dir / "finetuning-asr" / "lora_finetune.py" if not finetune_script.exists(): - raise FileNotFoundError(f"VibeVoice LoRA fine-tuning script not found at {finetune_script}") + raise FileNotFoundError( + f"VibeVoice LoRA fine-tuning script not found at {finetune_script}" + ) # Use importlib to load the script as a module import importlib.util - spec = importlib.util.spec_from_file_location("lora_finetune", str(finetune_script)) + spec = importlib.util.spec_from_file_location( + "lora_finetune", str(finetune_script) + ) lora_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(lora_module) @@ -172,12 +177,16 @@ def _run_lora_training( _finetune_state["status"] = "completed" _finetune_state["progress"] = "done" _finetune_state["last_completed_job_id"] = job_id - logger.info(f"LoRA fine-tuning completed: job_id={job_id}, adapter at {adapter_output_dir}") + logger.info( + f"LoRA fine-tuning completed: job_id={job_id}, adapter at {adapter_output_dir}" + ) except Exception as e: _finetune_state["status"] = "failed" _finetune_state["error"] = str(e) - logger.error(f"LoRA fine-tuning failed: job_id={job_id}, error={e}", exc_info=True) + logger.error( + f"LoRA fine-tuning failed: job_id={job_id}, error={e}", exc_info=True + ) def add_finetune_routes(app, service: VibeVoiceService) -> None: @@ -228,12 +237,14 @@ async def start_finetune( adapter_output_dir = str(adapter_base_dir / "latest") # Update state and launch training in background thread - _finetune_state.update({ - "status": "training", - "job_id": job_id, - "progress": "queued", - "error": None, - }) + _finetune_state.update( + { + "status": "training", + "job_id": job_id, + "progress": "queued", + "error": None, + } + ) loop = asyncio.get_event_loop() loop.run_in_executor( @@ -247,11 +258,13 @@ async def start_finetune( job_id, ) - return JSONResponse(content={ - "job_id": job_id, - "status": "training_started", - "adapter_output_dir": adapter_output_dir, - }) + return JSONResponse( + content={ + "job_id": job_id, + "status": "training_started", + "adapter_output_dir": adapter_output_dir, + } + ) @app.get("/fine-tune/status") async def finetune_status(): @@ -279,10 +292,12 @@ async def reload_adapter(adapter_path: Optional[str] = Form(None)): service.transcriber.load_lora_adapter, path, ) - return JSONResponse(content={ - "status": "adapter_loaded", - "adapter_path": path, - }) + return JSONResponse( + content={ + "status": "adapter_loaded", + "adapter_path": path, + } + ) except Exception as e: logger.error(f"Failed to reload adapter: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Failed to load adapter: {e}") diff --git a/extras/friend-lite-sdk/friend_lite/bluetooth.py b/extras/friend-lite-sdk/friend_lite/bluetooth.py index 20f8d374..30ce69b3 100644 --- a/extras/friend-lite-sdk/friend_lite/bluetooth.py +++ b/extras/friend-lite-sdk/friend_lite/bluetooth.py @@ -80,6 +80,7 @@ async def subscribe_battery(self, callback: Callable[[int], None]) -> None: *callback* receives a single int (0-100) each time the device reports an updated level. """ + def _on_notify(_sender: int, data: bytearray) -> None: if data: callback(data[0]) @@ -91,7 +92,9 @@ def _on_notify(_sender: int, data: bytearray) -> None: async def subscribe_audio(self, callback: Callable[[int, bytearray], None]) -> None: await self.subscribe(OMI_AUDIO_CHAR_UUID, callback) - async def subscribe(self, uuid: str, callback: Callable[[int, bytearray], None]) -> None: + async def subscribe( + self, uuid: str, callback: Callable[[int, bytearray], None] + ) -> None: if self._client is None: raise RuntimeError("Not connected to device") await self._client.start_notify(uuid, callback) @@ -106,7 +109,9 @@ async def wait_until_disconnected(self, timeout: float | None = None) -> None: class OmiConnection(WearableConnection): """OMI device with button and WiFi sync support.""" - async def subscribe_button(self, callback: Callable[[int, bytearray], None]) -> None: + async def subscribe_button( + self, callback: Callable[[int, bytearray], None] + ) -> None: await self.subscribe(OMI_BUTTON_CHAR_UUID, callback) # -- Haptic ------------------------------------------------------------ @@ -120,7 +125,9 @@ async def play_haptic(self, pattern: int = 1) -> None: raise RuntimeError("Not connected to device") if pattern not in (1, 2, 3): raise ValueError("pattern must be 1 (100ms), 2 (300ms), or 3 (500ms)") - await self._client.write_gatt_char(HAPTIC_CHAR_UUID, bytes([pattern]), response=True) + await self._client.write_gatt_char( + HAPTIC_CHAR_UUID, bytes([pattern]), response=True + ) async def is_haptic_supported(self) -> bool: """Check whether the device has a haptic motor.""" @@ -161,7 +168,9 @@ def _on_notify(_sender: int, data: bytearray) -> None: await self._client.start_notify(STORAGE_WIFI_CHAR_UUID, _on_notify) try: - await self._client.write_gatt_char(STORAGE_WIFI_CHAR_UUID, payload, response=True) + await self._client.write_gatt_char( + STORAGE_WIFI_CHAR_UUID, payload, response=True + ) await asyncio.wait_for(response_event.wait(), timeout=timeout) finally: await self._client.stop_notify(STORAGE_WIFI_CHAR_UUID) @@ -172,7 +181,12 @@ async def setup_wifi(self, ssid: str, password: str) -> int: """Send WiFi AP credentials to device. Returns response code (0=success).""" ssid_bytes = ssid.encode("utf-8") pwd_bytes = password.encode("utf-8") - payload = bytes([0x01, len(ssid_bytes)]) + ssid_bytes + bytes([len(pwd_bytes)]) + pwd_bytes + payload = ( + bytes([0x01, len(ssid_bytes)]) + + ssid_bytes + + bytes([len(pwd_bytes)]) + + pwd_bytes + ) return await self._wifi_command(payload) async def start_wifi(self) -> int: @@ -204,9 +218,13 @@ async def start_storage_read(self, file_num: int = 0, offset: int = 0) -> None: if self._client is None: raise RuntimeError("Not connected to device") payload = bytes([0x00, file_num]) + offset.to_bytes(4, byteorder="big") - await self._client.write_gatt_char(STORAGE_DATA_STREAM_CHAR_UUID, payload, response=True) + await self._client.write_gatt_char( + STORAGE_DATA_STREAM_CHAR_UUID, payload, response=True + ) - async def subscribe_storage_data(self, callback: Callable[[int, bytearray], None]) -> None: + async def subscribe_storage_data( + self, callback: Callable[[int, bytearray], None] + ) -> None: """Subscribe to storage data stream notifications (for BLE storage reads).""" if self._client is None: raise RuntimeError("Not connected to device") diff --git a/extras/local-wearable-client/main.py b/extras/local-wearable-client/main.py index ad89b827..5c123661 100644 --- a/extras/local-wearable-client/main.py +++ b/extras/local-wearable-client/main.py @@ -56,7 +56,9 @@ def check_config() -> bool: """Check that required configuration is present. Returns True if backend streaming is possible.""" if not os.path.exists(ENV_PATH): - logger.warning("No .env file found — copy .env.template to .env and fill in your settings") + logger.warning( + "No .env file found — copy .env.template to .env and fill in your settings" + ) logger.warning("Audio will be saved locally but NOT streamed to the backend") return False @@ -116,21 +118,25 @@ async def scan_all_devices(config: dict) -> list[dict]: for d, adv in discovered.values(): if d.address in known: entry = known[d.address] - devices.append({ - "mac": d.address, - "name": entry.get("name", d.name or "Unknown"), - "type": entry.get("type", detect_device_type(d.name or "")), - "rssi": adv.rssi, - }) + devices.append( + { + "mac": d.address, + "name": entry.get("name", d.name or "Unknown"), + "type": entry.get("type", detect_device_type(d.name or "")), + "rssi": adv.rssi, + } + ) elif auto_discover and d.name: lower = d.name.casefold() if "omi" in lower or "neo" in lower or "friend" in lower: - devices.append({ - "mac": d.address, - "name": d.name, - "type": detect_device_type(d.name), - "rssi": adv.rssi, - }) + devices.append( + { + "mac": d.address, + "name": d.name, + "type": detect_device_type(d.name), + "rssi": adv.rssi, + } + ) devices.sort(key=lambda x: x.get("rssi", -999), reverse=True) return devices @@ -148,7 +154,9 @@ def prompt_device_selection(devices: list[dict]) -> dict | None: print(f" {'#':<4} {'Name':<20} {'MAC':<20} {'Type':<8} {'RSSI'}") print(" " + "-" * 60) for i, d in enumerate(devices, 1): - print(f" {i:<4} {d['name']:<20} {d['mac']:<20} {d['type']:<8} {d.get('rssi', '?')}") + print( + f" {i:<4} {d['name']:<20} {d['mac']:<20} {d['type']:<8} {d.get('rssi', '?')}" + ) print() while True: @@ -284,19 +292,29 @@ def _on_battery(level: int) -> None: asyncio.create_task(process_audio(), name="process_audio"), ] if backend_enabled: - worker_tasks.append(asyncio.create_task(backend_stream_wrapper(), name="backend_stream")) + worker_tasks.append( + asyncio.create_task( + backend_stream_wrapper(), name="backend_stream" + ) + ) disconnect_task = asyncio.create_task( conn.wait_until_disconnected(), name="disconnect" ) - logger.info("Streaming audio from %s [%s]%s", device_name, device["mac"], - "" if backend_enabled else " (local-only, backend disabled)") + logger.info( + "Streaming audio from %s [%s]%s", + device_name, + device["mac"], + "" if backend_enabled else " (local-only, backend disabled)", + ) # Wait for disconnect or any worker to fail all_tasks = [disconnect_task] + worker_tasks try: - done, pending = await asyncio.wait(all_tasks, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + all_tasks, return_when=asyncio.FIRST_COMPLETED + ) except asyncio.CancelledError: # External cancellation (e.g. user disconnect) — clean up all workers for task in all_tasks: @@ -388,7 +406,11 @@ async def wifi_sync( logger.info("Configuring device WiFi AP (SSID=%s)...", ssid) rc = await conn.setup_wifi(ssid, password) if rc != WifiErrorCode.SUCCESS: - error_name = WifiErrorCode(rc).name if rc in WifiErrorCode._value2member_map_ else f"0x{rc:02X}" + error_name = ( + WifiErrorCode(rc).name + if rc in WifiErrorCode._value2member_map_ + else f"0x{rc:02X}" + ) logger.error("WiFi setup failed: %s", error_name) return @@ -403,9 +425,12 @@ def on_tcp_data(data: bytes) -> None: bytes_written[0] += len(data) # Progress update every ~1MB if bytes_written[0] % (1024 * 1024) < len(data): - logger.info("Received %d / %d bytes (%.1f%%)", - bytes_written[0], file_size, - bytes_written[0] / file_size * 100 if file_size else 0) + logger.info( + "Received %d / %d bytes (%.1f%%)", + bytes_written[0], + file_size, + bytes_written[0] / file_size * 100 if file_size else 0, + ) # Tell device to start WiFi AP (creates the network) logger.info("Starting device WiFi AP...") @@ -413,7 +438,11 @@ def on_tcp_data(data: bytes) -> None: if rc == WifiErrorCode.SESSION_ALREADY_RUNNING: logger.info("WiFi AP already running, continuing...") elif rc != WifiErrorCode.SUCCESS: - error_name = WifiErrorCode(rc).name if rc in WifiErrorCode._value2member_map_ else f"0x{rc:02X}" + error_name = ( + WifiErrorCode(rc).name + if rc in WifiErrorCode._value2member_map_ + else f"0x{rc:02X}" + ) logger.error("WiFi start failed: %s", error_name) output_file.close() return @@ -448,13 +477,25 @@ def on_tcp_data(data: bytes) -> None: break if attempt == 10 and not prompted_manual: prompted_manual = True - logger.info(">>> Auto-join may have failed. Please manually join WiFi '%s' (password: %s) <<<", ssid, password) + logger.info( + ">>> Auto-join may have failed. Please manually join WiFi '%s' (password: %s) <<<", + ssid, + password, + ) elif attempt % 10 == 0: - logger.info("Waiting for connection to '%s' AP... (current IP: %s)", ssid, local_ip) + logger.info( + "Waiting for connection to '%s' AP... (current IP: %s)", + ssid, + local_ip, + ) await asyncio.sleep(1) if not local_ip or not local_ip.startswith("192.168.1."): - logger.error("Failed to get IP on device AP network (got: %s). Is your WiFi connected to '%s'?", local_ip, ssid) + logger.error( + "Failed to get IP on device AP network (got: %s). Is your WiFi connected to '%s'?", + local_ip, + ssid, + ) await receiver.stop() output_file.close() if original_wifi: @@ -501,7 +542,11 @@ def on_tcp_data(data: bytes) -> None: except asyncio.CancelledError: pass - logger.info("Transfer complete: %d bytes written to %s", bytes_written[0], output_path) + logger.info( + "Transfer complete: %d bytes written to %s", + bytes_written[0], + output_path, + ) # Reconnect BLE to send cleanup commands logger.info("Reconnecting BLE for cleanup...") @@ -547,9 +592,16 @@ async def run(target_mac: str | None = None) -> None: device = None if target_mac: # --device flag: connect to specific MAC - device = next((d for d in devices if d["mac"].casefold() == target_mac.casefold()), None) + device = next( + (d for d in devices if d["mac"].casefold() == target_mac.casefold()), + None, + ) if not device: - logger.debug("Target device %s not found, retrying in %ds...", target_mac, scan_interval) + logger.debug( + "Target device %s not found, retrying in %ds...", + target_mac, + scan_interval, + ) elif len(devices) == 1: device = devices[0] elif len(devices) > 1: @@ -559,7 +611,12 @@ async def run(target_mac: str | None = None) -> None: return if device: - logger.info("Connecting to %s [%s] (type=%s)", device["name"], device["mac"], device["type"]) + logger.info( + "Connecting to %s [%s] (type=%s)", + device["name"], + device["mac"], + device["type"], + ) await connect_and_stream(device, backend_enabled=backend_enabled) logger.info("Device disconnected, resuming scan...") else: @@ -581,23 +638,27 @@ async def scan_and_print() -> None: for d, adv in discovered.values(): if d.address in known: entry = known[d.address] - devices.append({ - "mac": d.address, - "name": entry.get("name", d.name or "Unknown"), - "type": entry.get("type", detect_device_type(d.name or "")), - "rssi": adv.rssi, - "known": True, - }) + devices.append( + { + "mac": d.address, + "name": entry.get("name", d.name or "Unknown"), + "type": entry.get("type", detect_device_type(d.name or "")), + "rssi": adv.rssi, + "known": True, + } + ) elif auto_discover and d.name: lower = d.name.casefold() if "omi" in lower or "neo" in lower or "friend" in lower: - devices.append({ - "mac": d.address, - "name": d.name, - "type": detect_device_type(d.name), - "rssi": adv.rssi, - "known": False, - }) + devices.append( + { + "mac": d.address, + "name": d.name, + "type": detect_device_type(d.name), + "rssi": adv.rssi, + "known": False, + } + ) if not devices: print("No wearable devices found.") @@ -609,7 +670,9 @@ async def scan_and_print() -> None: print(f"{'Name':<20} {'MAC':<20} {'Type':<8} {'RSSI':<8} {'Known'}") print("-" * 70) for d in devices: - print(f"{d['name']:<20} {d['mac']:<20} {d['type']:<8} {d['rssi']:<8} {'yes' if d['known'] else 'auto'}") + print( + f"{d['name']:<20} {d['mac']:<20} {d['type']:<8} {d['rssi']:<8} {'yes' if d['known'] else 'auto'}" + ) def build_parser() -> argparse.ArgumentParser: @@ -620,15 +683,35 @@ def build_parser() -> argparse.ArgumentParser: sub = parser.add_subparsers(dest="command") sub.add_parser("menu", help="Launch menu bar app (default)") - run_parser = sub.add_parser("run", help="Headless mode — scan, connect, and stream (for launchd)") - run_parser.add_argument("--device", metavar="MAC", help="Connect to a specific device by MAC address") + run_parser = sub.add_parser( + "run", help="Headless mode — scan, connect, and stream (for launchd)" + ) + run_parser.add_argument( + "--device", metavar="MAC", help="Connect to a specific device by MAC address" + ) sub.add_parser("scan", help="One-shot scan — print nearby devices and exit") - wifi_parser = sub.add_parser("wifi-sync", help="Download stored audio from device via WiFi sync") - wifi_parser.add_argument("--device", metavar="MAC", help="Connect to a specific device by MAC address") - wifi_parser.add_argument("--ssid", default="Friend", help="WiFi AP SSID (default: Friend)") - wifi_parser.add_argument("--password", default="12345678", help="WiFi AP password (default: 12345678)") - wifi_parser.add_argument("--interface", metavar="IFACE", help="WiFi interface to use (e.g. en1 for USB adapter)") - wifi_parser.add_argument("--output-dir", default="./wifi_audio", help="Output directory (default: ./wifi_audio)") + wifi_parser = sub.add_parser( + "wifi-sync", help="Download stored audio from device via WiFi sync" + ) + wifi_parser.add_argument( + "--device", metavar="MAC", help="Connect to a specific device by MAC address" + ) + wifi_parser.add_argument( + "--ssid", default="Friend", help="WiFi AP SSID (default: Friend)" + ) + wifi_parser.add_argument( + "--password", default="12345678", help="WiFi AP password (default: 12345678)" + ) + wifi_parser.add_argument( + "--interface", + metavar="IFACE", + help="WiFi interface to use (e.g. en1 for USB adapter)", + ) + wifi_parser.add_argument( + "--output-dir", + default="./wifi_audio", + help="Output directory (default: ./wifi_audio)", + ) sub.add_parser("install", help="Install macOS launchd agent (auto-start on login)") sub.add_parser("uninstall", help="Remove macOS launchd agent") @@ -645,19 +728,22 @@ def main() -> None: command = args.command or "menu" # Default to menu mode if command == "wifi-sync": - asyncio.run(wifi_sync( - target_mac=getattr(args, "device", None), - ssid=args.ssid, - password=args.password, - interface=args.interface, - output_dir=args.output_dir, - )) + asyncio.run( + wifi_sync( + target_mac=getattr(args, "device", None), + ssid=args.ssid, + password=args.password, + interface=args.interface, + output_dir=args.output_dir, + ) + ) elif command == "run": asyncio.run(run(target_mac=getattr(args, "device", None))) elif command == "menu": from menu_app import run_menu_app + run_menu_app() elif command == "scan": @@ -665,22 +751,27 @@ def main() -> None: elif command == "install": from service import install + install() elif command == "uninstall": from service import uninstall + uninstall() elif command == "kickstart": from service import kickstart + kickstart() elif command == "status": from service import status + status() elif command == "logs": from service import logs + logs() diff --git a/extras/local-wearable-client/menu_app.py b/extras/local-wearable-client/menu_app.py index a7a785f7..7da5fedc 100644 --- a/extras/local-wearable-client/menu_app.py +++ b/extras/local-wearable-client/menu_app.py @@ -15,9 +15,13 @@ import yaml from bleak import BleakScanner from dotenv import load_dotenv - -from backend_sender import stream_to_backend -from main import CONFIG_PATH, check_config, connect_and_stream, create_connection, detect_device_type, load_config +from main import ( + CONFIG_PATH, + check_config, + connect_and_stream, + detect_device_type, + load_config, +) logger = logging.getLogger(__name__) @@ -26,6 +30,7 @@ # --- Shared state ----------------------------------------------------------- + @dataclass class SharedState: """Thread-safe state shared between the rumps UI and the asyncio BLE thread.""" @@ -33,7 +38,9 @@ class SharedState: _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) status: str = "idle" # idle | scanning | connecting | connected | error connected_device: Optional[dict] = None # {name, mac, type} - nearby_devices: list[dict] = field(default_factory=list) # [{name, mac, type, rssi}] + nearby_devices: list[dict] = field( + default_factory=list + ) # [{name, mac, type, rssi}] error: Optional[str] = None chunks_sent: int = 0 battery_level: int = -1 # -1 = unknown @@ -42,7 +49,9 @@ def snapshot(self) -> dict: with self._lock: return { "status": self.status, - "connected_device": self.connected_device.copy() if self.connected_device else None, + "connected_device": ( + self.connected_device.copy() if self.connected_device else None + ), "nearby_devices": [d.copy() for d in self.nearby_devices], "error": self.error, "chunks_sent": self.chunks_sent, @@ -57,6 +66,7 @@ def update(self, **kwargs) -> None: # --- Asyncio background thread ---------------------------------------------- + class AsyncioThread: """Runs an asyncio event loop in a daemon thread.""" @@ -83,6 +93,7 @@ def run_coro(self, coro): # --- BLE manager (runs in the asyncio thread) -------------------------------- + class BLEManager: """Manages BLE scanning and device connections in the background asyncio thread.""" @@ -99,7 +110,9 @@ def __init__(self, state: SharedState, bg: AsyncioThread) -> None: self._backoff_seconds: float = 0 # 0 = no backoff active self._BACKOFF_INITIAL: float = 10.0 self._BACKOFF_MAX: float = 300.0 # 5 minutes - self._MIN_HEALTHY_DURATION: float = 30.0 # connections shorter than this trigger backoff + self._MIN_HEALTHY_DURATION: float = ( + 30.0 # connections shorter than this trigger backoff + ) # Restore last connected device for auto-connect last = self.config.get("last_connected") @@ -138,7 +151,10 @@ async def _scan_loop(self) -> None: # If we have a target and not already connecting/connected, try connecting if self._target_mac and not self._connecting: snap = self.state.snapshot() - match = next((d for d in snap["nearby_devices"] if d["mac"] == self._target_mac), None) + match = next( + (d for d in snap["nearby_devices"] if d["mac"] == self._target_mac), + None, + ) if match: await self._connect(match) @@ -168,29 +184,35 @@ async def _do_scan(self) -> None: # Check if known device if d.address in known: entry = known[d.address] - devices.append({ - "mac": d.address, - "name": entry.get("name", d.name or "Unknown"), - "type": entry.get("type", detect_device_type(d.name or "")), - "rssi": adv.rssi, - }) + devices.append( + { + "mac": d.address, + "name": entry.get("name", d.name or "Unknown"), + "type": entry.get("type", detect_device_type(d.name or "")), + "rssi": adv.rssi, + } + ) continue # Auto-discover recognized names if auto_discover and d.name: lower = d.name.casefold() if "omi" in lower or "neo" in lower or "friend" in lower: - devices.append({ - "mac": d.address, - "name": d.name, - "type": detect_device_type(d.name), - "rssi": adv.rssi, - }) + devices.append( + { + "mac": d.address, + "name": d.name, + "type": detect_device_type(d.name), + "rssi": adv.rssi, + } + ) # Sort by signal strength (strongest first) devices.sort(key=lambda x: x.get("rssi", -999), reverse=True) - new_status = "idle" if self.state.snapshot()["status"] != "connected" else "connected" + new_status = ( + "idle" if self.state.snapshot()["status"] != "connected" else "connected" + ) self.state.update(nearby_devices=devices, status=new_status, error=None) logger.info("Scan found %d device(s)", len(devices)) @@ -237,9 +259,15 @@ async def _connect(self, device: dict) -> None: if self._backoff_seconds == 0: self._backoff_seconds = self._BACKOFF_INITIAL else: - self._backoff_seconds = min(self._backoff_seconds * 2, self._BACKOFF_MAX) - logger.info("Connection lasted %.1fs (< %.0fs), backoff %.0fs before next attempt", - elapsed, self._MIN_HEALTHY_DURATION, self._backoff_seconds) + self._backoff_seconds = min( + self._backoff_seconds * 2, self._BACKOFF_MAX + ) + logger.info( + "Connection lasted %.1fs (< %.0fs), backoff %.0fs before next attempt", + elapsed, + self._MIN_HEALTHY_DURATION, + self._backoff_seconds, + ) else: # Healthy connection — reset backoff self._backoff_seconds = 0 @@ -290,6 +318,7 @@ def request_scan(self) -> None: # --- rumps menu bar app ------------------------------------------------------- + class WearableMenuApp(rumps.App): """macOS menu bar app for Chronicle wearable client.""" @@ -344,7 +373,9 @@ def refresh_ui(self, _sender) -> None: dev = snap["connected_device"] bat = snap["battery_level"] bat_str = f" 🔋{bat}%" if bat >= 0 else "" - self.status_item.title = f"Connected: {dev['name']} [{dev['mac'][-8:]}]{bat_str}" + self.status_item.title = ( + f"Connected: {dev['name']} [{dev['mac'][-8:]}]{bat_str}" + ) elif status == "connecting": self.status_item.title = "Connecting..." elif status == "scanning": @@ -357,7 +388,9 @@ def refresh_ui(self, _sender) -> None: # Update device list self._rebuild_device_menu(snap["nearby_devices"], snap["connected_device"]) - def _rebuild_device_menu(self, devices: list[dict], connected: Optional[dict]) -> None: + def _rebuild_device_menu( + self, devices: list[dict], connected: Optional[dict] + ) -> None: """Replace the device submenu items with fresh MenuItem instances.""" connected_mac = connected["mac"] if connected else None @@ -414,11 +447,13 @@ def on_disconnect(self, _sender) -> None: # --- Entry point -------------------------------------------------------------- + def run_menu_app() -> None: """Launch the menu bar app with background BLE thread.""" # Register as accessory app so macOS allows menu bar icons # (non-bundled Python processes default to no-UI policy on Sequoia) from AppKit import NSApplication + NSApplication.sharedApplication().setActivationPolicy_(1) # Accessory logging.basicConfig( diff --git a/extras/local-wearable-client/service.py b/extras/local-wearable-client/service.py index e6831773..825657a0 100644 --- a/extras/local-wearable-client/service.py +++ b/extras/local-wearable-client/service.py @@ -30,11 +30,12 @@ def _find_uv() -> str: ]: if candidate.exists(): return str(candidate) - print("Error: could not find 'uv' binary. Install it: curl -LsSf https://astral.sh/uv/install.sh | sh") + print( + "Error: could not find 'uv' binary. Install it: curl -LsSf https://astral.sh/uv/install.sh | sh" + ) sys.exit(1) - def _opus_dyld_path() -> str: """Get DYLD_LIBRARY_PATH for opuslib on macOS.""" try: @@ -58,7 +59,8 @@ def _create_app_bundle() -> None: applescript = f'do shell script "launchctl kickstart gui/" & (do shell script "id -u") & "/{LABEL}"' result = subprocess.run( ["osacompile", "-o", str(APP_BUNDLE), "-e", applescript], - capture_output=True, text=True, + capture_output=True, + text=True, ) if result.returncode != 0: print(f"osacompile failed: {result.stderr.strip()}") @@ -68,14 +70,16 @@ def _create_app_bundle() -> None: info_plist = APP_BUNDLE / "Contents" / "Info.plist" with open(info_plist, "rb") as f: info = plistlib.load(f) - info.update({ - "CFBundleName": "Chronicle Wearable", - "CFBundleDisplayName": "Chronicle Wearable", - "CFBundleIdentifier": LABEL, - "CFBundleVersion": "1.0", - "CFBundleShortVersionString": "1.0", - "LSUIElement": True, # No dock icon - }) + info.update( + { + "CFBundleName": "Chronicle Wearable", + "CFBundleDisplayName": "Chronicle Wearable", + "CFBundleIdentifier": LABEL, + "CFBundleVersion": "1.0", + "CFBundleShortVersionString": "1.0", + "LSUIElement": True, # No dock icon + } + ) with open(info_plist, "wb") as f: plistlib.dump(info, f) @@ -108,9 +112,13 @@ def _build_plist() -> dict: plist = { "Label": LABEL, "ProgramArguments": [ - uv, "run", - "--project", str(PROJECT_DIR), - "python", str(PROJECT_DIR / "main.py"), "menu", + uv, + "run", + "--project", + str(PROJECT_DIR), + "python", + str(PROJECT_DIR / "main.py"), + "menu", ], "WorkingDirectory": str(PROJECT_DIR), "RunAtLoad": True, @@ -146,7 +154,8 @@ def install() -> None: result = subprocess.run( ["launchctl", "bootstrap", f"gui/{os.getuid()}", str(PLIST_PATH)], - capture_output=True, text=True, + capture_output=True, + text=True, ) if result.returncode == 0: print(f"Service '{LABEL}' installed and loaded.") @@ -168,7 +177,8 @@ def uninstall() -> None: result = subprocess.run( ["launchctl", "bootout", f"gui/{os.getuid()}", str(PLIST_PATH)], - capture_output=True, text=True, + capture_output=True, + text=True, ) if result.returncode == 0: print(f"Service '{LABEL}' unloaded.") @@ -188,7 +198,8 @@ def kickstart() -> None: result = subprocess.run( ["launchctl", "kickstart", f"gui/{os.getuid()}/{LABEL}"], - capture_output=True, text=True, + capture_output=True, + text=True, ) if result.returncode == 0: print(f"Service '{LABEL}' started.") @@ -204,13 +215,16 @@ def status() -> None: result = subprocess.run( ["launchctl", "print", f"gui/{os.getuid()}/{LABEL}"], - capture_output=True, text=True, + capture_output=True, + text=True, ) if result.returncode == 0: # Extract key info from launchctl print output for line in result.stdout.splitlines(): stripped = line.strip() - if any(k in stripped.lower() for k in ["state", "pid", "last exit", "runs"]): + if any( + k in stripped.lower() for k in ["state", "pid", "last exit", "runs"] + ): print(stripped) else: print(f"Service '{LABEL}' is not running.") diff --git a/tests/unit/test_wizard_defaults.py b/tests/unit/test_wizard_defaults.py new file mode 100644 index 00000000..fa6da712 --- /dev/null +++ b/tests/unit/test_wizard_defaults.py @@ -0,0 +1,271 @@ +"""Test wizard.py helper functions for loading previous config as defaults. + +Tests for the functions that read config/config.yml to pre-populate wizard +prompts with previously-configured values, so re-runs default to existing +settings. +""" + +import importlib.util +import sys +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import yaml + +# --------------------------------------------------------------------------- +# Import the pure helper functions directly from wizard.py. +# wizard.py lives at the project root, not inside a package, so we import +# via importlib with an explicit path to avoid adding the root to sys.path +# permanently. +# --------------------------------------------------------------------------- + + +WIZARD_PATH = Path(__file__).parent.parent.parent / "wizard.py" +PROJECT_ROOT = str(WIZARD_PATH.parent) + + +def _load_wizard(): + # wizard.py and setup_utils.py both live in the project root. + # Add the root to sys.path so the relative import resolves. + if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + spec = importlib.util.spec_from_file_location("wizard", WIZARD_PATH) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +# Load once and reuse +_wizard = _load_wizard() + +read_config_yml = _wizard.read_config_yml +get_existing_stt_provider = _wizard.get_existing_stt_provider +get_existing_stream_provider = _wizard.get_existing_stream_provider +select_llm_provider = _wizard.select_llm_provider +select_memory_provider = _wizard.select_memory_provider +select_knowledge_graph = _wizard.select_knowledge_graph + + +# --------------------------------------------------------------------------- +# read_config_yml +# --------------------------------------------------------------------------- + + +def test_read_config_yml_missing_file(tmp_path, monkeypatch): + """Returns empty dict when config/config.yml does not exist.""" + monkeypatch.chdir(tmp_path) + result = read_config_yml() + assert result == {} + + +def test_read_config_yml_valid_file(tmp_path, monkeypatch): + """Parses and returns dict from a valid YAML file.""" + monkeypatch.chdir(tmp_path) + config_dir = tmp_path / "config" + config_dir.mkdir() + (config_dir / "config.yml").write_text( + "defaults:\n llm: openai-llm\n stt: stt-deepgram\n" + ) + result = read_config_yml() + assert result["defaults"]["llm"] == "openai-llm" + assert result["defaults"]["stt"] == "stt-deepgram" + + +def test_read_config_yml_empty_file(tmp_path, monkeypatch): + """Returns empty dict for an empty YAML file (yaml.safe_load returns None).""" + monkeypatch.chdir(tmp_path) + config_dir = tmp_path / "config" + config_dir.mkdir() + (config_dir / "config.yml").write_text("") + result = read_config_yml() + assert result == {} + + +def test_read_config_yml_comment_only_file(tmp_path, monkeypatch): + """Returns empty dict when the file contains only YAML comments.""" + monkeypatch.chdir(tmp_path) + config_dir = tmp_path / "config" + config_dir.mkdir() + (config_dir / "config.yml").write_text("# just a comment\n") + result = read_config_yml() + assert result == {} + + +# --------------------------------------------------------------------------- +# get_existing_stt_provider +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "stt_value, expected", + [ + ("stt-deepgram", "deepgram"), + ("stt-deepgram-stream", "deepgram"), + ("stt-parakeet-batch", "parakeet"), + ("stt-vibevoice", "vibevoice"), + ("stt-qwen3-asr", "qwen3-asr"), + ("stt-smallest", "smallest"), + ("stt-smallest-stream", "smallest"), + ], +) +def test_get_existing_stt_provider_known_values(stt_value, expected): + """Maps known config.yml stt values to wizard provider names.""" + config = {"defaults": {"stt": stt_value}} + assert get_existing_stt_provider(config) == expected + + +def test_get_existing_stt_provider_unknown_returns_none(): + """Returns None for unknown stt values (e.g. custom providers).""" + config = {"defaults": {"stt": "stt-unknown-provider"}} + assert get_existing_stt_provider(config) is None + + +def test_get_existing_stt_provider_missing_key(): + """Returns None when defaults.stt key is absent.""" + assert get_existing_stt_provider({}) is None + assert get_existing_stt_provider({"defaults": {}}) is None + + +# --------------------------------------------------------------------------- +# get_existing_stream_provider +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "stt_stream_value, expected", + [ + ("stt-deepgram-stream", "deepgram"), + ("stt-smallest-stream", "smallest"), + ("stt-qwen3-asr", "qwen3-asr"), + ("stt-qwen3-asr-stream", "qwen3-asr"), + ], +) +def test_get_existing_stream_provider_known_values(stt_stream_value, expected): + """Maps known config.yml stt_stream values to wizard streaming provider names.""" + config = {"defaults": {"stt_stream": stt_stream_value}} + assert get_existing_stream_provider(config) == expected + + +def test_get_existing_stream_provider_unknown_returns_none(): + """Returns None for unknown stt_stream values.""" + config = {"defaults": {"stt_stream": "stt-unknown"}} + assert get_existing_stream_provider(config) is None + + +def test_get_existing_stream_provider_missing_key(): + """Returns None when defaults.stt_stream is absent.""" + assert get_existing_stream_provider({}) is None + assert get_existing_stream_provider({"defaults": {}}) is None + + +# --------------------------------------------------------------------------- +# select_llm_provider — test default resolution logic via EOFError path +# --------------------------------------------------------------------------- + + +def _select_llm_with_eof(config_yml): + """Drive select_llm_provider in non-interactive mode by injecting EOFError.""" + with patch.object(_wizard, "Prompt") as mock_prompt: + mock_prompt.ask.side_effect = EOFError + return select_llm_provider(config_yml) + + +def test_select_llm_provider_defaults_to_openai_when_no_config(): + """Defaults to openai when config is empty.""" + result = _select_llm_with_eof({}) + assert result == "openai" + + +def test_select_llm_provider_defaults_to_openai_for_openai_llm(): + """Picks openai when existing config has defaults.llm = openai-llm.""" + config = {"defaults": {"llm": "openai-llm"}} + result = _select_llm_with_eof(config) + assert result == "openai" + + +def test_select_llm_provider_defaults_to_ollama_for_local_llm(): + """Picks ollama when existing config has defaults.llm = local-llm.""" + config = {"defaults": {"llm": "local-llm"}} + result = _select_llm_with_eof(config) + assert result == "ollama" + + +def test_select_llm_provider_none_config(): + """Treats None config_yml as empty dict (defaults to openai).""" + result = _select_llm_with_eof(None) + assert result == "openai" + + +# --------------------------------------------------------------------------- +# select_memory_provider — test default resolution logic via EOFError path +# --------------------------------------------------------------------------- + + +def _select_memory_with_eof(config_yml): + with patch.object(_wizard, "Prompt") as mock_prompt: + mock_prompt.ask.side_effect = EOFError + return select_memory_provider(config_yml) + + +def test_select_memory_provider_defaults_to_chronicle_when_no_config(): + """Defaults to chronicle when config is empty.""" + result = _select_memory_with_eof({}) + assert result == "chronicle" + + +def test_select_memory_provider_defaults_to_chronicle(): + """Picks chronicle when existing config has memory.provider = chronicle.""" + config = {"memory": {"provider": "chronicle"}} + result = _select_memory_with_eof(config) + assert result == "chronicle" + + +def test_select_memory_provider_defaults_to_openmemory_mcp(): + """Picks openmemory_mcp when existing config has memory.provider = openmemory_mcp.""" + config = {"memory": {"provider": "openmemory_mcp"}} + result = _select_memory_with_eof(config) + assert result == "openmemory_mcp" + + +def test_select_memory_provider_none_config(): + """Treats None config_yml as empty dict (defaults to chronicle).""" + result = _select_memory_with_eof(None) + assert result == "chronicle" + + +# --------------------------------------------------------------------------- +# select_knowledge_graph — test default resolution logic via EOFError path +# --------------------------------------------------------------------------- + + +def _select_kg_with_eof(config_yml): + with patch.object(_wizard, "Confirm") as mock_confirm: + mock_confirm.ask.side_effect = EOFError + return select_knowledge_graph(config_yml) + + +def test_select_knowledge_graph_defaults_to_true_when_no_config(): + """Defaults to True (enabled) when config is empty.""" + result = _select_kg_with_eof({}) + assert result is True + + +def test_select_knowledge_graph_respects_existing_true(): + """Returns True when existing config has knowledge_graph.enabled = True.""" + config = {"memory": {"knowledge_graph": {"enabled": True}}} + result = _select_kg_with_eof(config) + assert result is True + + +def test_select_knowledge_graph_respects_existing_false(): + """Returns False when existing config has knowledge_graph.enabled = False.""" + config = {"memory": {"knowledge_graph": {"enabled": False}}} + result = _select_kg_with_eof(config) + assert result is False + + +def test_select_knowledge_graph_none_config(): + """Treats None config_yml as empty dict (defaults to True).""" + result = _select_kg_with_eof(None) + assert result is True diff --git a/wizard.py b/wizard.py index c3884120..5347e9b0 100755 --- a/wizard.py +++ b/wizard.py @@ -9,7 +9,7 @@ from datetime import datetime from pathlib import Path -import yaml +from config_manager import ConfigManager from rich.console import Console from rich.prompt import Confirm, Prompt @@ -27,6 +27,34 @@ console = Console() + +def get_existing_stt_provider(config_yml: dict): + """Map config.yml defaults.stt value back to wizard provider name, or None.""" + stt = config_yml.get("defaults", {}).get("stt", "") + mapping = { + "stt-deepgram": "deepgram", + "stt-deepgram-stream": "deepgram", + "stt-parakeet-batch": "parakeet", + "stt-vibevoice": "vibevoice", + "stt-qwen3-asr": "qwen3-asr", + "stt-smallest": "smallest", + "stt-smallest-stream": "smallest", + } + return mapping.get(stt) + + +def get_existing_stream_provider(config_yml: dict): + """Map config.yml defaults.stt_stream value back to wizard streaming provider name, or None.""" + stt_stream = config_yml.get("defaults", {}).get("stt_stream", "") + mapping = { + "stt-deepgram-stream": "deepgram", + "stt-smallest-stream": "smallest", + "stt-qwen3-asr": "qwen3-asr", + "stt-qwen3-asr-stream": "qwen3-asr", + } + return mapping.get(stt_stream) + + SERVICES = { "backend": { "advanced": { @@ -153,8 +181,9 @@ def check_service_exists(service_name, service_config): return True, "OK" -def select_services(transcription_provider=None): +def select_services(transcription_provider=None, config_yml=None, memory_provider=None): """Let user select which services to setup""" + config_yml = config_yml or {} console.print("🚀 [bold cyan]Chronicle Service Setup[/bold cyan]") console.print("Select which services to configure:\n") @@ -195,8 +224,25 @@ def select_services(transcription_provider=None): console.print(f" ⏸️ {service_config['description']} - [dim]{msg}[/dim]") continue - # Speaker recognition is recommended by default - default_enable = service_name == "speaker-recognition" + # Determine smart default based on existing config + if service_name == "speaker-recognition": + # Default to True if speaker-recognition .env exists and has a valid (non-placeholder) HF_TOKEN + speaker_env = "extras/speaker-recognition/.env" + existing_hf = read_env_value(speaker_env, "HF_TOKEN") + default_enable = bool( + existing_hf + and not is_placeholder( + existing_hf, + "your_huggingface_token_here", + "your-huggingface-token-here", + "hf_xxxxx", + ) + ) + elif service_name == "openmemory-mcp": + # Default to True if memory provider was selected as openmemory_mcp + default_enable = memory_provider == "openmemory_mcp" + else: + default_enable = False try: enable_service = Confirm.ask( @@ -250,6 +296,9 @@ def run_service_setup( langfuse_secret_key=None, langfuse_host=None, streaming_provider=None, + llm_provider=None, + memory_provider=None, + knowledge_graph_enabled=None, hardware_profile=None, ): """Execute individual service setup script""" @@ -279,9 +328,25 @@ def run_service_setup( if neo4j_password: cmd.extend(["--neo4j-password", neo4j_password]) - # Add Obsidian configuration + # Always pass obsidian choice to avoid double-ask if obsidian_enabled: cmd.extend(["--enable-obsidian"]) + else: + cmd.extend(["--no-obsidian"]) + + # Always pass knowledge graph choice to avoid double-ask + if knowledge_graph_enabled is True: + cmd.extend(["--enable-knowledge-graph"]) + elif knowledge_graph_enabled is False: + cmd.extend(["--no-knowledge-graph"]) + + # Pass LLM provider choice + if llm_provider: + cmd.extend(["--llm-provider", llm_provider]) + + # Pass memory provider choice + if memory_provider: + cmd.extend(["--memory-provider", memory_provider]) # Pass LangFuse keys from langfuse init or external config if langfuse_public_key and langfuse_secret_key: @@ -728,33 +793,26 @@ def setup_hf_token_if_needed(selected_services): return None -def setup_config_file(): - """Setup config/config.yml from template if it doesn't exist""" - config_file = Path("config/config.yml") - config_template = Path("config/config.yml.template") - - if not config_file.exists(): - if config_template.exists(): - # Ensure config/ directory exists - config_file.parent.mkdir(parents=True, exist_ok=True) - shutil.copy(config_template, config_file) - console.print("✅ [green]Created config/config.yml from template[/green]") - else: - console.print( - "⚠️ [yellow]config/config.yml.template not found, skipping config setup[/yellow]" - ) - else: - console.print( - "ℹ️ [blue]config/config.yml already exists, keeping existing configuration[/blue]" - ) - - # Providers that support real-time streaming STREAMING_CAPABLE = {"deepgram", "smallest", "qwen3-asr"} -def select_transcription_provider(): +def select_transcription_provider(config_yml: dict = None): """Ask user which transcription provider they want (batch/primary).""" + config_yml = config_yml or {} + existing_provider = get_existing_stt_provider(config_yml) + + provider_to_choice = { + "deepgram": "1", + "parakeet": "2", + "vibevoice": "3", + "qwen3-asr": "4", + "smallest": "5", + "none": "6", + } + choice_to_provider = {v: k for k, v in provider_to_choice.items()} + default_choice = provider_to_choice.get(existing_provider, "1") + console.print("\n🎤 [bold cyan]Transcription Provider[/bold cyan]") console.print( "Choose your speech-to-text provider (used for [bold]batch[/bold]/high-quality transcription):" @@ -762,6 +820,17 @@ def select_transcription_provider(): console.print( "[dim]If it also supports streaming, it will be used for real-time too by default.[/dim]" ) + if existing_provider: + provider_labels = { + "deepgram": "Deepgram", + "parakeet": "Parakeet ASR", + "vibevoice": "VibeVoice ASR", + "qwen3-asr": "Qwen3-ASR", + "smallest": "Smallest.ai Pulse", + } + console.print( + f"[blue][INFO][/blue] Current: {provider_labels.get(existing_provider, existing_provider)}" + ) console.print() choices = { @@ -774,34 +843,24 @@ def select_transcription_provider(): } for key, desc in choices.items(): - console.print(f" {key}) {desc}") + marker = " [dim](current)[/dim]" if key == default_choice else "" + console.print(f" {key}) {desc}{marker}") console.print() while True: try: - choice = Prompt.ask("Enter choice", default="1") + choice = Prompt.ask("Enter choice", default=default_choice) if choice in choices: - if choice == "1": - return "deepgram" - elif choice == "2": - return "parakeet" - elif choice == "3": - return "vibevoice" - elif choice == "4": - return "qwen3-asr" - elif choice == "5": - return "smallest" - elif choice == "6": - return "none" + return choice_to_provider[choice] console.print( f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]" ) except EOFError: - console.print("Using default: Deepgram") - return "deepgram" + console.print(f"Using default: {choices.get(default_choice, 'Deepgram')}") + return choice_to_provider.get(default_choice, "deepgram") -def select_streaming_provider(batch_provider): +def select_streaming_provider(batch_provider, config_yml: dict = None): """Ask if user wants a different provider for real-time streaming. If the batch provider supports streaming, offer to use the same (saves a step). @@ -810,16 +869,24 @@ def select_streaming_provider(batch_provider): Returns: Streaming provider name if different from batch, or None (same / skipped). """ + config_yml = config_yml or {} if batch_provider in ("none", None): return None + existing_stream = get_existing_stream_provider(config_yml) + if batch_provider in STREAMING_CAPABLE: # Batch provider can already stream — just confirm + # Default to "use different" if a different streaming provider was previously configured + has_different_stream = bool( + existing_stream and existing_stream != batch_provider + ) console.print(f"\n🔊 [bold cyan]Streaming[/bold cyan]") console.print(f"{batch_provider} supports both batch and streaming.") try: use_different = Confirm.ask( - "Use a different provider for real-time streaming?", default=False + "Use a different provider for real-time streaming?", + default=has_different_stream, ) except EOFError: return None @@ -851,13 +918,22 @@ def select_streaming_provider(batch_provider): streaming_choices[skip_key] = "Skip (no real-time streaming)" provider_map[skip_key] = None + # Pre-select the default based on existing config + default_stream_choice = "1" + if existing_stream and existing_stream != batch_provider: + for k, v in provider_map.items(): + if v == existing_stream: + default_stream_choice = k + break + for key, desc in streaming_choices.items(): - console.print(f" {key}) {desc}") + marker = " [dim](current)[/dim]" if key == default_stream_choice else "" + console.print(f" {key}) {desc}{marker}") console.print() while True: try: - choice = Prompt.ask("Enter choice", default="1") + choice = Prompt.ask("Enter choice", default=default_stream_choice) if choice in streaming_choices: result = provider_map[choice] if result: @@ -1007,6 +1083,117 @@ def select_hardware_profile( return None +def select_llm_provider(config_yml: dict = None) -> str: + """Ask user which LLM provider to use for memory extraction. + + Returns: + "openai", "ollama", or "none" + """ + config_yml = config_yml or {} + existing_llm = config_yml.get("defaults", {}).get("llm", "") + llm_to_choice = {"openai-llm": "1", "local-llm": "2"} + default_choice = llm_to_choice.get(existing_llm, "1") + + console.print("\n🤖 [bold cyan]LLM Provider[/bold cyan]") + console.print( + "Choose your language model provider for memory extraction and analysis:" + ) + console.print() + + choices = { + "1": "OpenAI (GPT-4o-mini, requires API key)", + "2": "Ollama (local models, runs on your machine)", + "3": "None (skip memory extraction)", + } + + for key, desc in choices.items(): + marker = " [dim](current)[/dim]" if key == default_choice else "" + console.print(f" {key}) {desc}{marker}") + console.print() + + while True: + try: + choice = Prompt.ask("Enter choice", default=default_choice) + if choice in choices: + return {"1": "openai", "2": "ollama", "3": "none"}[choice] + console.print( + f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]" + ) + except EOFError: + console.print(f"Using default: {choices.get(default_choice, 'OpenAI')}") + return {"1": "openai", "2": "ollama", "3": "none"}.get( + default_choice, "openai" + ) + + +def select_memory_provider(config_yml: dict = None) -> str: + """Ask user which memory storage backend to use. + + This is separate from the 'Setup OpenMemory MCP server?' service question. + That question is about running the extra service; this is about the backend provider. + + Returns: + "chronicle" or "openmemory_mcp" + """ + config_yml = config_yml or {} + existing_provider = config_yml.get("memory", {}).get("provider", "chronicle") + default_choice = "2" if existing_provider == "openmemory_mcp" else "1" + + console.print("\n🧠 [bold cyan]Memory Storage Backend[/bold cyan]") + console.print("Choose where your memories and conversation facts are stored:") + console.print() + + choices = { + "1": "Chronicle Native (Qdrant vector database, self-hosted)", + "2": "OpenMemory MCP (cross-client compatible, requires openmemory-mcp service)", + } + + for key, desc in choices.items(): + marker = " [dim](current)[/dim]" if key == default_choice else "" + console.print(f" {key}) {desc}{marker}") + console.print() + + while True: + try: + choice = Prompt.ask("Enter choice", default=default_choice) + if choice in choices: + return {"1": "chronicle", "2": "openmemory_mcp"}[choice] + console.print( + f"[red]Invalid choice. Please select from {list(choices.keys())}[/red]" + ) + except EOFError: + return {"1": "chronicle", "2": "openmemory_mcp"}.get( + default_choice, "chronicle" + ) + + +def select_knowledge_graph(config_yml: dict = None) -> bool: + """Ask user if Knowledge Graph should be enabled. + + Returns: + True if Knowledge Graph should be enabled, False otherwise. + """ + config_yml = config_yml or {} + existing_enabled = ( + config_yml.get("memory", {}).get("knowledge_graph", {}).get("enabled", True) + ) + + console.print("\n🕸️ [bold cyan]Knowledge Graph[/bold cyan]") + console.print( + "Extracts people, places, organizations, events, and tasks from conversations" + ) + console.print("Uses Neo4j (included in the stack)") + console.print() + + try: + enabled = Confirm.ask("Enable Knowledge Graph?", default=existing_enabled) + except EOFError: + console.print(f"Using default: {'Yes' if existing_enabled else 'No'}") + enabled = existing_enabled + + return enabled + + def main(): """Main orchestration logic""" console.print("🎉 [bold green]Welcome to Chronicle![/bold green]\n") @@ -1018,8 +1205,9 @@ def main(): "[dim]When unsure, just press Enter — the defaults will work.[/dim]\n" ) - # Setup config file from template - setup_config_file() + # Ensure config.yml exists (create from template if needed) + config_mgr = ConfigManager() + config_mgr.ensure_config_yml() # Setup git hooks first setup_git_hooks() @@ -1027,14 +1215,25 @@ def main(): # Show what's available show_service_status() + # Read existing config.yml once — used as defaults for ALL wizard questions below + config_yml = config_mgr.get_full_config() + # Ask about transcription provider FIRST (determines which services are needed) - transcription_provider = select_transcription_provider() + transcription_provider = select_transcription_provider(config_yml) # Ask about streaming provider (if batch provider doesn't stream, or user wants a different one) - streaming_provider = select_streaming_provider(transcription_provider) + streaming_provider = select_streaming_provider(transcription_provider, config_yml) + + # LLM Provider selection (asked once here, passed to init.py — avoids double-ask) + llm_provider = select_llm_provider(config_yml) + + # Memory Provider selection (asked once here, passed to init.py — avoids double-ask) + memory_provider = select_memory_provider(config_yml) # Service Selection (pass transcription_provider so we skip asking about ASR when already chosen) - selected_services = select_services(transcription_provider) + selected_services = select_services( + transcription_provider, config_yml, memory_provider + ) # Auto-add asr-services if any local ASR was chosen (batch or streaming) local_asr_providers = ("parakeet", "vibevoice", "qwen3-asr") @@ -1052,6 +1251,20 @@ def main(): ) selected_services.append("asr-services") + # Auto-add openmemory-mcp service if openmemory_mcp was selected as memory provider + if ( + memory_provider == "openmemory_mcp" + and "openmemory-mcp" not in selected_services + ): + exists, _ = check_service_exists( + "openmemory-mcp", SERVICES["extras"]["openmemory-mcp"] + ) + if exists: + console.print( + "[blue][INFO][/blue] Memory provider is OpenMemory MCP — auto-adding openmemory-mcp service" + ) + selected_services.append("openmemory-mcp") + if not selected_services: console.print("\n[yellow]No services selected. Exiting.[/yellow]") return @@ -1085,13 +1298,17 @@ def main(): "HTTPS enables microphone access in browsers and secure connections" ) + # Default to existing HTTPS_ENABLED setting + existing_https = read_env_value("backends/advanced/.env", "HTTPS_ENABLED") + default_https = existing_https == "true" + try: https_enabled = Confirm.ask( - "Enable HTTPS for selected services?", default=True + "Enable HTTPS for selected services?", default=default_https ) except EOFError: - console.print("Using default: Yes") - https_enabled = True + console.print(f"Using default: {'Yes' if default_https else 'No'}") + https_enabled = default_https if https_enabled: # Try to auto-detect Tailscale address @@ -1155,6 +1372,7 @@ def main(): # Neo4j Configuration (always required - used by Knowledge Graph) neo4j_password = None obsidian_enabled = False + knowledge_graph_enabled = None if "advanced" in selected_services: console.print("\n🗄️ [bold cyan]Neo4j Configuration[/bold cyan]") @@ -1163,19 +1381,19 @@ def main(): ) console.print() - # Prompt for Neo4j password (remembers previous value on re-run) - try: - neo4j_password = prompt_with_existing_masked( - "Neo4j password (min 8 chars)", - env_file_path="backends/advanced/.env", - env_key="NEO4J_PASSWORD", - placeholders=["", "your-neo4j-password"], - is_password=True, - default="neo4jpassword", - ) - except (EOFError, KeyboardInterrupt): - neo4j_password = "neo4jpassword" - console.print("Using default password") + # Read existing Neo4j password and use as default (masked prompt) + existing_neo4j_pw = read_env_value("backends/advanced/.env", "NEO4J_PASSWORD") + neo4j_password = prompt_with_existing_masked( + prompt_text="Neo4j password (min 8 chars)", + existing_value=existing_neo4j_pw, + placeholders=[ + "neo4jpassword", + "your_neo4j_password", + "your-neo4j-password", + ], + is_password=True, + default="neo4jpassword", + ) if not neo4j_password: neo4j_password = "neo4jpassword" @@ -1188,17 +1406,24 @@ def main(): ) console.print() + # Load existing obsidian enabled state from config.yml as default + existing_obsidian = ( + config_yml.get("memory", {}).get("obsidian", {}).get("enabled", False) + ) try: obsidian_enabled = Confirm.ask( - "Enable Obsidian integration?", default=False + "Enable Obsidian integration?", default=existing_obsidian ) except EOFError: - console.print("Using default: No") - obsidian_enabled = False + console.print(f"Using default: {'Yes' if existing_obsidian else 'No'}") + obsidian_enabled = existing_obsidian if obsidian_enabled: console.print("[green]✅[/green] Obsidian integration will be configured") + # Knowledge Graph configuration (asked here once, passed to init.py) + knowledge_graph_enabled = select_knowledge_graph(config_yml) + # Pure Delegation - Run Each Service Setup console.print(f"\n📋 [bold]Setting up {len(selected_services)} services...[/bold]") @@ -1246,6 +1471,9 @@ def main(): langfuse_secret_key=langfuse_secret_key, langfuse_host=langfuse_host, streaming_provider=streaming_provider, + llm_provider=llm_provider, + memory_provider=memory_provider, + knowledge_graph_enabled=knowledge_graph_enabled, hardware_profile=hardware_profile, ): success_count += 1