diff --git a/.env.template b/.env.template index 388edbf5..1a638bb5 100644 --- a/.env.template +++ b/.env.template @@ -218,4 +218,4 @@ WEBUI_MEMORY_REQUEST=128Mi SPEAKER_CPU_LIMIT=2000m SPEAKER_MEMORY_LIMIT=4Gi SPEAKER_CPU_REQUEST=500m -SPEAKER_MEMORY_REQUEST=2Gi \ No newline at end of file +SPEAKER_MEMORY_REQUEST=2Gi diff --git a/.github/workflows/advanced-docker-compose-build.yml b/.github/workflows/advanced-docker-compose-build.yml index 93e72d68..ff406822 100644 --- a/.github/workflows/advanced-docker-compose-build.yml +++ b/.github/workflows/advanced-docker-compose-build.yml @@ -140,10 +140,10 @@ jobs: set -euo pipefail docker compose version OWNER_LC=$(echo "$OWNER" | tr '[:upper:]' '[:lower:]') - + # CUDA variants from pyproject.toml CUDA_VARIANTS=("cpu" "cu121" "cu126" "cu128") - + # Base services (no CUDA variants, no profiles) base_service_specs=( "chronicle-backend|advanced-chronicle-backend|docker-compose.yml|." @@ -151,11 +151,11 @@ jobs: "webui|advanced-webui|docker-compose.yml|." "openmemory-mcp|openmemory-mcp|../../extras/openmemory-mcp/docker-compose.yml|../../extras/openmemory-mcp" ) - + # Build and push base services for spec in "${base_service_specs[@]}"; do IFS='|' read -r svc svc_repo compose_file project_dir <<< "$spec" - + echo "::group::Building and pushing $svc_repo" if [ "$compose_file" = "docker-compose.yml" ] && [ "$project_dir" = "." ]; then docker compose build --pull "$svc" @@ -173,7 +173,7 @@ jobs: echo "::endgroup::" continue fi - + # Tag and push with version target_image="$REGISTRY/$OWNER_LC/$svc_repo:$VERSION" latest_image="$REGISTRY/$OWNER_LC/$svc_repo:latest" @@ -181,21 +181,21 @@ jobs: docker tag "$img_id" "$target_image" echo "Tagging $img_id as $latest_image" docker tag "$img_id" "$latest_image" - + echo "Pushing $target_image" docker push "$target_image" echo "Pushing $latest_image" docker push "$latest_image" - + # Clean up local tags docker image rm -f "$target_image" || true docker image rm -f "$latest_image" || true echo "::endgroup::" - + # Aggressive cleanup to save space docker system prune -af || true done - + # Build and push parakeet-asr with CUDA variants (cu121, cu126, cu128) echo "::group::Building and pushing parakeet-asr CUDA variants" cd ../../extras/asr-services @@ -203,7 +203,7 @@ jobs: echo "Building parakeet-asr-${cuda_variant}" export PYTORCH_CUDA_VERSION="${cuda_variant}" docker compose build parakeet-asr - + img_id=$(docker compose images -q parakeet-asr | head -n1) if [ -n "${img_id:-}" ]; then target_image="$REGISTRY/$OWNER_LC/parakeet-asr-${cuda_variant}:$VERSION" @@ -212,23 +212,23 @@ jobs: docker tag "$img_id" "$target_image" echo "Tagging $img_id as $latest_image" docker tag "$img_id" "$latest_image" - + echo "Pushing $target_image" docker push "$target_image" echo "Pushing $latest_image" docker push "$latest_image" - + # Clean up local tags docker image rm -f "$target_image" || true docker image rm -f "$latest_image" || true fi - + # Aggressive cleanup to save space docker system prune -af || true done cd - > /dev/null echo "::endgroup::" - + # Build and push speaker-recognition with all CUDA variants (including CPU) # Note: speaker-service has profiles, but we can build it directly by setting PYTORCH_CUDA_VERSION echo "::group::Building and pushing speaker-recognition variants" @@ -238,7 +238,7 @@ jobs: export PYTORCH_CUDA_VERSION="${cuda_variant}" # Build speaker-service directly (profiles only affect 'up', not 'build') docker compose build speaker-service - + img_id=$(docker compose images -q speaker-service | head -n1) if [ -n "${img_id:-}" ]; then target_image="$REGISTRY/$OWNER_LC/speaker-recognition-${cuda_variant}:$VERSION" @@ -247,23 +247,23 @@ jobs: docker tag "$img_id" "$target_image" echo "Tagging $img_id as $latest_image" docker tag "$img_id" "$latest_image" - + echo "Pushing $target_image" docker push "$target_image" echo "Pushing $latest_image" docker push "$latest_image" - + # Clean up local tags docker image rm -f "$target_image" || true docker image rm -f "$latest_image" || true fi - + # Aggressive cleanup to save space docker system prune -af || true done cd - > /dev/null echo "::endgroup::" - + # Summary echo "::group::Build Summary" echo "Built and pushed images with version tag: ${VERSION}" diff --git a/CLAUDE.md b/CLAUDE.md index fc3d8818..2fa839a8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -613,4 +613,4 @@ The uv package manager is used for all python projects. Wherever you'd call `pyt - Only use `--no-cache` when explicitly needed (e.g., if cached layers are causing issues or when troubleshooting build problems) - Docker's build cache is efficient and saves significant time during development -- Remember that whenever there's a python command, you should use uv run python3 instead \ No newline at end of file +- Remember that whenever there's a python command, you should use uv run python3 instead diff --git a/Docs/init-system.md b/Docs/init-system.md index 895d727d..9d859226 100644 --- a/Docs/init-system.md +++ b/Docs/init-system.md @@ -3,7 +3,7 @@ ## Quick Links - **πŸ‘‰ [Start Here: Quick Start Guide](../quickstart.md)** - Main setup path for new users -- **πŸ“š [Full Documentation](../CLAUDE.md)** - Comprehensive reference +- **πŸ“š [Full Documentation](../CLAUDE.md)** - Comprehensive reference - **πŸ—οΈ [Architecture Details](overview.md)** - Technical deep dive --- @@ -59,12 +59,12 @@ Each service can be configured independently: cd backends/advanced uv run --with-requirements setup-requirements.txt python init.py -# Speaker Recognition only +# Speaker Recognition only cd extras/speaker-recognition ./setup.sh # ASR Services only -cd extras/asr-services +cd extras/asr-services ./setup.sh # OpenMemory MCP only @@ -80,21 +80,21 @@ cd extras/openmemory-mcp - **Generates**: Complete `.env` file with all required configuration - **Default ports**: Backend (8000), WebUI (5173) -### Speaker Recognition +### Speaker Recognition - **Prompts for**: Hugging Face token, compute mode (cpu/gpu) - **Service port**: 8085 - **WebUI port**: 5173 - **Requires**: HF_TOKEN for pyannote models ### ASR Services -- **Starts**: Parakeet ASR service via Docker Compose +- **Starts**: Parakeet ASR service via Docker Compose - **Service port**: 8767 - **Purpose**: Offline speech-to-text processing - **No configuration required** ### OpenMemory MCP - **Starts**: External OpenMemory MCP server -- **Service port**: 8765 +- **Service port**: 8765 - **WebUI**: Available at http://localhost:8765 - **Purpose**: Cross-client memory compatibility @@ -112,10 +112,10 @@ Note (Linux): If `host.docker.internal` is unavailable, add `extra_hosts: - "hos ## Key Benefits -βœ… **No Unnecessary Building** - Services are only started when you explicitly request them -βœ… **Resource Efficient** - Parakeet ASR won't start if you're using cloud transcription -βœ… **Clean Separation** - Configuration vs service management are separate concerns -βœ… **Unified Control** - Single command to start/stop all services +βœ… **No Unnecessary Building** - Services are only started when you explicitly request them +βœ… **Resource Efficient** - Parakeet ASR won't start if you're using cloud transcription +βœ… **Clean Separation** - Configuration vs service management are separate concerns +βœ… **Unified Control** - Single command to start/stop all services βœ… **Selective Starting** - Choose which services to run based on your current needs ## Ports & Access @@ -215,7 +215,7 @@ You can also manage services individually: # Advanced Backend cd backends/advanced && docker compose up --build -d -# Speaker Recognition +# Speaker Recognition cd extras/speaker-recognition && docker compose up --build -d # ASR Services (only if using offline transcription) @@ -225,6 +225,59 @@ cd extras/asr-services && docker compose up --build -d cd extras/openmemory-mcp && docker compose up --build -d ``` +## Startup Flow (Mermaid) diagram + +Chronicle has two layers: +- **Setup** (`wizard.sh` / `wizard.py`) writes config (`.env`, `config/config.yml`, optional SSL/nginx config). +- **Run** (`start.sh` / `services.py`) starts the configured services via `docker compose`. + +```mermaid +flowchart TD + A[wizard.sh] --> B[uv run --with-requirements setup-requirements.txt wizard.py] + B --> C{Select services} + C --> D[backends/advanced/init.py\nwrites backends/advanced/.env + config/config.yml] + C --> E[extras/speaker-recognition/init.py\nwrites extras/speaker-recognition/.env\noptionally ssl/* + nginx.conf] + C --> F[extras/asr-services/init.py\nwrites extras/asr-services/.env] + C --> G[extras/openmemory-mcp/setup.sh] + + A2[start.sh] --> B2[uv run --with-requirements setup-requirements.txt python services.py start ...] + B2 --> H{For each service:\n.env exists?} + H -->|yes| I[services.py runs docker compose\nin each service directory] + H -->|no| J[Skip (not configured)] +``` + +### How `services.py` picks Speaker Recognition variants + +`services.py` reads `extras/speaker-recognition/.env` and decides: +- `COMPUTE_MODE=cpu|gpu|strixhalo` β†’ choose compose profile +- `REACT_UI_HTTPS=true|false` β†’ include `nginx` (HTTPS) vs run only API+UI (HTTP) + +```mermaid +flowchart TD + S[start.sh] --> P[services.py] + P --> R[Read extras/speaker-recognition/.env] + R --> M{COMPUTE_MODE} + M -->|cpu| C1[docker compose --profile cpu up ...] + M -->|gpu| C2[docker compose --profile gpu up ...] + M -->|strixhalo| C3[docker compose --profile strixhalo up ...] + R --> H{REACT_UI_HTTPS} + H -->|true| N1[Start profile default set:\nAPI + web-ui + nginx] + H -->|false| N2[Start only:\nAPI + web-ui (no nginx)] +``` + +### CPU + NVIDIA share the same `Dockerfile` + `pyproject.toml` + +Speaker recognition uses a single dependency definition with per-accelerator β€œextras”: +- `extras/speaker-recognition/pyproject.toml` defines extras like `cpu`, `cu121`, `cu126`, `cu128`, `strixhalo`. +- `extras/speaker-recognition/Dockerfile` takes `ARG PYTORCH_CUDA_VERSION` and runs: + - `uv sync --extra ${PYTORCH_CUDA_VERSION}` + - `uv run --extra ${PYTORCH_CUDA_VERSION} ...` +- `extras/speaker-recognition/docker-compose.yml` sets that build arg per profile: + - CPU profile defaults to `PYTORCH_CUDA_VERSION=cpu` + - GPU profile defaults to `PYTORCH_CUDA_VERSION=cu126` and reserves NVIDIA GPUs + +AMD/ROCm (Strix Halo) uses the same `pyproject.toml` interface (the `strixhalo` extra), but a different build recipe (`extras/speaker-recognition/Dockerfile.strixhalo`) and ROCm device mappings, because the base image provides the torch stack. + ## Configuration Files ### Generated Files @@ -250,7 +303,7 @@ cd extras/openmemory-mcp && docker compose up --build -d # Backend health curl http://localhost:8000/health -# Speaker Recognition health +# Speaker Recognition health curl http://localhost:8085/health # ASR service health @@ -267,4 +320,4 @@ cd backends/advanced && docker compose logs chronicle-backend # Speaker Recognition logs cd extras/speaker-recognition && docker compose logs speaker-service -``` \ No newline at end of file +``` diff --git a/README-K8S.md b/README-K8S.md index 8bbe22fa..6a36e038 100644 --- a/README-K8S.md +++ b/README-K8S.md @@ -47,7 +47,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu 2. **Install Ubuntu Server** - Boot from USB/DVD - Choose "Install Ubuntu Server" - - Configure network with static IP (recommended: 192.168.1.42) + - Configure network with static IP (recommended: 192.168.1.42) - Set hostname (e.g., `k8s_control_plane`) - Create user account - Install OpenSSH server @@ -56,10 +56,10 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # Update system sudo apt update && sudo apt upgrade -y - + # Install essential packages sudo apt install -y curl wget git vim htop tree - + # Configure firewall sudo ufw allow ssh sudo ufw allow 6443 # Kubernetes API @@ -75,11 +75,11 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # Install MicroK8s sudo snap install microk8s --classic - + # Add user to microk8s group sudo usermod -a -G microk8s $USER sudo chown -f -R $USER ~/.kube - + # Log out and back in, or run: newgrp microk8s ``` @@ -88,10 +88,10 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # Start MicroK8s sudo microk8s start - + # Wait for all services to be ready sudo microk8s status --wait-ready - + # Generate join token for worker nodes sudo microk8s add-node # This will output a command like: @@ -103,7 +103,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # Start MicroK8s sudo microk8s start - + # Wait for all services to be ready sudo microk8s status --wait-ready ``` @@ -115,7 +115,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu sudo microk8s enable ingress sudo microk8s enable storage sudo microk8s enable metrics-server - + # Wait for add-ons to be ready sudo microk8s status --wait-ready ``` @@ -125,7 +125,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu # Create kubectl alias echo 'alias kubectl="microk8s kubectl"' >> ~/.bashrc source ~/.bashrc - + # Verify installation kubectl get nodes kubectl get pods -A @@ -147,11 +147,11 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # Install MicroK8s sudo snap install microk8s --classic - + # Add user to microk8s group sudo usermod -a -G microk8s $USER sudo chown -f -R $USER ~/.kube - + # Log out and back in, or run: newgrp microk8s ``` @@ -161,7 +161,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu # Use the join command from the control plane # Replace with your actual join token sudo microk8s join 192.168.1.42:25000/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - + # Wait for node to join sudo microk8s status --wait-ready ``` @@ -170,7 +170,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # On the control plane, verify the worker node joined kubectl get nodes - + # The worker node should show as Ready # Example output: # NAME STATUS ROLES AGE VERSION @@ -182,7 +182,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # From your build machine, configure the worker node ./configure-insecure-registry-remote.sh 192.168.1.43 - + # Repeat for each worker node with their respective IPs ``` @@ -194,10 +194,10 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # Enable the built-in MicroK8s registry (not enabled by default) sudo microk8s enable registry - + # Wait for registry to be ready sudo microk8s status --wait-ready - + # Verify registry is running kubectl get pods -n container-registry ``` @@ -213,11 +213,11 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # From your build machine, configure MicroK8s to trust the insecure registry chmod +x scripts/configure-insecure-registry-remote.sh - + # Run the configuration script with your node IP address # Usage: ./scripts/configure-insecure-registry-remote.sh [ssh_user] ./scripts/configure-insecure-registry-remote.sh 192.168.1.42 - + # Or with custom SSH user: # ./scripts/configure-insecure-registry-remote.sh 192.168.1.42 myuser ``` @@ -234,7 +234,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu ```bash # Apply the hostpath provisioner kubectl apply -f k8s-manifests/hostpath-provisioner-official.yaml - + # Verify storage class kubectl get storageclass ``` @@ -281,19 +281,19 @@ chronicle/ > **Note:** The `--recursive` flag downloads the optional Mycelia submodule (an alternative memory backend with timeline visualization). Most deployments use the default Chronicle memory system and don't need Mycelia. 2. **Install Required Tools** - + **kubectl** (required for Skaffold and Helm): - Visit: https://kubernetes.io/docs/tasks/tools/ - Follow the official installation guide for your platform - + **Skaffold**: - Visit: https://skaffold.dev/docs/install/ - - Follow the official installation guide - + - Follow the official installation guide + **Helm**: - Visit: https://helm.sh/docs/intro/install/ - - Follow the official installation guide - + - Follow the official installation guide + **Verify installations:** ```bash kubectl version --client @@ -311,7 +311,7 @@ chronicle/ ```bash # Copy template (if it exists) # cp backends/advanced/.env.template backends/advanced/.env - + # Note: Most environment variables are automatically set by Skaffold during deployment # including MONGODB_URI, QDRANT_BASE_URL, and other Kubernetes-specific values ``` @@ -320,27 +320,27 @@ chronicle/ ```bash # Copy the template file cp skaffold.env.template skaffold.env - + # Edit skaffold.env with your specific values vim skaffold.env - + # Essential variables to configure: REGISTRY=192.168.1.42:32000 # Use IP address for immediate access # Alternative: REGISTRY=k8s_control_plane:32000 (requires adding 'k8s_control_plane 192.168.1.42' to /etc/hosts) BACKEND_IP=192.168.1.42 BACKEND_NODEPORT=30270 WEBUI_NODEPORT=31011 - + # Optional: Configure speaker recognition service HF_TOKEN=hf_your_huggingface_token_here DEEPGRAM_API_KEY=your_deepgram_api_key_here - + # Note: MONGODB_URI and QDRANT_BASE_URL are automatically generated # by Skaffold based on your infrastructure namespace and service names ``` 3. **Configuration Variables Reference** - + **Required Variables:** - `REGISTRY`: Docker registry for image storage - `BACKEND_IP`: IP address of your Kubernetes control plane @@ -348,13 +348,13 @@ chronicle/ - `WEBUI_NODEPORT`: Port for WebUI service (30000-32767) - `INFRASTRUCTURE_NAMESPACE`: Namespace for MongoDB and Qdrant - `APPLICATION_NAMESPACE`: Namespace for your application - + **Optional Variables (for Speaker Recognition):** - `HF_TOKEN`: Hugging Face token for Pyannote models - `DEEPGRAM_API_KEY`: Deepgram API key for speech-to-text - `COMPUTE_MODE`: GPU or CPU mode for ML services - `SIMILARITY_THRESHOLD`: Speaker identification threshold - + **Automatically Generated:** - `MONGODB_URI`: Generated from infrastructure namespace - `QDRANT_BASE_URL`: Generated from infrastructure namespace @@ -365,11 +365,11 @@ chronicle/ ```bash # Note: Most environment variables are handled by Skaffold automatically # If you need custom environment variables, you can: - + # Option 1: Use the script (if it exists) # chmod +x scripts/generate-helm-configmap.sh # ./scripts/generate-helm-configmap.sh - + # Option 2: Add them directly to the Helm chart values # Edit backends/charts/advanced-backend/values.yaml ``` @@ -460,7 +460,7 @@ This directory contains standalone Kubernetes manifests that are not managed by ```bash # Deploy everything in the correct order ./scripts/deploy-all-services.sh - + # This will automatically: # - Deploy infrastructure (MongoDB, Qdrant) # - Deploy main application (Backend, WebUI) @@ -473,13 +473,13 @@ This directory contains standalone Kubernetes manifests that are not managed by ```bash # Deploy infrastructure first skaffold run --profile=infrastructure - + # Wait for infrastructure to be ready kubectl get pods -n root - + # Deploy main application skaffold run --profile=advanced-backend --default-repo=192.168.1.42:32000 - + # Monitor deployment skaffold run --profile=advanced-backend --default-repo=192.168.1.42:32000 --tail ``` @@ -489,10 +489,10 @@ This directory contains standalone Kubernetes manifests that are not managed by # Check all resources kubectl get all -n chronicle kubectl get all -n root - + # Check Ingress kubectl get ingress -n chronicle - + # Check services kubectl get svc -n chronicle ``` @@ -636,7 +636,7 @@ spec: ```bash # Check backend health curl -k https://chronicle.192-168-1-42.nip.io:32623/health - + # Check WebUI curl -k https://chronicle.192-168-1-42.nip.io:32623/ ``` @@ -660,7 +660,7 @@ spec: ```bash # Test registry connectivity (run on Kubernetes node) curl http://k8s_control_plane:32000/v2/ - + # Check MicroK8s containerd config (run on Kubernetes node) sudo cat /var/snap/microk8s/current/args/certs.d/k8s_control_plane:32000/hosts.toml ``` @@ -669,7 +669,7 @@ spec: ```bash # Check storage class (run on build machine) kubectl get storageclass - + # Check persistent volumes (run on build machine) kubectl get pv kubectl get pvc -A @@ -679,7 +679,7 @@ spec: ```bash # Check Ingress controller (run on build machine) kubectl get pods -n ingress-nginx - + # Check Ingress configuration (run on build machine) kubectl describe ingress -n chronicle ``` @@ -696,13 +696,13 @@ spec: # Check GPU operator status (run on build machine) kubectl get pods -n gpu-operator kubectl describe pod -n gpu-operator - + # Check GPU detection on nodes kubectl get nodes -o json | jq '.items[] | {name: .metadata.name, gpu: .status.allocatable."nvidia.com/gpu"}' - + # Check GPU operator logs kubectl logs -n gpu-operator deployment/gpu-operator - + # Verify NVIDIA drivers on host (run on Kubernetes node) nvidia-smi ``` @@ -712,18 +712,18 @@ spec: # Check node connectivity (run on build machine) kubectl get nodes kubectl describe node - + # Check node status and conditions kubectl get nodes -o json | jq '.items[] | {name: .metadata.name, status: .status.conditions[] | select(.type=="Ready") | .status, message: .message}' - + # Check if pods can be scheduled kubectl get pods -A -o wide kubectl describe pod -n - + # Check node resources and capacity kubectl top nodes kubectl describe node | grep -A 10 "Allocated resources" - + # Verify network connectivity between nodes # Run on each node: ping @@ -766,7 +766,7 @@ kubectl rollout restart deployment/webui -n chronicle ```bash # Update system packages sudo apt update && sudo apt upgrade -y - + # Update MicroK8s sudo snap refresh microk8s ``` @@ -776,7 +776,7 @@ kubectl rollout restart deployment/webui -n chronicle # Backup environment files (run on build machine) cp backends/advanced/.env backends/advanced/.env.backup cp skaffold.env skaffold.env.backup - + # Backup Kubernetes manifests (run on build machine) kubectl get all -n chronicle -o yaml > chronicle-backup.yaml kubectl get all -n root -o yaml > infrastructure-backup.yaml diff --git a/README.md b/README.md index a4383ba2..3ff922d3 100644 --- a/README.md +++ b/README.md @@ -170,7 +170,7 @@ Usecases are numerous - OMI Mentor is one of them. Friend/Omi/pendants are a sma Regardless - this repo will try to do the minimal of this - multiple OMI-like audio devices feeding audio data - and from it: - Memories -- Action items +- Action items - Home automation ## Golden Goals (Not Yet Achieved) @@ -179,4 +179,3 @@ Regardless - this repo will try to do the minimal of this - multiple OMI-like au - **Home automation integration** (planned) - **Multi-device coordination** (planned) - **Visual context capture** (smart glasses integration planned) - diff --git a/app/README.md b/app/README.md index e85e83e5..7041c19b 100644 --- a/app/README.md +++ b/app/README.md @@ -174,7 +174,7 @@ Stream audio directly from your phone's microphone to Chronicle backend, bypassi #### Requirements - **iOS**: iOS 13+ with microphone permissions -- **Android**: Android API 21+ with microphone permissions +- **Android**: Android API 21+ with microphone permissions - **Network**: Stable connection to Chronicle backend - **Backend**: Advanced backend running with `/ws?codec=pcm` endpoint @@ -191,7 +191,7 @@ Stream audio directly from your phone's microphone to Chronicle backend, bypassi - **Network Connection**: Test backend connectivity - **Authentication**: Verify JWT token is valid -#### Poor Audio Quality +#### Poor Audio Quality - **Check Signal Strength**: Ensure stable network connection - **Reduce Background Noise**: Use in quiet environment - **Restart Recording**: Stop and restart phone audio streaming @@ -365,4 +365,4 @@ BluetoothService.onAudioData = (audioBuffer) => { - **[Backend Setup](../backends/)**: Choose and configure backend services - **[Quick Start Guide](../quickstart.md)**: Complete system setup - **[Advanced Backend](../backends/advanced/)**: Full-featured backend option -- **[Simple Backend](../backends/simple/)**: Basic backend for testing \ No newline at end of file +- **[Simple Backend](../backends/simple/)**: Basic backend for testing diff --git a/app/app/components/BackendStatus.tsx b/app/app/components/BackendStatus.tsx index 4f55d37f..d953ab96 100644 --- a/app/app/components/BackendStatus.tsx +++ b/app/app/components/BackendStatus.tsx @@ -40,21 +40,21 @@ export const BackendStatus: React.FC = ({ try { // Convert WebSocket URL to HTTP URL for health check let baseUrl = backendUrl.trim(); - + // Handle different URL formats if (baseUrl.startsWith('ws://')) { baseUrl = baseUrl.replace('ws://', 'http://'); } else if (baseUrl.startsWith('wss://')) { baseUrl = baseUrl.replace('wss://', 'https://'); } - + // Remove any WebSocket path if present baseUrl = baseUrl.split('/ws')[0]; - + // Try health endpoint first const healthUrl = `${baseUrl}/health`; console.log('[BackendStatus] Checking health at:', healthUrl); - + const response = await fetch(healthUrl, { method: 'GET', headers: { @@ -63,7 +63,7 @@ export const BackendStatus: React.FC = ({ ...(jwtToken ? { 'Authorization': `Bearer ${jwtToken}` } : {}), }, }); - + console.log('[BackendStatus] Health check response status:', response.status); if (response.ok) { @@ -73,7 +73,7 @@ export const BackendStatus: React.FC = ({ message: `Connected (${healthData.status || 'OK'})`, lastChecked: new Date(), }); - + if (showAlert) { Alert.alert('Connection Success', 'Successfully connected to backend!'); } @@ -83,7 +83,7 @@ export const BackendStatus: React.FC = ({ message: 'Authentication required', lastChecked: new Date(), }); - + if (showAlert) { Alert.alert('Authentication Required', 'Please login to access the backend.'); } @@ -92,7 +92,7 @@ export const BackendStatus: React.FC = ({ } } catch (error) { console.error('[BackendStatus] Health check error:', error); - + let errorMessage = 'Connection failed'; if (error instanceof Error) { if (error.message.includes('Network request failed')) { @@ -103,13 +103,13 @@ export const BackendStatus: React.FC = ({ errorMessage = error.message; } } - + setHealthStatus({ status: 'unhealthy', message: errorMessage, lastChecked: new Date(), }); - + if (showAlert) { Alert.alert( 'Connection Failed', @@ -125,7 +125,7 @@ export const BackendStatus: React.FC = ({ const timer = setTimeout(() => { checkBackendHealth(false); }, 500); // Debounce - + return () => clearTimeout(timer); } }, [backendUrl, jwtToken]); @@ -163,7 +163,7 @@ export const BackendStatus: React.FC = ({ return ( Backend Connection - + Backend URL: (''); - + // State for User ID const [userId, setUserId] = useState(''); - + // Authentication state const [isAuthenticated, setIsAuthenticated] = useState(false); const [currentUserEmail, setCurrentUserEmail] = useState(null); const [jwtToken, setJwtToken] = useState(null); - + // Bluetooth Management Hook const { bleManager, @@ -65,7 +65,7 @@ export default function App() { // Custom Audio Streamer Hook const audioStreamer = useAudioStreamer(); - + // Phone Audio Recorder Hook const phoneAudioRecorder = usePhoneAudioRecorder(); const [isPhoneAudioMode, setIsPhoneAudioMode] = useState(false); @@ -195,7 +195,7 @@ export default function App() { bluetoothState === BluetoothState.PoweredOn, // Derived from useBluetoothManager requestBluetoothPermission // From useBluetoothManager, should be stable ); - + // Effect for attempting auto-reconnection useEffect(() => { if ( @@ -260,15 +260,15 @@ export default function App() { try { let finalWebSocketUrl = webSocketUrl.trim(); - + // Check if this is the advanced backend (requires authentication) or simple backend const isAdvancedBackend = jwtToken && isAuthenticated; - + if (isAdvancedBackend) { // Advanced backend: include JWT token and device parameters const params = new URLSearchParams(); params.append('token', jwtToken); - + if (userId && userId.trim() !== '') { params.append('device_name', userId.trim()); console.log('[App.tsx] Using advanced backend with token and device_name:', userId.trim()); @@ -276,7 +276,7 @@ export default function App() { params.append('device_name', 'phone'); // Default device name console.log('[App.tsx] Using advanced backend with token and default device_name'); } - + const separator = webSocketUrl.includes('?') ? '&' : '?'; finalWebSocketUrl = `${webSocketUrl}${separator}${params.toString()}`; console.log('[App.tsx] Advanced backend WebSocket URL constructed (token hidden for security)'); @@ -318,10 +318,10 @@ export default function App() { try { let finalWebSocketUrl = webSocketUrl.trim(); - + // Convert HTTP/HTTPS to WS/WSS protocol finalWebSocketUrl = finalWebSocketUrl.replace(/^http:/, 'ws:').replace(/^https:/, 'wss:'); - + // Ensure /ws endpoint is included if (!finalWebSocketUrl.includes('/ws')) { // Remove trailing slash if present, then add /ws @@ -333,19 +333,19 @@ export default function App() { const separator = finalWebSocketUrl.includes('?') ? '&' : '?'; finalWebSocketUrl = finalWebSocketUrl + separator + 'codec=pcm'; } - + // Check if this is the advanced backend (requires authentication) or simple backend const isAdvancedBackend = jwtToken && isAuthenticated; - + if (isAdvancedBackend) { // Advanced backend: include JWT token and device parameters const params = new URLSearchParams(); params.append('token', jwtToken); - + const deviceName = userId && userId.trim() !== '' ? userId.trim() : 'phone-mic'; params.append('device_name', deviceName); console.log('[App.tsx] Using advanced backend with token and device_name:', deviceName); - + const separator = finalWebSocketUrl.includes('?') ? '&' : '?'; finalWebSocketUrl = `${finalWebSocketUrl}${separator}${params.toString()}`; console.log('[App.tsx] Advanced backend WebSocket URL constructed for phone audio'); @@ -356,7 +356,7 @@ export default function App() { // Start WebSocket streaming first await audioStreamer.startStreaming(finalWebSocketUrl); - + // Start phone audio recording await phoneAudioRecorder.startRecording(async (pcmBuffer) => { const wsReadyState = audioStreamer.getWebSocketReadyState(); @@ -364,7 +364,7 @@ export default function App() { await audioStreamer.sendAudio(pcmBuffer); } }); - + setIsPhoneAudioMode(true); console.log('[App.tsx] Phone audio streaming started successfully'); } catch (error) { @@ -417,7 +417,7 @@ export default function App() { return () => { console.log('App unmounting - cleaning up OmiConnection, BleManager, AudioStreamer, and PhoneAudioRecorder'); const refs = cleanupRefs.current; - + if (refs.omiConnection.isConnected()) { refs.disconnectFromDevice().catch(err => console.error("Error disconnecting in cleanup:", err)); } @@ -486,7 +486,7 @@ export default function App() { } // Attempt to stop any ongoing connection process // disconnectFromDevice also sets isConnecting to false internally. - await deviceConnection.disconnectFromDevice(); + await deviceConnection.disconnectFromDevice(); setIsAttemptingAutoReconnect(false); // Explicitly set to false to hide the auto-reconnect screen }, [deviceConnection, lastKnownDeviceId, saveLastConnectedDeviceId, setLastKnownDeviceId, setTriedAutoReconnectForCurrentId, setIsAttemptingAutoReconnect]); @@ -495,8 +495,8 @@ export default function App() { - {isAttemptingAutoReconnect - ? `Attempting to reconnect to the last device (${lastKnownDeviceId ? lastKnownDeviceId.substring(0, 10) + '...' : ''})...` + {isAttemptingAutoReconnect + ? `Attempting to reconnect to the last device (${lastKnownDeviceId ? lastKnownDeviceId.substring(0, 10) + '...' : ''})...` : 'Initializing Bluetooth...'} @@ -519,12 +519,12 @@ export default function App() { return ( - - @@ -618,7 +618,7 @@ export default function App() { ) : ( - {showOnlyOmi + {showOnlyOmi ? `No OMI/Friend devices found. ${scannedDevices.length} other device(s) hidden by filter.` : 'No devices found.' } @@ -627,7 +627,7 @@ export default function App() { )} )} - + {deviceConnection.connectedDeviceId && filteredDevices.find(d => d.id === deviceConnection.connectedDeviceId) && ( Connected Device @@ -638,9 +638,9 @@ export default function App() { console.log('[App.tsx] Manual disconnect initiated via DeviceListItem.'); // Prevent auto-reconnection by clearing the last known device ID *before* disconnecting. await saveLastConnectedDeviceId(null); - setLastKnownDeviceId(null); - setTriedAutoReconnectForCurrentId(true); - + setLastKnownDeviceId(null); + setTriedAutoReconnectForCurrentId(true); + // TODO: Consider adding setIsDisconnecting(true) here if a visual indicator is needed // and a finally block to set it to false, similar to the old handleDisconnectPress. // For now, focusing on the core logic. @@ -658,7 +658,7 @@ export default function App() { /> )} - + {/* Show disconnect button when connected but scan list isn't visible */} {deviceConnection.connectedDeviceId && !filteredDevices.find(d => d.id === deviceConnection.connectedDeviceId) && ( @@ -671,9 +671,9 @@ export default function App() { onPress={async () => { console.log('[App.tsx] Manual disconnect initiated via standalone disconnect button.'); await saveLastConnectedDeviceId(null); - setLastKnownDeviceId(null); + setLastKnownDeviceId(null); setTriedAutoReconnectForCurrentId(true); - + try { await deviceConnection.disconnectFromDevice(); console.log('[App.tsx] Manual disconnect from device successful.'); diff --git a/backends/advanced/.dockerignore b/backends/advanced/.dockerignore index f0f7f05c..237b041c 100644 --- a/backends/advanced/.dockerignore +++ b/backends/advanced/.dockerignore @@ -18,4 +18,4 @@ !start.sh !start-k8s.sh !worker_orchestrator.py -!Caddyfile \ No newline at end of file +!Caddyfile diff --git a/backends/advanced/.env.bak b/backends/advanced/.env.bak new file mode 100644 index 00000000..77fc4cf7 --- /dev/null +++ b/backends/advanced/.env.bak @@ -0,0 +1,128 @@ +# ======================================== +# Chronicle Backend - Secrets Only +# ======================================== +# This file contains ONLY secret values (API keys, passwords, tokens). +# All other configuration is in config/config.yml. +# +# Setup: +# 1. Copy this file to .env: cp .env.template .env +# 2. Fill in your API keys and secrets below +# 3. Configure non-secret settings in config/config.yml +# 4. Run: docker compose up --build -d + +# ======================================== +# Authentication Secrets +# ======================================== + +# JWT signing key (generate a long random string) +AUTH_SECRET_KEY='9532558aa95c37f53dbb04310c1084b2b7541f163559384ed2485065022c5f39' + +# Admin account password +ADMIN_PASSWORD='abc12345' + +# Admin email address +ADMIN_EMAIL='admin@example.com' + +# ======================================== +# LLM API Keys +# ======================================== + +# OpenAI API key (or OpenAI-compatible provider) +OPENAI_API_KEY= + +# ======================================== +# Transcription API Keys +# ======================================== + +# Deepgram API key (for cloud-based transcription) +DEEPGRAM_API_KEY= + +# Smallest.ai API key (for Pulse STT) +# SMALLEST_API_KEY= + +# ======================================== +# Speaker Recognition +# ======================================== + +# Hugging Face token (for PyAnnote speaker recognition models) +HF_TOKEN= + +# ======================================== +# Optional Services +# ======================================== + +# Neo4j configuration (if using Neo4j for Obsidian or Knowledge Graph) +NEO4J_HOST='neo4j' +NEO4J_USER='neo4j' +NEO4J_PASSWORD='abc12345' + +# Langfuse (for LLM observability and prompt management) +LANGFUSE_HOST='http://langfuse-web:3000' +LANGFUSE_PUBLIC_KEY='pk-lf-56e2bb264820104cbae94e739ea18765' +LANGFUSE_SECRET_KEY='sk-lf-b5936c7795d1be9e3f3e0a3988a17b96' +LANGFUSE_BASE_URL='http://langfuse-web:3000' + +# Galileo (OTEL-based LLM observability) +GALILEO_API_KEY= +GALILEO_PROJECT=chronicle +GALILEO_LOG_STREAM=default +# GALILEO_CONSOLE_URL=https://app.galileo.ai # Default; override for self-hosted + +# Qwen3-ASR (offline ASR via vLLM) +# QWEN3_ASR_URL=host.docker.internal:8767 +# QWEN3_ASR_STREAM_URL=host.docker.internal:8769 + +# Tailscale auth key (for remote service access) +TS_AUTHKEY= + +# ======================================== +# Plugin Configuration +# ======================================== +# Plugin-specific configuration is in: backends/advanced/src/advanced_omi_backend/plugins/{plugin_id}/config.yml +# Plugin orchestration (enabled, events) is in: config/plugins.yml +# This section contains ONLY plugin secrets + +# --------------------------------------- +# Home Assistant Plugin +# --------------------------------------- +# Enable in config/plugins.yml +# Configure in backends/advanced/src/advanced_omi_backend/plugins/homeassistant/config.yml + +# Home Assistant server URL +HA_URL=http://homeassistant.local:8123 + +# Home Assistant long-lived access token +# Get from: Profile β†’ Security β†’ Long-Lived Access Tokens +HA_TOKEN= + +# Wake word for voice commands (optional, default: vivi) +HA_WAKE_WORD=vivi + +# Request timeout in seconds (optional, default: 30) +HA_TIMEOUT=30 + +# --------------------------------------- +# Email Summarizer Plugin +# --------------------------------------- +# Enable in config/plugins.yml +# Configure in backends/advanced/src/advanced_omi_backend/plugins/email_summarizer/config.yml + +# SMTP server configuration +# For Gmail: Use App Password (requires 2FA enabled) +# 1. Go to Google Account β†’ Security β†’ 2-Step Verification +# 2. Scroll to "App passwords" β†’ Generate password for "Mail" +# 3. Use the 16-character password below (no spaces) +SMTP_HOST=smtp.gmail.com +SMTP_PORT=587 +SMTP_USERNAME=your-email@gmail.com +SMTP_PASSWORD=your-app-password-here +SMTP_USE_TLS=true + +# Email sender information +FROM_EMAIL=noreply@chronicle.ai +FROM_NAME=Chronicle AI +PARAKEET_ASR_URL='http://host.docker.internal:8767' +SPEAKER_SERVICE_URL='http://speaker-service:8085' +BACKEND_PUBLIC_PORT='8000' +WEBUI_PORT='5173' +HTTPS_ENABLED='false' diff --git a/backends/advanced/Docs/README.md b/backends/advanced/Docs/README.md index e58f94ee..c79c29c6 100644 --- a/backends/advanced/Docs/README.md +++ b/backends/advanced/Docs/README.md @@ -15,7 +15,7 @@ Welcome to chronicle! This guide provides the optimal reading sequence to unders - Basic setup and configuration - **Code References**: `src/advanced_omi_backend/main.py`, `config/config.yml`, `docker-compose.yml` -### 2. **[System Architecture](./architecture.md)** +### 2. **[System Architecture](./architecture.md)** **Read second** - Complete technical architecture with diagrams - Component relationships and data flow - Authentication and security architecture @@ -32,7 +32,7 @@ Welcome to chronicle! This guide provides the optimal reading sequence to unders - How conversations become memories - Mem0 integration and vector storage - Configuration and customization options -- **Code References**: +- **Code References**: - `src/advanced_omi_backend/memory/memory_service.py` (main processing) - `src/advanced_omi_backend/transcript_coordinator.py` (event coordination) - `src/advanced_omi_backend/conversation_repository.py` (data access) @@ -78,7 +78,7 @@ Welcome to chronicle! This guide provides the optimal reading sequence to unders ### **"I want to understand the system quickly"** (30 min) 1. [quickstart.md](./quickstart.md) - System overview -2. [architecture.md](./architecture.md) - Technical architecture +2. [architecture.md](./architecture.md) - Technical architecture 3. `src/advanced_omi_backend/main.py` - Core imports and setup 4. `config/config.yml` - Configuration overview @@ -147,7 +147,7 @@ backends/advanced-backend/ ### **Authentication** - **Setup**: `src/advanced_omi_backend/auth.py` -- **Users**: `src/advanced_omi_backend/users.py` +- **Users**: `src/advanced_omi_backend/users.py` - **Integration**: `src/advanced_omi_backend/routers/api_router.py` --- diff --git a/backends/advanced/Docs/README_speaker_enrollment.md b/backends/advanced/Docs/README_speaker_enrollment.md index 6f705d67..de3fa736 100644 --- a/backends/advanced/Docs/README_speaker_enrollment.md +++ b/backends/advanced/Docs/README_speaker_enrollment.md @@ -7,7 +7,7 @@ The advanced backend now includes sophisticated speaker recognition functionalit The speaker recognition system provides: 1. **Speaker Diarization**: Automatically detect and separate different speakers in audio -2. **Speaker Enrollment**: Register known speakers with audio samples +2. **Speaker Enrollment**: Register known speakers with audio samples 3. **Speaker Identification**: Identify enrolled speakers in new audio 4. **API Endpoints**: RESTful API for all speaker operations 5. **Command Line Tools**: Easy-to-use scripts for speaker management @@ -32,7 +32,7 @@ The speaker recognition system requires additional packages. Install them with: # For audio recording (optional) pip install sounddevice soundfile -# For API calls +# For API calls pip install aiohttp requests # Core dependencies (should already be installed) @@ -75,7 +75,7 @@ curl -X POST "http://localhost:8000/api/speakers/enroll" \ -H "Content-Type: application/json" \ -d '{ "speaker_id": "alice", - "speaker_name": "Alice Smith", + "speaker_name": "Alice Smith", "audio_file_path": "audio_chunk_file.wav" }' @@ -151,7 +151,7 @@ python enroll_speaker.py --identify "audio_chunk_test_recognition_67890.wav" 4. **FAISS Storage**: Add embedding to FAISS index for fast similarity search 5. **Database Storage**: Store speaker metadata in MongoDB -### Identification Process +### Identification Process 1. **Embedding Extraction**: Generate embedding from unknown audio 2. **Similarity Search**: Search FAISS index for most similar enrolled speaker @@ -211,7 +211,7 @@ The system supports: 4. **Poor Recognition Accuracy** ``` Issue: Speakers not being identified correctly - Solutions: + Solutions: - Use cleaner audio for enrollment (less background noise) - Enroll with longer audio segments (5-10 seconds) - Lower similarity threshold if needed @@ -248,10 +248,10 @@ from enroll_speaker import enroll_speaker_api async def batch_enroll(): speakers = [ ("alice", "Alice Smith", "alice.wav"), - ("bob", "Bob Jones", "bob.wav"), + ("bob", "Bob Jones", "bob.wav"), ("charlie", "Charlie Brown", "charlie.wav") ] - + for speaker_id, name, file in speakers: await enroll_speaker_api("localhost", 8000, speaker_id, name, file) @@ -286,7 +286,7 @@ speakers = list_enrolled_speakers() ## Next Steps 1. **Improve Accuracy**: Collect more training data for your specific use case -2. **Real-time Processing**: Implement streaming speaker recognition +2. **Real-time Processing**: Implement streaming speaker recognition 3. **Speaker Adaptation**: Fine-tune models on your specific speakers 4. **Multi-language Support**: Add support for different languages -5. **Speaker Verification**: Add 1:1 verification in addition to 1:N identification \ No newline at end of file +5. **Speaker Verification**: Add 1:1 verification in addition to 1:N identification diff --git a/backends/advanced/Docs/UI.md b/backends/advanced/Docs/UI.md index 02bdf943..66d9e316 100644 --- a/backends/advanced/Docs/UI.md +++ b/backends/advanced/Docs/UI.md @@ -30,7 +30,7 @@ The Chronicle web dashboard provides a comprehensive interface for managing conv **Features**: - Real-time conversation listing with metadata -- Audio playback and transcript viewing +- Audio playback and transcript viewing - Conversation status tracking (open/closed) - Speaker identification and timing information - Audio file upload for processing existing recordings @@ -39,7 +39,7 @@ The Chronicle web dashboard provides a comprehensive interface for managing conv - View all users' conversations - Advanced filtering and search capabilities -### 2. Memories Tab +### 2. Memories Tab **Purpose**: Browse and search extracted conversation memories with advanced filtering **Core Search Features**: @@ -97,7 +97,7 @@ The Chronicle web dashboard provides a comprehensive interface for managing conv #### System Overview (Click "πŸ“ˆ Load Debug Stats") - **Processing Metrics**: Total memory sessions, success rates, processing times -- **Failure Analysis**: Failed extractions and error tracking +- **Failure Analysis**: Failed extractions and error tracking - **Performance Monitoring**: Average processing times and bottlenecks - **Live Statistics**: Real-time system performance data @@ -199,7 +199,7 @@ ADMIN_PASSWORD=your-admin-password 4. **User Management**: Create/manage user accounts as needed 5. **Troubleshooting**: Use debug tools to investigate issues -### User Workflow +### User Workflow 1. **Authentication**: Login via sidebar 2. **View Conversations**: Browse recent audio sessions 3. **Search Memories**: Find relevant conversation insights @@ -243,9 +243,9 @@ ADMIN_PASSWORD=your-admin-password ### Debug Steps 1. **Check Logs**: `./logs/streamlit.log` for frontend issues -2. **Backend Health**: Use `/health` endpoint to verify backend status +2. **Backend Health**: Use `/health` endpoint to verify backend status 3. **API Testing**: Test endpoints directly with admin token 4. **Service Status**: Use debug tab to check component health 5. **Configuration**: Verify all environment variables are set correctly -This dashboard provides comprehensive system management capabilities with particular strength in debugging and monitoring the audio processing pipeline and memory extraction systems. \ No newline at end of file +This dashboard provides comprehensive system management capabilities with particular strength in debugging and monitoring the audio processing pipeline and memory extraction systems. diff --git a/backends/advanced/Docs/architecture.md b/backends/advanced/Docs/architecture.md index 739f0ed7..2d7474f2 100644 --- a/backends/advanced/Docs/architecture.md +++ b/backends/advanced/Docs/architecture.md @@ -46,21 +46,21 @@ graph TB subgraph "ProcessorManager" direction TB PM[Manager Coordinator] - + subgraph "Global Queues" AQ[Audio Queue] TQ[Transcription Queue] MQ[Memory Queue] CQ[Cropping Queue] end - + subgraph "Processors" AP[Audio Processor] TP[Transcription Processor] MP[Memory Processor] CP[Cropping Processor] end - + subgraph "Event Coordination" TC[TranscriptCoordinator
AsyncIO Events] end @@ -99,33 +99,33 @@ graph TB APP --> WS WS --> AUTH AUTH --> WP - + %% Wyoming Protocol Flow WP --> AS WP --> AC WP --> AST - + AS --> CM AC --> CM AST --> CM - + %% Client State Management CM --> CS CS --> CT - + %% Audio Processing Flow CS -->|queue audio| AQ AQ --> AP AP -->|save| FS AP -->|queue| TQ - + %% Transcription Flow TQ --> TP TP --> DG TP --> WY TP -->|save transcript| AC_COL TP -->|signal completion| TC - + %% Conversation Closure Flow (Critical Path) AST -->|close_conversation| CS CS -->|delegate to| CONV_MGR @@ -133,7 +133,7 @@ graph TB TC -->|transcript ready| CONV_MGR CONV_MGR -->|queue memory| MQ CONV_MGR -->|if enabled| CQ - + %% Memory Processing Flow MQ --> MP MP -->|read via| CONV_REPO @@ -142,13 +142,13 @@ graph TB MP --> OLL MP -->|orchestrate by mem0| MEM MP -->|update via| CONV_REPO - + %% Cropping Flow CQ --> CP CP -->|read| FS CP -->|save cropped| FS CP -->|update| AC_COL - + %% Management PM -.-> AQ PM -.-> TQ @@ -385,7 +385,7 @@ BackgroundTaskManager( ### Conversation Closure and Memory Processing -**Recent Architecture Changes**: +**Recent Architecture Changes**: 1. Memory processing decoupled from transcription to prevent duplicates 2. **Event-driven coordination** eliminates polling/retry race conditions @@ -410,14 +410,14 @@ The system now uses **TranscriptCoordinator** for proper async coordination: class TranscriptCoordinator: async def wait_for_transcript_completion(audio_uuid: str) -> bool: # Wait for asyncio.Event (no polling!) - + def signal_transcript_ready(audio_uuid: str): # Signal completion from TranscriptionManager ``` **Benefits:** - βœ… **Zero polling** - Uses asyncio events instead of sleep/retry loops -- βœ… **Immediate processing** - Memory processor starts as soon as transcript is ready +- βœ… **Immediate processing** - Memory processor starts as soon as transcript is ready - βœ… **No race conditions** - Proper async coordination prevents timing issues - βœ… **Better performance** - No artificial delays or timeout-based waiting @@ -450,7 +450,7 @@ The `close_conversation()` method in ClientState now delegates to ConversationMa - Most reliable trigger - Immediate processing -2. **Client Disconnect** +2. **Client Disconnect** - WebSocket connection closed - Cleanup handler ensures closure - Prevents orphaned conversations @@ -473,10 +473,10 @@ async def close_current_conversation(self): # 1. Check if conversation has content if not self.has_transcripts(): return # No memory processing needed - + # 2. Save conversation to MongoDB audio_uuid = await self.save_conversation() - + # 3. Queue memory processing (ONLY place this happens) if audio_uuid and self.has_required_data(): processor_manager.queue_memory_processing({ @@ -484,11 +484,11 @@ async def close_current_conversation(self): 'user_id': self.user_id, 'transcript': self.get_transcript() }) - + # 4. Queue audio cropping if enabled if self.audio_cropping_enabled: processor_manager.queue_audio_cropping(audio_path) - + # 5. Reset state for next conversation self.reset_conversation_state() ``` @@ -591,21 +591,21 @@ graph LR Mongo[mongo:4.4.18
Primary Database] Qdrant[qdrant
Vector Store] end - + subgraph "External Services" Ollama[ollama
LLM Service] ASRService[ASR Services
extras/asr-services] end - + subgraph "Client Access" WebBrowser[Web Browser
Dashboard] AudioClient[Audio Client
Mobile/Desktop] end - + WebBrowser -->|Port 5173 (dev)| WebUI WebBrowser -->|Port 80 (prod)| Proxy AudioClient -->|Port 8000| Backend - + Proxy --> Backend Proxy --> WebUI Backend --> Mongo @@ -639,7 +639,7 @@ graph LR ## Detailed Data Flow Architecture -> πŸ“– **Reference Documentation**: +> πŸ“– **Reference Documentation**: > - [Authentication Details](./auth.md) - Complete authentication system documentation ### Complete System Data Flow Diagram @@ -661,13 +661,13 @@ flowchart TB subgraph "🎡 Audio Processing Pipeline" WSAuth[WebSocket Auth
πŸ• Connection timeout: 30s] OpusDecoder[Opus/PCM Decoder
Real-time Processing] - + subgraph "⏱️ Per-Client State Management" ClientState[Client State
πŸ• Conversation timeout: 1.5min] AudioQueue[Audio Chunk Queue
60s segments] ConversationTimer[Conversation Timer
πŸ”„ Auto-timeout tracking] end - + subgraph "πŸŽ™οΈ Transcription Layer" ASRManager[Transcription Manager
πŸ• Init timeout: 60s] DeepgramWS[Deepgram WebSocket
Nova-3 Model, Smart Format
πŸ”Œ Auto-reconnect on disconnect] @@ -685,7 +685,7 @@ flowchart TB LLMProcessor[Ollama LLM
πŸ”„ Circuit breaker protection] VectorStore[Qdrant Vector Store
πŸ” Semantic search] end - + end @@ -701,7 +701,7 @@ flowchart TB AuthGW -->|❌ 401 Unauthorized
⏱️ Invalid/expired token| Client AuthGW -->|βœ… Validated| ClientGen ClientGen -->|🏷️ Generate client_id| WSAuth - + %% Audio Processing Flow Client -->|🎡 Opus/PCM Stream
πŸ• 30s connection timeout| WSAuth WSAuth -->|❌ 1008 Policy Violation
πŸ” Auth required| Client @@ -709,7 +709,7 @@ flowchart TB OpusDecoder -->|πŸ“¦ Audio chunks| ClientState ClientState -->|⏱️ 1.5min timeout check| ConversationTimer ConversationTimer -->|πŸ”„ Timeout exceeded| ClientState - + %% Transcription Flow with Failure Points ClientState -->|🎡 Audio data| ASRManager ASRManager -->|πŸ”Œ Primary connection| DeepgramWS @@ -727,7 +727,7 @@ flowchart TB LLMProcessor -->|❌ Empty response
πŸ”„ Fallback memory| MemoryService LLMProcessor -->|βœ… Memory extracted| VectorStore MemoryService -->|πŸ“Š Track processing| QueueTracker - + %% Disconnect and Cleanup Flow Client -->|πŸ”Œ Disconnect| ClientState @@ -765,7 +765,7 @@ flowchart TB #### πŸ”Œ **Disconnection Scenarios** 1. **Client Disconnect**: Graceful cleanup with conversation finalization -2. **Network Interruption**: Auto-reconnection with exponential backoff +2. **Network Interruption**: Auto-reconnection with exponential backoff 3. **Service Failure**: Circuit breaker protection and alternative routing 4. **Authentication Expiry**: Forced re-authentication with clear error codes @@ -792,7 +792,7 @@ flowchart TB **Solution**: A 2-second delay is added before calling `close_client_audio()` to ensure the transcription manager is created by the background processor. Without this delay, the flush call fails silently and transcription never completes. **File Upload Flow**: -1. Audio chunks queued to `transcription_queue` +1. Audio chunks queued to `transcription_queue` 2. Background transcription processor creates `TranscriptionManager` on first chunk 3. 2-second delay ensures manager exists before flush 4. Client audio closure triggers transcript completion diff --git a/backends/advanced/Docs/auth.md b/backends/advanced/Docs/auth.md index b1b9c273..fd019c90 100644 --- a/backends/advanced/Docs/auth.md +++ b/backends/advanced/Docs/auth.md @@ -16,11 +16,11 @@ class User(BeanieBaseUser, Document): is_active: bool = True is_superuser: bool = False is_verified: bool = False - + # Custom fields display_name: Optional[str] = None registered_clients: dict[str, dict] = Field(default_factory=dict) - + @property def user_id(self) -> str: """Return string representation of MongoDB ObjectId for backward compatibility.""" @@ -41,7 +41,7 @@ class UserManager(BaseUserManager[User, PydanticObjectId]): """Authenticate with email+password""" username = credentials.get("username") # Email-based authentication only - + async def get_by_email(self, email: str) -> Optional[User]: """Get user by email address""" ``` @@ -287,7 +287,7 @@ curl -H "Authorization: Bearer $TOKEN" http://localhost:8000/users/me ```bash # Old AUTH_USERNAME=abc123 # Custom user_id (deprecated) - + # New AUTH_USERNAME=user@example.com # Email address only ``` @@ -296,7 +296,7 @@ curl -H "Authorization: Bearer $TOKEN" http://localhost:8000/users/me ```python # Old username = AUTH_USERNAME # Could be email or user_id - + # New username = AUTH_USERNAME # Email address only ``` @@ -336,4 +336,4 @@ async def get_all_data(user: User = Depends(current_superuser)): # Unified user dashboard ``` -This authentication system provides enterprise-grade security with developer-friendly APIs, supporting email/password authentication and modern OAuth flows while maintaining proper data isolation and user management capabilities using MongoDB's robust ObjectId system. \ No newline at end of file +This authentication system provides enterprise-grade security with developer-friendly APIs, supporting email/password authentication and modern OAuth flows while maintaining proper data isolation and user management capabilities using MongoDB's robust ObjectId system. diff --git a/backends/advanced/Docs/memories.md b/backends/advanced/Docs/memories.md index 08ae393e..3da4e188 100644 --- a/backends/advanced/Docs/memories.md +++ b/backends/advanced/Docs/memories.md @@ -4,7 +4,7 @@ This document explains how to configure and customize the memory service in the chronicle backend. -**Code References**: +**Code References**: - **Main Implementation**: `src/memory/memory_service.py` - **Event Coordination**: `src/advanced_omi_backend/transcript_coordinator.py` (zero-polling async events) - **Repository Layer**: `src/advanced_omi_backend/conversation_repository.py` (clean data access) @@ -16,7 +16,7 @@ This document explains how to configure and customize the memory service in the The memory service uses [Mem0](https://mem0.ai/) to store, retrieve, and search conversation memories. It integrates with Ollama for embeddings and LLM processing, and Qdrant for vector storage. -**Key Architecture Changes**: +**Key Architecture Changes**: 1. **Event-Driven Processing**: Memories use asyncio events instead of polling/retry mechanisms 2. **Repository Pattern**: Clean data access through ConversationRepository 3. **User-Centric Storage**: All memories keyed by user_id instead of client_id @@ -47,7 +47,7 @@ The memory service uses [Mem0](https://mem0.ai/) to store, retrieve, and search **Key Flow:** 1. **Audio** β†’ **TranscriptCoordinator** signals completion -2. **ConversationManager** waits for event, queues memory processing +2. **ConversationManager** waits for event, queues memory processing 3. **MemoryProcessor** uses **ConversationRepository** for data access 4. **Mem0 + Ollama** extract and store memories in **Qdrant** @@ -88,7 +88,7 @@ MEM0_CONFIG = { }, }, "embedder": { - "provider": "ollama", + "provider": "ollama", "config": { "model": "nomic-embed-text:latest", "embedding_dims": 768, @@ -311,10 +311,10 @@ MEM0_CONFIG["vector_store"]["config"].update({ ```python def search_memories_with_filters(self, query: str, user_id: str, topic: str = None): filters = {} - + if topic: filters["metadata.topics"] = {"$in": [topic]} - + return self.memory.search( query=query, user_id=user_id, @@ -329,7 +329,7 @@ def search_memories_with_filters(self, query: str, user_id: str, topic: str = No def get_important_memories(self, user_id: str): """Get memories sorted by importance/frequency""" memories = self.memory.get_all(user_id=user_id) - + # Custom scoring logic for memory in memories: score = 0 @@ -338,7 +338,7 @@ def get_important_memories(self, user_id: str): if "deadline" in memory.get('memory', '').lower(): score += 3 memory['importance_score'] = score - + return sorted(memories, key=lambda x: x.get('importance_score', 0), reverse=True) ``` @@ -409,13 +409,13 @@ Create a custom processing function: def custom_memory_processor(transcript: str, client_id: str, audio_uuid: str, user_id: str, user_email: str): # Extract entities entities = extract_named_entities(transcript) - + # Classify conversation type conv_type = classify_conversation(transcript) - + # Generate custom summary summary = generate_custom_summary(transcript, conv_type) - + # Store with enriched metadata process_memory.add( summary, @@ -440,12 +440,12 @@ def init_specialized_memory_services(): # Personal memories personal_config = MEM0_CONFIG.copy() personal_config["vector_store"]["config"]["collection_name"] = "personal_memories" - - # Work memories + + # Work memories work_config = MEM0_CONFIG.copy() work_config["vector_store"]["config"]["collection_name"] = "work_memories" work_config["custom_prompt"] = "Focus on work-related tasks, meetings, and projects" - + return { "personal": Memory.from_config(personal_config), "work": Memory.from_config(work_config) @@ -460,7 +460,7 @@ Implement automatic memory cleanup: def cleanup_old_memories(self, user_id: str, days_old: int = 365): """Remove memories older than specified days""" cutoff_timestamp = int(time.time()) - (days_old * 24 * 60 * 60) - + memories = self.get_all_memories(user_id) for memory in memories: if memory.get('metadata', {}).get('timestamp', 0) < cutoff_timestamp: @@ -533,14 +533,14 @@ async def batch_add_memories(self, transcripts_data: List[Dict]): tasks = [] for data in transcripts_data: task = self.add_memory( - data['transcript'], - data['client_id'], + data['transcript'], + data['client_id'], data['audio_uuid'], data['user_id'], # Database user_id data['user_email'] # User email ) tasks.append(task) - + results = await asyncio.gather(*tasks, return_exceptions=True) return results ``` @@ -553,14 +553,14 @@ Implement memory consolidation: def consolidate_memories(self, user_id: str, time_window_hours: int = 24): """Consolidate related memories from the same time period""" recent_memories = self.get_recent_memories(user_id, time_window_hours) - + if len(recent_memories) > 5: # If many memories in short time consolidated = self.summarize_memories(recent_memories) - + # Delete individual memories and store consolidated version for memory in recent_memories: self.delete_memory(memory['id']) - + return self.add_consolidated_memory(consolidated, user_id) ``` @@ -569,7 +569,7 @@ def consolidate_memories(self, user_id: str, time_window_hours: int = 24): The memory service exposes these endpoints with enhanced search capabilities: - `GET /api/memories` - Get user memories with total count support (keyed by database user_id) -- `GET /api/memories/search?query={query}&limit={limit}` - **Semantic memory search** with relevance scoring (user-scoped) +- `GET /api/memories/search?query={query}&limit={limit}` - **Semantic memory search** with relevance scoring (user-scoped) - `GET /api/memories/unfiltered` - User's memories without filtering for debugging - `DELETE /api/memories/{memory_id}` - Delete specific memory (requires authentication) - `GET /api/memories/admin` - Admin view of all memories across all users (superuser only) @@ -578,11 +578,11 @@ The memory service exposes these endpoints with enhanced search capabilities: **Semantic Search (`/api/memories/search`)**: - **Relevance Scoring**: Returns similarity scores from vector database (0.0-1.0 range) -- **Configurable Limits**: Supports `limit` parameter for result count control +- **Configurable Limits**: Supports `limit` parameter for result count control - **User Scoped**: Results automatically filtered by authenticated user - **Vector-based**: Uses embeddings for contextual understanding beyond keyword matching -**Memory Count API**: +**Memory Count API**: - **Chronicle Provider**: Native Qdrant count API provides accurate total counts - **OpenMemory MCP Provider**: Count support varies by OpenMemory implementation - **Response Format**: `{"memories": [...], "total_count": 42}` when supported @@ -604,7 +604,7 @@ Returns all memories across all users in a clean, searchable format: "user_id": "abc123", "created_at": "2025-07-10T14:30:00Z", "owner_user_id": "abc123", - "owner_email": "user@example.com", + "owner_email": "user@example.com", "owner_display_name": "John Doe", "metadata": { "client_id": "abc123-laptop", diff --git a/backends/advanced/README.md b/backends/advanced/README.md index 104137b3..e85636a7 100644 --- a/backends/advanced/README.md +++ b/backends/advanced/README.md @@ -38,7 +38,7 @@ Modern React-based web dashboard located in `./webui/` with: - **Optional Services**: Speaker Recognition, network configuration - **API Keys**: Prompts for all required keys with helpful links -#### 2. Start Services +#### 2. Start Services **HTTP Mode (Default - No SSL required):** ```bash diff --git a/backends/advanced/scripts/create_plugin.py b/backends/advanced/scripts/create_plugin.py index f24427ad..9668c14d 100755 --- a/backends/advanced/scripts/create_plugin.py +++ b/backends/advanced/scripts/create_plugin.py @@ -7,6 +7,7 @@ Usage: uv run python scripts/create_plugin.py my_awesome_plugin """ + import argparse import os import shutil @@ -16,7 +17,7 @@ def snake_to_pascal(snake_str: str) -> str: """Convert snake_case to PascalCase.""" - return ''.join(word.capitalize() for word in snake_str.split('_')) + return "".join(word.capitalize() for word in snake_str.split("_")) def create_plugin(plugin_name: str, force: bool = False): @@ -28,19 +29,19 @@ def create_plugin(plugin_name: str, force: bool = False): force: Overwrite existing plugin if True """ # Validate plugin name - if not plugin_name.replace('_', '').isalnum(): + if not plugin_name.replace("_", "").isalnum(): print(f"❌ Error: Plugin name must be alphanumeric with underscores") print(f" Got: {plugin_name}") print(f" Example: my_awesome_plugin") sys.exit(1) # Convert to class name - class_name = snake_to_pascal(plugin_name) + 'Plugin' + class_name = snake_to_pascal(plugin_name) + "Plugin" # Get plugins directory (repo root plugins/) script_dir = Path(__file__).parent backend_dir = script_dir.parent - plugins_dir = backend_dir.parent.parent / 'plugins' + plugins_dir = backend_dir.parent.parent / "plugins" plugin_dir = plugins_dir / plugin_name # Check if plugin already exists @@ -70,7 +71,7 @@ def create_plugin(plugin_name: str, force: bool = False): __all__ = ['{class_name}'] ''' - init_file = plugin_dir / '__init__.py' + init_file = plugin_dir / "__init__.py" print(f"πŸ“ Creating {init_file}") init_file.write_text(init_content, encoding="utf-8") @@ -271,12 +272,12 @@ async def _my_helper_method(self, data: Any) -> Any: pass ''' - plugin_file = plugin_dir / 'plugin.py' + plugin_file = plugin_dir / "plugin.py" print(f"πŸ“ Creating {plugin_file}") - plugin_file.write_text(plugin_content,encoding="utf-8") + plugin_file.write_text(plugin_content, encoding="utf-8") # Create README.md - readme_content = f'''# {class_name} + readme_content = f"""# {class_name} [Brief description of what your plugin does] @@ -363,9 +364,9 @@ async def _my_helper_method(self, data: Any) -> Any: ## License MIT License - see project LICENSE file for details. -''' +""" - readme_file = plugin_dir / 'README.md' + readme_file = plugin_dir / "README.md" print(f"πŸ“ Creating {readme_file}") readme_file.write_text(readme_content, encoding="utf-8") @@ -402,23 +403,23 @@ async def _my_helper_method(self, data: Any) -> Any: def main(): parser = argparse.ArgumentParser( - description='Create a new Chronicle plugin with boilerplate structure', + description="Create a new Chronicle plugin with boilerplate structure", formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=''' + epilog=""" Examples: uv run python scripts/create_plugin.py my_awesome_plugin uv run python scripts/create_plugin.py slack_notifier uv run python scripts/create_plugin.py todo_extractor --force - ''' + """, ) parser.add_argument( - 'plugin_name', - help='Plugin name in snake_case (e.g., my_awesome_plugin)' + "plugin_name", help="Plugin name in snake_case (e.g., my_awesome_plugin)" ) parser.add_argument( - '--force', '-f', - action='store_true', - help='Overwrite existing plugin if it exists' + "--force", + "-f", + action="store_true", + help="Overwrite existing plugin if it exists", ) args = parser.parse_args() @@ -433,5 +434,5 @@ def main(): sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/backends/advanced/scripts/delete_all_conversations_api.py b/backends/advanced/scripts/delete_all_conversations_api.py index 2c47f263..25cd2b3b 100755 --- a/backends/advanced/scripts/delete_all_conversations_api.py +++ b/backends/advanced/scripts/delete_all_conversations_api.py @@ -4,11 +4,12 @@ This uses the proper API authentication and endpoints. """ +import argparse import asyncio import os import sys -import argparse from pathlib import Path + import aiohttp from dotenv import load_dotenv @@ -20,104 +21,100 @@ async def get_auth_token(): """Get admin authentication token.""" admin_email = os.getenv("ADMIN_EMAIL") admin_password = os.getenv("ADMIN_PASSWORD") - + if not admin_email or not admin_password: print("Error: ADMIN_EMAIL and ADMIN_PASSWORD must be set in .env file") sys.exit(1) - + base_url = "http://localhost:8000" - + async with aiohttp.ClientSession() as session: # Login to get token - login_data = { - "username": admin_email, - "password": admin_password - } - + login_data = {"username": admin_email, "password": admin_password} + async with session.post( - f"{base_url}/auth/jwt/login", - data=login_data + f"{base_url}/auth/jwt/login", data=login_data ) as response: if response.status != 200: print(f"Failed to login: {response.status}") text = await response.text() print(f"Response: {text}") sys.exit(1) - + result = await response.json() return result["access_token"] async def delete_all_conversations(skip_prompt=False): """Delete all conversations using the API.""" - + base_url = "http://localhost:8000" - + # Get auth token print("Getting admin authentication token...") token = await get_auth_token() - - headers = { - "Authorization": f"Bearer {token}" - } - + + headers = {"Authorization": f"Bearer {token}"} + async with aiohttp.ClientSession() as session: # First, get all conversations print("Fetching all conversations...") async with session.get( - f"{base_url}/api/conversations", - headers=headers + f"{base_url}/api/conversations", headers=headers ) as response: if response.status != 200: print(f"Failed to fetch conversations: {response.status}") text = await response.text() print(f"Response: {text}") return - + data = await response.json() - + # Extract conversations from nested structure conversations_dict = data.get("conversations", {}) conversations = [] for client_id, client_conversations in conversations_dict.items(): conversations.extend(client_conversations) - + print(f"Found {len(conversations)} conversations") - + if len(conversations) == 0: print("No conversations to delete") return - + # Confirm deletion unless --yes flag is used if not skip_prompt: - response = input(f"Are you sure you want to delete ALL {len(conversations)} conversations? (yes/no): ") + response = input( + f"Are you sure you want to delete ALL {len(conversations)} conversations? (yes/no): " + ) if response.lower() != "yes": print("Deletion cancelled") return - + # Delete each conversation deleted_count = 0 failed_count = 0 - + for conv in conversations: audio_uuid = conv.get("audio_uuid") if not audio_uuid: print(f"Skipping conversation without audio_uuid: {conv.get('_id')}") continue - + # Delete the conversation async with session.delete( - f"{base_url}/api/conversations/{audio_uuid}", - headers=headers + f"{base_url}/api/conversations/{audio_uuid}", headers=headers ) as response: if response.status == 200: deleted_count += 1 - print(f"Deleted conversation {audio_uuid} ({deleted_count}/{len(conversations)})") + print( + f"Deleted conversation {audio_uuid} ({deleted_count}/{len(conversations)})" + ) else: failed_count += 1 text = await response.text() print(f"Failed to delete {audio_uuid}: {response.status} - {text}") - + print(f"\nDeletion complete:") print(f" Successfully deleted: {deleted_count}") print(f" Failed: {failed_count}") @@ -125,8 +122,10 @@ async def delete_all_conversations(skip_prompt=False): if __name__ == "__main__": # Parse command line arguments - parser = argparse.ArgumentParser(description='Delete all conversations') - parser.add_argument('--yes', '-y', action='store_true', help='Skip confirmation prompt') + parser = argparse.ArgumentParser(description="Delete all conversations") + parser.add_argument( + "--yes", "-y", action="store_true", help="Skip confirmation prompt" + ) args = parser.parse_args() - - asyncio.run(delete_all_conversations(skip_prompt=args.yes)) \ No newline at end of file + + asyncio.run(delete_all_conversations(skip_prompt=args.yes)) diff --git a/backends/advanced/scripts/laptop_client.py b/backends/advanced/scripts/laptop_client.py index a0047f3b..a848469f 100644 --- a/backends/advanced/scripts/laptop_client.py +++ b/backends/advanced/scripts/laptop_client.py @@ -21,7 +21,11 @@ def build_websocket_uri( - host: str, port: int, endpoint: str, token: str | None = None, device_name: str = "laptop" + host: str, + port: int, + endpoint: str, + token: str | None = None, + device_name: str = "laptop", ) -> str: """Build WebSocket URI with JWT token authentication.""" base_uri = f"ws://{host}:{port}{endpoint}" @@ -36,7 +40,9 @@ def build_websocket_uri( return base_uri -async def authenticate_with_credentials(host: str, port: int, username: str, password: str) -> str: +async def authenticate_with_credentials( + host: str, port: int, username: str, password: str +) -> str: """Authenticate with username/password and return JWT token.""" auth_url = f"http://{host}:{port}/auth/jwt/login" @@ -58,7 +64,9 @@ async def authenticate_with_credentials(host: str, port: int, username: str, pas raise Exception("No access token received from server") elif response.status == 400: error_detail = await response.text() - raise Exception(f"Authentication failed: Invalid credentials - {error_detail}") + raise Exception( + f"Authentication failed: Invalid credentials - {error_detail}" + ) else: error_detail = await response.text() raise Exception( @@ -96,29 +104,31 @@ def validate_auth_args(args): async def send_wyoming_event(websocket, wyoming_event): """Send a Wyoming protocol event over WebSocket. - + Based on how the backend processes Wyoming events, they expect: 1. JSON header line ending with \n 2. Optional binary payload if payload_length > 0 - + This replicates Wyoming's async_write_event behavior for WebSocket transport. """ # Get the event data from Wyoming event event_data = wyoming_event.event() - + # Build event dict like Wyoming's async_write_event does event_dict = event_data.to_dict() event_dict["version"] = "1.0.0" # Wyoming adds version - + # Add payload_length if payload exists (critical for audio chunks!) if event_data.payload: event_dict["payload_length"] = len(event_data.payload) - + # Send JSON header - json_header = json.dumps(event_dict) + '\n' + json_header = json.dumps(event_dict) + "\n" await websocket.send(json_header) - logger.debug(f"Sent Wyoming event: {event_data.type} (payload_length: {event_dict.get('payload_length', 0)})") - + logger.debug( + f"Sent Wyoming event: {event_data.type} (payload_length: {event_dict.get('payload_length', 0)})" + ) + # Send binary payload if exists if event_data.payload: await websocket.send(event_data.payload) @@ -131,11 +141,17 @@ async def main(): description="Laptop audio client for OMI backend with dual authentication modes" ) parser.add_argument("--host", default=DEFAULT_HOST, help="WebSocket server host") - parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="WebSocket server port") - parser.add_argument("--endpoint", default=DEFAULT_ENDPOINT, help="WebSocket endpoint") + parser.add_argument( + "--port", type=int, default=DEFAULT_PORT, help="WebSocket server port" + ) + parser.add_argument( + "--endpoint", default=DEFAULT_ENDPOINT, help="WebSocket endpoint" + ) # Authentication options (mutually exclusive) - auth_group = parser.add_argument_group("authentication", "Choose one authentication method") + auth_group = parser.add_argument_group( + "authentication", "Choose one authentication method" + ) auth_group.add_argument("--token", help="JWT authentication token") auth_group.add_argument("--username", help="Username for login authentication") auth_group.add_argument("--password", help="Password for login authentication") @@ -174,7 +190,9 @@ async def main(): return # Build WebSocket URI - ws_uri = build_websocket_uri(args.host, args.port, args.endpoint, token, args.device_name) + ws_uri = build_websocket_uri( + args.host, args.port, args.endpoint, token, args.device_name + ) print(f"Connecting to {ws_uri}") print(f"Using device name: {args.device_name}") @@ -190,10 +208,12 @@ async def send_audio(): audio_start = AudioStart( rate=stream.sample_rate, width=stream.sample_width, - channels=stream.channels + channels=stream.channels, ) await send_wyoming_event(websocket, audio_start) - logger.info(f"Sent audio-start event (rate={stream.sample_rate}, width={stream.sample_width}, channels={stream.channels})") + logger.info( + f"Sent audio-start event (rate={stream.sample_rate}, width={stream.sample_width}, channels={stream.channels})" + ) while True: try: data = await stream.read() @@ -203,18 +223,24 @@ async def send_audio(): audio=data.audio, rate=stream.sample_rate, width=stream.sample_width, - channels=stream.channels + channels=stream.channels, ) await send_wyoming_event(websocket, audio_chunk) - logger.debug(f"Sent audio chunk: {len(data.audio)} bytes") - await asyncio.sleep(0.01) # Small delay to prevent overwhelming + logger.debug( + f"Sent audio chunk: {len(data.audio)} bytes" + ) + await asyncio.sleep( + 0.01 + ) # Small delay to prevent overwhelming except websockets.exceptions.ConnectionClosed: - logger.info("WebSocket connection closed during audio sending") + logger.info( + "WebSocket connection closed during audio sending" + ) break except Exception as e: logger.error(f"Error sending audio: {e}") break - + except Exception as e: logger.error(f"Error in audio session: {e}") finally: diff --git a/backends/advanced/src/advanced_omi_backend/app_config.py b/backends/advanced/src/advanced_omi_backend/app_config.py index 5ed50618..d2f50d8e 100644 --- a/backends/advanced/src/advanced_omi_backend/app_config.py +++ b/backends/advanced/src/advanced_omi_backend/app_config.py @@ -57,7 +57,9 @@ def __init__(self): f"βœ… Using {self.transcription_provider.name} transcription provider ({self.transcription_provider.mode})" ) else: - logger.warning("⚠️ No transcription provider configured - speech-to-text will not be available") + logger.warning( + "⚠️ No transcription provider configured - speech-to-text will not be available" + ) # External Services Configuration self.qdrant_base_url = os.getenv("QDRANT_BASE_URL", "qdrant") @@ -73,7 +75,9 @@ def __init__(self): # CORS Configuration default_origins = "http://localhost:3000,http://localhost:3001,http://127.0.0.1:3000,http://127.0.0.1:3002" self.cors_origins = os.getenv("CORS_ORIGINS", default_origins) - self.allowed_origins = [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()] + self.allowed_origins = [ + origin.strip() for origin in self.cors_origins.split(",") if origin.strip() + ] # Tailscale support self.tailscale_regex = r"http://100\.\d{1,3}\.\d{1,3}\.\d{1,3}:3000" @@ -105,15 +109,11 @@ def get_audio_chunk_dir() -> Path: def get_mongo_collections(): """Get MongoDB collections.""" return { - 'users': app_config.users_col, - 'speakers': app_config.speakers_col, + "users": app_config.users_col, + "speakers": app_config.speakers_col, } def get_redis_config(): """Get Redis configuration.""" - return { - 'url': app_config.redis_url, - 'encoding': "utf-8", - 'decode_responses': False - } + return {"url": app_config.redis_url, "encoding": "utf-8", "decode_responses": False} diff --git a/backends/advanced/src/advanced_omi_backend/auth.py b/backends/advanced/src/advanced_omi_backend/auth.py index c0d0a7b5..3bd6ea88 100644 --- a/backends/advanced/src/advanced_omi_backend/auth.py +++ b/backends/advanced/src/advanced_omi_backend/auth.py @@ -53,12 +53,13 @@ def _verify_configured(var_name: str, *, optional: bool = False) -> Optional[str # Accepted token issuers - comma-separated list of services whose tokens we accept # Default: "chronicle,ushadow" (accept tokens from both chronicle and ushadow) ACCEPTED_ISSUERS = [ - iss.strip() - for iss in os.getenv("ACCEPTED_TOKEN_ISSUERS", "chronicle,ushadow").split(",") + iss.strip() + for iss in os.getenv("ACCEPTED_TOKEN_ISSUERS", "chronicle,ushadow").split(",") if iss.strip() ] logger.info(f"Accepting tokens from issuers: {ACCEPTED_ISSUERS}") + class UserManager(BaseUserManager[User, PydanticObjectId]): """User manager with minimal customization for fastapi-users.""" @@ -108,8 +109,9 @@ async def get_user_manager(user_db=Depends(get_user_db)): def get_jwt_strategy() -> JWTStrategy: """Get JWT strategy for token generation and validation.""" return JWTStrategy( - secret=SECRET_KEY, lifetime_seconds=JWT_LIFETIME_SECONDS, - token_audience=["fastapi-users:auth"] + ACCEPTED_ISSUERS + secret=SECRET_KEY, + lifetime_seconds=JWT_LIFETIME_SECONDS, + token_audience=["fastapi-users:auth"] + ACCEPTED_ISSUERS, ) @@ -220,7 +222,9 @@ async def create_admin_user_if_needed(): existing_admin = await user_db.get_by_email(ADMIN_EMAIL) if existing_admin: - logger.debug(f"existing_admin.id = {existing_admin.id}, type = {type(existing_admin.id)}") + logger.debug( + f"existing_admin.id = {existing_admin.id}, type = {type(existing_admin.id)}" + ) logger.debug(f"str(existing_admin.id) = {str(existing_admin.id)}") logger.debug(f"existing_admin.user_id = {existing_admin.user_id}") logger.info( @@ -258,25 +262,39 @@ async def websocket_auth(websocket, token: Optional[str] = None) -> Optional[Use # Try JWT token from query parameter first if token: - logger.info(f"Attempting WebSocket auth with query token (first 20 chars): {token[:20]}...") + logger.info( + f"Attempting WebSocket auth with query token (first 20 chars): {token[:20]}..." + ) try: user_db_gen = get_user_db() user_db = await user_db_gen.__anext__() user_manager = UserManager(user_db) user = await strategy.read_token(token, user_manager) if user and user.is_active: - logger.info(f"WebSocket auth successful for user {user.user_id} using query token.") + logger.info( + f"WebSocket auth successful for user {user.user_id} using query token." + ) return user else: - logger.warning(f"Token validated but user inactive or not found: user={user}") + logger.warning( + f"Token validated but user inactive or not found: user={user}" + ) except Exception as e: - logger.error(f"WebSocket auth with query token failed: {type(e).__name__}: {e}", exc_info=True) + logger.error( + f"WebSocket auth with query token failed: {type(e).__name__}: {e}", + exc_info=True, + ) # Try cookie authentication logger.debug("Attempting WebSocket auth with cookie.") try: cookie_header = next( - (v.decode() for k, v in websocket.headers.items() if k.lower() == b"cookie"), None + ( + v.decode() + for k, v in websocket.headers.items() + if k.lower() == b"cookie" + ), + None, ) if cookie_header: match = re.search(r"fastapiusersauth=([^;]+)", cookie_header) @@ -286,7 +304,9 @@ async def websocket_auth(websocket, token: Optional[str] = None) -> Optional[Use user_manager = UserManager(user_db) user = await strategy.read_token(match.group(1), user_manager) if user and user.is_active: - logger.info(f"WebSocket auth successful for user {user.user_id} using cookie.") + logger.info( + f"WebSocket auth successful for user {user.user_id} using cookie." + ) return user except Exception as e: logger.warning(f"WebSocket auth with cookie failed: {e}") diff --git a/backends/advanced/src/advanced_omi_backend/chat_service.py b/backends/advanced/src/advanced_omi_backend/chat_service.py index 46b734a9..d887156e 100644 --- a/backends/advanced/src/advanced_omi_backend/chat_service.py +++ b/backends/advanced/src/advanced_omi_backend/chat_service.py @@ -40,7 +40,7 @@ class ChatMessage: """Represents a chat message.""" - + def __init__( self, message_id: str, @@ -91,7 +91,7 @@ def from_dict(cls, data: Dict) -> "ChatMessage": class ChatSession: """Represents a chat session.""" - + def __init__( self, session_id: str, @@ -152,11 +152,13 @@ async def _get_system_prompt(self) -> str: """ try: reg = get_models_registry() - if reg and hasattr(reg, 'chat'): + if reg and hasattr(reg, "chat"): chat_config = reg.chat - prompt = chat_config.get('system_prompt') + prompt = chat_config.get("system_prompt") if prompt: - logger.info(f"βœ… Loaded chat system prompt from config (length: {len(prompt)} chars)") + logger.info( + f"βœ… Loaded chat system prompt from config (length: {len(prompt)} chars)" + ) logger.debug(f"System prompt: {prompt[:100]}...") return prompt except Exception as e: @@ -193,9 +195,15 @@ async def initialize(self): self.messages_collection = self.db["chat_messages"] # Create indexes for better performance - await self.sessions_collection.create_index([("user_id", 1), ("updated_at", -1)]) - await self.messages_collection.create_index([("session_id", 1), ("timestamp", 1)]) - await self.messages_collection.create_index([("user_id", 1), ("timestamp", -1)]) + await self.sessions_collection.create_index( + [("user_id", 1), ("updated_at", -1)] + ) + await self.messages_collection.create_index( + [("session_id", 1), ("timestamp", 1)] + ) + await self.messages_collection.create_index( + [("user_id", 1), ("timestamp", -1)] + ) # Initialize LLM client and memory service self.llm_client = get_llm_client() @@ -214,23 +222,25 @@ async def create_session(self, user_id: str, title: str = None) -> ChatSession: await self.initialize() session = ChatSession( - session_id=str(uuid4()), - user_id=user_id, - title=title or "New Chat" + session_id=str(uuid4()), user_id=user_id, title=title or "New Chat" ) await self.sessions_collection.insert_one(session.to_dict()) logger.info(f"Created new chat session {session.session_id} for user {user_id}") return session - async def get_user_sessions(self, user_id: str, limit: int = 50) -> List[ChatSession]: + async def get_user_sessions( + self, user_id: str, limit: int = 50 + ) -> List[ChatSession]: """Get all chat sessions for a user.""" if not self._initialized: await self.initialize() - cursor = self.sessions_collection.find( - {"user_id": user_id} - ).sort("updated_at", -1).limit(limit) + cursor = ( + self.sessions_collection.find({"user_id": user_id}) + .sort("updated_at", -1) + .limit(limit) + ) sessions = [] async for doc in cursor: @@ -243,10 +253,9 @@ async def get_session(self, session_id: str, user_id: str) -> Optional[ChatSessi if not self._initialized: await self.initialize() - doc = await self.sessions_collection.find_one({ - "session_id": session_id, - "user_id": user_id - }) + doc = await self.sessions_collection.find_one( + {"session_id": session_id, "user_id": user_id} + ) if doc: return ChatSession.from_dict(doc) @@ -258,16 +267,14 @@ async def delete_session(self, session_id: str, user_id: str) -> bool: await self.initialize() # Delete all messages in the session - await self.messages_collection.delete_many({ - "session_id": session_id, - "user_id": user_id - }) + await self.messages_collection.delete_many( + {"session_id": session_id, "user_id": user_id} + ) # Delete the session - result = await self.sessions_collection.delete_one({ - "session_id": session_id, - "user_id": user_id - }) + result = await self.sessions_collection.delete_one( + {"session_id": session_id, "user_id": user_id} + ) success = result.deleted_count > 0 if success: @@ -281,10 +288,13 @@ async def get_session_messages( if not self._initialized: await self.initialize() - cursor = self.messages_collection.find({ - "session_id": session_id, - "user_id": user_id - }).sort("timestamp", 1).limit(limit) + cursor = ( + self.messages_collection.find( + {"session_id": session_id, "user_id": user_id} + ) + .sort("timestamp", 1) + .limit(limit) + ) messages = [] async for doc in cursor: @@ -299,10 +309,10 @@ async def add_message(self, message: ChatMessage) -> bool: try: await self.messages_collection.insert_one(message.to_dict()) - + # Update session timestamp and title if needed update_data = {"updated_at": message.timestamp} - + # Auto-generate title from first user message if session has default title if message.role == "user": session = await self.get_session(message.session_id, message.user_id) @@ -315,35 +325,43 @@ async def add_message(self, message: ChatMessage) -> bool: await self.sessions_collection.update_one( {"session_id": message.session_id, "user_id": message.user_id}, - {"$set": update_data} + {"$set": update_data}, ) - + return True except Exception as e: logger.error(f"Failed to add message to session {message.session_id}: {e}") return False - async def get_relevant_memories(self, query: str, user_id: str) -> List[MemoryEntry]: + async def get_relevant_memories( + self, query: str, user_id: str + ) -> List[MemoryEntry]: """Get relevant memories for the user's query.""" try: memories = await self.memory_service.search_memories( - query=query, - user_id=user_id, - limit=MAX_MEMORY_CONTEXT + query=query, user_id=user_id, limit=MAX_MEMORY_CONTEXT + ) + logger.info( + f"Retrieved {len(memories)} relevant memories for query: {query[:50]}..." ) - logger.info(f"Retrieved {len(memories)} relevant memories for query: {query[:50]}...") return memories except Exception as e: logger.error(f"Failed to retrieve memories for user {user_id}: {e}") return [] async def format_conversation_context( - self, session_id: str, user_id: str, current_message: str, include_obsidian_memory: bool = False + self, + session_id: str, + user_id: str, + current_message: str, + include_obsidian_memory: bool = False, ) -> Tuple[str, List[str]]: """Format conversation context with memory integration.""" # Get recent conversation history - messages = await self.get_session_messages(session_id, user_id, MAX_CONVERSATION_HISTORY) - + messages = await self.get_session_messages( + session_id, user_id, MAX_CONVERSATION_HISTORY + ) + # Get relevant memories memories = await self.get_relevant_memories(current_message, user_id) memory_ids = [memory.id for memory in memories if memory.id] @@ -364,14 +382,18 @@ async def format_conversation_context( if include_obsidian_memory: try: obsidian_service = get_obsidian_service() - obsidian_result = await obsidian_service.search_obsidian(current_message) + obsidian_result = await obsidian_service.search_obsidian( + current_message + ) obsidian_context = obsidian_result["results"] if obsidian_context: context_parts.append("# Relevant Obsidian Notes:") for entry in obsidian_context: context_parts.append(entry) context_parts.append("") - logger.info(f"Added {len(obsidian_context)} Obsidian notes to context") + logger.info( + f"Added {len(obsidian_context)} Obsidian notes to context" + ) except ObsidianSearchError as exc: logger.error( "Failed to get Obsidian context (%s stage): %s", @@ -399,7 +421,11 @@ async def format_conversation_context( return context, memory_ids async def generate_response_stream( - self, session_id: str, user_id: str, message_content: str, include_obsidian_memory: bool = False + self, + session_id: str, + user_id: str, + message_content: str, + include_obsidian_memory: bool = False, ) -> AsyncGenerator[Dict, None]: """Generate streaming response with memory context.""" if not self._initialized: @@ -412,23 +438,23 @@ async def generate_response_stream( session_id=session_id, user_id=user_id, role="user", - content=message_content + content=message_content, ) await self.add_message(user_message) # Format context with memories context, memory_ids = await self.format_conversation_context( - session_id, user_id, message_content, include_obsidian_memory=include_obsidian_memory + session_id, + user_id, + message_content, + include_obsidian_memory=include_obsidian_memory, ) # Send memory context used yield { "type": "memory_context", - "data": { - "memory_ids": memory_ids, - "memory_count": len(memory_ids) - }, - "timestamp": time.time() + "data": {"memory_ids": memory_ids, "memory_count": len(memory_ids)}, + "timestamp": time.time(), } # Get system prompt from config @@ -438,8 +464,10 @@ async def generate_response_stream( full_prompt = f"{system_prompt}\n\n{context}" # Generate streaming response - logger.info(f"Generating response for session {session_id} with {len(memory_ids)} memories") - + logger.info( + f"Generating response for session {session_id} with {len(memory_ids)} memories" + ) + # Resolve chat operation temperature from config chat_temp = None registry = get_models_registry() @@ -457,16 +485,16 @@ async def generate_response_stream( # Simulate streaming by yielding chunks words = response_content.split() current_text = "" - + for i, word in enumerate(words): current_text += word + " " - + # Yield every few words to simulate streaming if i % 3 == 0 or i == len(words) - 1: yield { "type": "token", "data": current_text.strip(), - "timestamp": time.time() + "timestamp": time.time(), } await asyncio.sleep(0.05) # Small delay for realistic streaming @@ -477,7 +505,7 @@ async def generate_response_stream( user_id=user_id, role="assistant", content=response_content.strip(), - memories_used=memory_ids + memories_used=memory_ids, ) await self.add_message(assistant_message) @@ -486,20 +514,18 @@ async def generate_response_stream( "type": "complete", "data": { "message_id": assistant_message.message_id, - "memories_used": memory_ids + "memories_used": memory_ids, }, - "timestamp": time.time() + "timestamp": time.time(), } except Exception as e: logger.error(f"Error generating response for session {session_id}: {e}") - yield { - "type": "error", - "data": {"error": str(e)}, - "timestamp": time.time() - } + yield {"type": "error", "data": {"error": str(e)}, "timestamp": time.time()} - async def update_session_title(self, session_id: str, user_id: str, title: str) -> bool: + async def update_session_title( + self, session_id: str, user_id: str, title: str + ) -> bool: """Update a session's title.""" if not self._initialized: await self.initialize() @@ -507,7 +533,7 @@ async def update_session_title(self, session_id: str, user_id: str, title: str) try: result = await self.sessions_collection.update_one( {"session_id": session_id, "user_id": user_id}, - {"$set": {"title": title, "updated_at": datetime.utcnow()}} + {"$set": {"title": title, "updated_at": datetime.utcnow()}}, ) return result.modified_count > 0 except Exception as e: @@ -521,33 +547,38 @@ async def get_chat_statistics(self, user_id: str) -> Dict: try: # Count sessions - session_count = await self.sessions_collection.count_documents({"user_id": user_id}) - + session_count = await self.sessions_collection.count_documents( + {"user_id": user_id} + ) + # Count messages - message_count = await self.messages_collection.count_documents({"user_id": user_id}) - + message_count = await self.messages_collection.count_documents( + {"user_id": user_id} + ) + # Get most recent session latest_session = await self.sessions_collection.find_one( - {"user_id": user_id}, - sort=[("updated_at", -1)] + {"user_id": user_id}, sort=[("updated_at", -1)] ) - + return { "total_sessions": session_count, "total_messages": message_count, - "last_chat": latest_session["updated_at"] if latest_session else None + "last_chat": latest_session["updated_at"] if latest_session else None, } except Exception as e: logger.error(f"Failed to get chat statistics for user {user_id}: {e}") return {"total_sessions": 0, "total_messages": 0, "last_chat": None} - async def extract_memories_from_session(self, session_id: str, user_id: str) -> Tuple[bool, List[str], int]: + async def extract_memories_from_session( + self, session_id: str, user_id: str + ) -> Tuple[bool, List[str], int]: """Extract and store memories from a chat session. - + Args: session_id: ID of the chat session to extract memories from user_id: User ID for authorization and memory scoping - + Returns: Tuple of (success: bool, memory_ids: List[str], memory_count: int) """ @@ -556,20 +587,23 @@ async def extract_memories_from_session(self, session_id: str, user_id: str) -> try: # Verify session belongs to user - session = await self.sessions_collection.find_one({ - "session_id": session_id, - "user_id": user_id - }) - + session = await self.sessions_collection.find_one( + {"session_id": session_id, "user_id": user_id} + ) + if not session: logger.error(f"Session {session_id} not found for user {user_id}") return False, [], 0 # Get all messages from the session messages = await self.get_session_messages(session_id, user_id) - - if not messages or len(messages) < 2: # Need at least user + assistant message - logger.info(f"Not enough messages in session {session_id} for memory extraction") + + if ( + not messages or len(messages) < 2 + ): # Need at least user + assistant message + logger.info( + f"Not enough messages in session {session_id} for memory extraction" + ) return True, [], 0 # Format messages as a transcript @@ -577,12 +611,12 @@ async def extract_memories_from_session(self, session_id: str, user_id: str) -> for message in messages: role = "User" if message.role == "user" else "Assistant" transcript_parts.append(f"{role}: {message.content}") - + transcript = "\n".join(transcript_parts) - + # Get user email for memory service user_email = session.get("user_email", f"user_{user_id}") - + # Extract memories using the memory service success, memory_ids = await self.memory_service.add_memory( transcript=transcript, @@ -590,16 +624,20 @@ async def extract_memories_from_session(self, session_id: str, user_id: str) -> source_id=f"chat_{session_id}", user_id=user_id, user_email=user_email, - allow_update=True # Allow deduplication and updates + allow_update=True, # Allow deduplication and updates ) - + if success: - logger.info(f"βœ… Extracted {len(memory_ids)} memories from chat session {session_id}") + logger.info( + f"βœ… Extracted {len(memory_ids)} memories from chat session {session_id}" + ) return True, memory_ids, len(memory_ids) else: - logger.error(f"❌ Failed to extract memories from chat session {session_id}") + logger.error( + f"❌ Failed to extract memories from chat session {session_id}" + ) return False, [], 0 - + except Exception as e: logger.error(f"Failed to extract memories from session {session_id}: {e}") return False, [], 0 diff --git a/backends/advanced/src/advanced_omi_backend/client.py b/backends/advanced/src/advanced_omi_backend/client.py index 79ee2957..30eed2bc 100644 --- a/backends/advanced/src/advanced_omi_backend/client.py +++ b/backends/advanced/src/advanced_omi_backend/client.py @@ -18,7 +18,9 @@ audio_logger = logging.getLogger("audio_processing") # Configuration constants -NEW_CONVERSATION_TIMEOUT_MINUTES = float(os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5")) +NEW_CONVERSATION_TIMEOUT_MINUTES = float( + os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5") +) class ClientState: @@ -99,7 +101,9 @@ def record_speech_end(self, audio_uuid: str, timestamp: float): f"(duration: {duration:.3f}s)" ) else: - audio_logger.warning(f"Speech end recorded for {audio_uuid} but no start time found") + audio_logger.warning( + f"Speech end recorded for {audio_uuid} but no start time found" + ) def update_transcript_received(self): """Update timestamp when transcript is received (for timeout detection).""" diff --git a/backends/advanced/src/advanced_omi_backend/client_manager.py b/backends/advanced/src/advanced_omi_backend/client_manager.py index 68fd6ef8..2e90a4de 100644 --- a/backends/advanced/src/advanced_omi_backend/client_manager.py +++ b/backends/advanced/src/advanced_omi_backend/client_manager.py @@ -98,7 +98,9 @@ def get_client_count(self) -> int: """ return len(self._active_clients) - def create_client(self, client_id: str, chunk_dir, user_id: str, user_email: Optional[str] = None) -> "ClientState": + def create_client( + self, client_id: str, chunk_dir, user_id: str, user_email: Optional[str] = None + ) -> "ClientState": """ Atomically create and register a new client. @@ -327,12 +329,16 @@ def unregister_client_user_mapping(client_id: str): if client_id in _client_to_user_mapping: user_id = _client_to_user_mapping.pop(client_id) logger.info(f"❌ Unregistered active client {client_id} from user {user_id}") - logger.info(f"πŸ“Š Active client mappings: {len(_client_to_user_mapping)} remaining") + logger.info( + f"πŸ“Š Active client mappings: {len(_client_to_user_mapping)} remaining" + ) else: logger.warning(f"⚠️ Attempted to unregister non-existent client {client_id}") -async def track_client_user_relationship_async(client_id: str, user_id: str, ttl: int = 86400): +async def track_client_user_relationship_async( + client_id: str, user_id: str, ttl: int = 86400 +): """ Track that a client belongs to a user (async, writes to Redis for cross-container support). @@ -346,11 +352,15 @@ async def track_client_user_relationship_async(client_id: str, user_id: str, ttl if _redis_client: try: await _redis_client.setex(f"client:owner:{client_id}", ttl, user_id) - logger.debug(f"βœ… Tracked client {client_id} β†’ user {user_id} in Redis (TTL: {ttl}s)") + logger.debug( + f"βœ… Tracked client {client_id} β†’ user {user_id} in Redis (TTL: {ttl}s)" + ) except Exception as e: logger.warning(f"Failed to track client in Redis: {e}") else: - logger.debug(f"Tracked client {client_id} relationship to user {user_id} (in-memory only)") + logger.debug( + f"Tracked client {client_id} relationship to user {user_id} (in-memory only)" + ) def track_client_user_relationship(client_id: str, user_id: str): @@ -424,7 +434,9 @@ def get_user_clients_active(user_id: str) -> list[str]: if mapped_user_id == user_id ] - logger.debug(f"πŸ” Found {len(user_clients)} active clients for user {user_id}: {user_clients}") + logger.debug( + f"πŸ” Found {len(user_clients)} active clients for user {user_id}: {user_clients}" + ) return user_clients @@ -515,7 +527,9 @@ def generate_client_id(user: "User", device_name: Optional[str] = None) -> str: if device_name: # Sanitize device name: lowercase, alphanumeric + hyphens only, max 10 chars - sanitized_device = "".join(c for c in device_name.lower() if c.isalnum() or c == "-")[:10] + sanitized_device = "".join( + c for c in device_name.lower() if c.isalnum() or c == "-" + )[:10] base_client_id = f"{user_id_suffix}-{sanitized_device}" # Check for existing client IDs in database diff --git a/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py b/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py index 9d93d884..39b63f87 100644 --- a/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py +++ b/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py @@ -7,6 +7,7 @@ _drive_client_cache = None + def get_google_drive_client(): """Singleton Google Drive client.""" global _drive_client_cache @@ -22,8 +23,7 @@ def get_google_drive_client(): ) creds = Credentials.from_service_account_file( - config.gdrive_credentials_path, - scopes=config.gdrive_scopes + config.gdrive_credentials_path, scopes=config.gdrive_scopes ) _drive_client_cache = build("drive", "v3", credentials=creds) diff --git a/backends/advanced/src/advanced_omi_backend/config.py b/backends/advanced/src/advanced_omi_backend/config.py index 4286492a..4d96eb6f 100644 --- a/backends/advanced/src/advanced_omi_backend/config.py +++ b/backends/advanced/src/advanced_omi_backend/config.py @@ -19,9 +19,7 @@ load_config, ) from advanced_omi_backend.config_loader import reload_config as reload_omegaconf_config -from advanced_omi_backend.config_loader import ( - save_config_section, -) +from advanced_omi_backend.config_loader import save_config_section logger = logging.getLogger(__name__) @@ -34,6 +32,7 @@ # Configuration Functions (OmegaConf-based) # ============================================================================ + def get_config_yml_path() -> Path: """ Get path to config.yml file. @@ -43,6 +42,7 @@ def get_config_yml_path() -> Path: """ return get_config_dir() / "config.yml" + def get_config(force_reload: bool = False) -> dict: """ Get merged configuration using OmegaConf. @@ -68,6 +68,7 @@ def reload_config(): # Diarization Settings (OmegaConf-based) # ============================================================================ + def get_diarization_settings() -> dict: """ Get diarization settings using OmegaConf. @@ -75,7 +76,7 @@ def get_diarization_settings() -> dict: Returns: Dict with diarization configuration (resolved from YAML + env vars) """ - cfg = get_backend_config('diarization') + cfg = get_backend_config("diarization") return OmegaConf.to_container(cfg, resolve=True) @@ -89,16 +90,18 @@ def save_diarization_settings(settings: dict) -> bool: Returns: True if saved successfully, False otherwise """ - return save_config_section('backend.diarization', settings) + return save_config_section("backend.diarization", settings) # ============================================================================ # Cleanup Settings (OmegaConf-based) # ============================================================================ + @dataclass class CleanupSettings: """Cleanup configuration for soft-deleted conversations.""" + auto_cleanup_enabled: bool = False retention_days: int = 30 @@ -110,7 +113,7 @@ def get_cleanup_settings() -> dict: Returns: Dict with auto_cleanup_enabled and retention_days """ - cfg = get_backend_config('cleanup') + cfg = get_backend_config("cleanup") return OmegaConf.to_container(cfg, resolve=True) @@ -125,13 +128,15 @@ def save_cleanup_settings(settings: CleanupSettings) -> bool: True if saved successfully, False otherwise """ from dataclasses import asdict - return save_config_section('backend.cleanup', asdict(settings)) + + return save_config_section("backend.cleanup", asdict(settings)) # ============================================================================ # Speech Detection Settings (OmegaConf-based) # ============================================================================ + def get_speech_detection_settings() -> dict: """ Get speech detection settings using OmegaConf. @@ -139,7 +144,7 @@ def get_speech_detection_settings() -> dict: Returns: Dict with min_words, min_confidence, min_duration """ - cfg = get_backend_config('speech_detection') + cfg = get_backend_config("speech_detection") return OmegaConf.to_container(cfg, resolve=True) @@ -147,6 +152,7 @@ def get_speech_detection_settings() -> dict: # Conversation Stop Settings (OmegaConf-based) # ============================================================================ + def get_conversation_stop_settings() -> dict: """ Get conversation stop settings using OmegaConf. @@ -154,12 +160,14 @@ def get_conversation_stop_settings() -> dict: Returns: Dict with transcription_buffer_seconds, speech_inactivity_threshold """ - cfg = get_backend_config('conversation_stop') + cfg = get_backend_config("conversation_stop") settings = OmegaConf.to_container(cfg, resolve=True) # Add min_word_confidence from speech_detection for backward compatibility - speech_cfg = get_backend_config('speech_detection') - settings['min_word_confidence'] = OmegaConf.to_container(speech_cfg, resolve=True).get('min_confidence', 0.7) + speech_cfg = get_backend_config("speech_detection") + settings["min_word_confidence"] = OmegaConf.to_container( + speech_cfg, resolve=True + ).get("min_confidence", 0.7) return settings @@ -168,6 +176,7 @@ def get_conversation_stop_settings() -> dict: # Audio Storage Settings (OmegaConf-based) # ============================================================================ + def get_audio_storage_settings() -> dict: """ Get audio storage settings using OmegaConf. @@ -175,7 +184,7 @@ def get_audio_storage_settings() -> dict: Returns: Dict with audio_base_path, audio_chunks_path """ - cfg = get_backend_config('audio_storage') + cfg = get_backend_config("audio_storage") return OmegaConf.to_container(cfg, resolve=True) @@ -183,6 +192,7 @@ def get_audio_storage_settings() -> dict: # Transcription Job Timeout (OmegaConf-based) # ============================================================================ + def get_transcription_job_timeout() -> int: """ Get transcription job timeout in seconds from config. @@ -190,15 +200,16 @@ def get_transcription_job_timeout() -> int: Returns: Job timeout in seconds (default 900 = 15 minutes) """ - cfg = get_backend_config('transcription') + cfg = get_backend_config("transcription") settings = OmegaConf.to_container(cfg, resolve=True) if cfg else {} - return int(settings.get('job_timeout_seconds', 900)) + return int(settings.get("job_timeout_seconds", 900)) # ============================================================================ # Miscellaneous Settings (OmegaConf-based) # ============================================================================ + def get_misc_settings() -> dict: """ Get miscellaneous configuration settings using OmegaConf. @@ -207,23 +218,37 @@ def get_misc_settings() -> dict: Dict with always_persist_enabled and use_provider_segments """ # Get audio settings for always_persist_enabled - audio_cfg = get_backend_config('audio') - audio_settings = OmegaConf.to_container(audio_cfg, resolve=True) if audio_cfg else {} + audio_cfg = get_backend_config("audio") + audio_settings = ( + OmegaConf.to_container(audio_cfg, resolve=True) if audio_cfg else {} + ) # Get transcription settings for use_provider_segments - transcription_cfg = get_backend_config('transcription') - transcription_settings = OmegaConf.to_container(transcription_cfg, resolve=True) if transcription_cfg else {} + transcription_cfg = get_backend_config("transcription") + transcription_settings = ( + OmegaConf.to_container(transcription_cfg, resolve=True) + if transcription_cfg + else {} + ) # Get speaker recognition settings for per_segment_speaker_id - speaker_cfg = get_backend_config('speaker_recognition') - speaker_settings = OmegaConf.to_container(speaker_cfg, resolve=True) if speaker_cfg else {} + speaker_cfg = get_backend_config("speaker_recognition") + speaker_settings = ( + OmegaConf.to_container(speaker_cfg, resolve=True) if speaker_cfg else {} + ) return { - 'always_persist_enabled': audio_settings.get('always_persist_enabled', False), - 'use_provider_segments': transcription_settings.get('use_provider_segments', False), - 'per_segment_speaker_id': speaker_settings.get('per_segment_speaker_id', False), - 'transcription_job_timeout_seconds': int(transcription_settings.get('job_timeout_seconds', 900)), - 'always_batch_retranscribe': transcription_settings.get('always_batch_retranscribe', False), + "always_persist_enabled": audio_settings.get("always_persist_enabled", False), + "use_provider_segments": transcription_settings.get( + "use_provider_segments", False + ), + "per_segment_speaker_id": speaker_settings.get("per_segment_speaker_id", False), + "transcription_job_timeout_seconds": int( + transcription_settings.get("job_timeout_seconds", 900) + ), + "always_batch_retranscribe": transcription_settings.get( + "always_batch_retranscribe", False + ), } @@ -240,33 +265,41 @@ def save_misc_settings(settings: dict) -> bool: success = True # Save audio settings if always_persist_enabled is provided - if 'always_persist_enabled' in settings: - audio_settings = {'always_persist_enabled': settings['always_persist_enabled']} - if not save_config_section('backend.audio', audio_settings): + if "always_persist_enabled" in settings: + audio_settings = {"always_persist_enabled": settings["always_persist_enabled"]} + if not save_config_section("backend.audio", audio_settings): success = False # Save transcription settings if use_provider_segments is provided - if 'use_provider_segments' in settings: - transcription_settings = {'use_provider_segments': settings['use_provider_segments']} - if not save_config_section('backend.transcription', transcription_settings): + if "use_provider_segments" in settings: + transcription_settings = { + "use_provider_segments": settings["use_provider_segments"] + } + if not save_config_section("backend.transcription", transcription_settings): success = False # Save speaker recognition settings if per_segment_speaker_id is provided - if 'per_segment_speaker_id' in settings: - speaker_settings = {'per_segment_speaker_id': settings['per_segment_speaker_id']} - if not save_config_section('backend.speaker_recognition', speaker_settings): + if "per_segment_speaker_id" in settings: + speaker_settings = { + "per_segment_speaker_id": settings["per_segment_speaker_id"] + } + if not save_config_section("backend.speaker_recognition", speaker_settings): success = False # Save transcription job timeout if provided - if 'transcription_job_timeout_seconds' in settings: - timeout_settings = {'job_timeout_seconds': settings['transcription_job_timeout_seconds']} - if not save_config_section('backend.transcription', timeout_settings): + if "transcription_job_timeout_seconds" in settings: + timeout_settings = { + "job_timeout_seconds": settings["transcription_job_timeout_seconds"] + } + if not save_config_section("backend.transcription", timeout_settings): success = False # Save always_batch_retranscribe if provided - if 'always_batch_retranscribe' in settings: - batch_settings = {'always_batch_retranscribe': settings['always_batch_retranscribe']} - if not save_config_section('backend.transcription', batch_settings): + if "always_batch_retranscribe" in settings: + batch_settings = { + "always_batch_retranscribe": settings["always_batch_retranscribe"] + } + if not save_config_section("backend.transcription", batch_settings): success = False - return success \ No newline at end of file + return success diff --git a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py index ba434229..f94ff320 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py @@ -12,6 +12,9 @@ import time import uuid +from fastapi import UploadFile +from fastapi.responses import JSONResponse + from advanced_omi_backend.config import get_transcription_job_timeout from advanced_omi_backend.controllers.queue_controller import ( JOB_RESULT_TTL, @@ -29,11 +32,7 @@ convert_any_to_wav, validate_and_prepare_audio, ) -from advanced_omi_backend.workers.transcription_jobs import ( - transcribe_full_audio_job, -) -from fastapi import UploadFile -from fastapi.responses import JSONResponse +from advanced_omi_backend.workers.transcription_jobs import transcribe_full_audio_job logger = logging.getLogger(__name__) audio_logger = logging.getLogger("audio_processing") @@ -50,7 +49,7 @@ async def upload_and_process_audio_files( user: User, files: list[UploadFile], device_name: str = "upload", - source: str = "upload" + source: str = "upload", ) -> dict: """ Upload audio files and process them directly. @@ -80,11 +79,13 @@ async def upload_and_process_audio_files( _, ext = os.path.splitext(filename.lower()) if not ext or ext not in SUPPORTED_AUDIO_EXTENSIONS: supported = ", ".join(sorted(SUPPORTED_AUDIO_EXTENSIONS)) - processed_files.append({ - "filename": filename, - "status": "error", - "error": f"Unsupported format '{ext}'. Supported: {supported}", - }) + processed_files.append( + { + "filename": filename, + "status": "error", + "error": f"Unsupported format '{ext}'. Supported: {supported}", + } + ) continue is_video_source = ext in VIDEO_EXTENSIONS @@ -101,37 +102,47 @@ async def upload_and_process_audio_files( try: content = await convert_any_to_wav(content, ext) except AudioValidationError as e: - processed_files.append({ - "filename": filename, - "status": "error", - "error": str(e), - }) + processed_files.append( + { + "filename": filename, + "status": "error", + "error": str(e), + } + ) continue # Track external source for deduplication (Google Drive, etc.) external_source_id = None external_source_type = None if source == "gdrive": - external_source_id = getattr(file, "file_id", None) or getattr(file, "audio_uuid", None) + external_source_id = getattr(file, "file_id", None) or getattr( + file, "audio_uuid", None + ) external_source_type = "gdrive" if not external_source_id: - audio_logger.warning(f"Missing file_id for gdrive file: {filename}") + audio_logger.warning( + f"Missing file_id for gdrive file: {filename}" + ) timestamp = int(time.time() * 1000) # Validate and prepare audio (read format from WAV file) try: - audio_data, sample_rate, sample_width, channels, duration = await validate_and_prepare_audio( - audio_data=content, - expected_sample_rate=16000, # Expecting 16kHz - convert_to_mono=True, # Convert stereo to mono - auto_resample=True # Auto-resample if sample rate doesn't match + audio_data, sample_rate, sample_width, channels, duration = ( + await validate_and_prepare_audio( + audio_data=content, + expected_sample_rate=16000, # Expecting 16kHz + convert_to_mono=True, # Convert stereo to mono + auto_resample=True, # Auto-resample if sample rate doesn't match + ) ) except AudioValidationError as e: - processed_files.append({ - "filename": filename, - "status": "error", - "error": str(e), - }) + processed_files.append( + { + "filename": filename, + "status": "error", + "error": str(e), + } + ) continue audio_logger.info( @@ -139,7 +150,11 @@ async def upload_and_process_audio_files( ) # Generate title from filename - title = filename.rsplit('.', 1)[0][:50] if filename != "unknown" else "Uploaded Audio" + title = ( + filename.rsplit(".", 1)[0][:50] + if filename != "unknown" + else "Uploaded Audio" + ) conversation = create_conversation( user_id=user.user_id, @@ -150,9 +165,13 @@ async def upload_and_process_audio_files( external_source_type=external_source_type, ) await conversation.insert() - conversation_id = conversation.conversation_id # Get the auto-generated ID + conversation_id = ( + conversation.conversation_id + ) # Get the auto-generated ID - audio_logger.info(f"πŸ“ Created conversation {conversation_id} for uploaded file") + audio_logger.info( + f"πŸ“ Created conversation {conversation_id} for uploaded file" + ) # Convert audio directly to MongoDB chunks try: @@ -170,24 +189,28 @@ async def upload_and_process_audio_files( except ValueError as val_error: # Handle validation errors (e.g., file too long) audio_logger.error(f"Audio validation failed: {val_error}") - processed_files.append({ - "filename": filename, - "status": "error", - "error": str(val_error), - }) + processed_files.append( + { + "filename": filename, + "status": "error", + "error": str(val_error), + } + ) # Delete the conversation since it won't have audio chunks await conversation.delete() continue except Exception as chunk_error: audio_logger.error( f"Failed to convert uploaded file to chunks: {chunk_error}", - exc_info=True + exc_info=True, + ) + processed_files.append( + { + "filename": filename, + "status": "error", + "error": f"Audio conversion failed: {str(chunk_error)}", + } ) - processed_files.append({ - "filename": filename, - "status": "error", - "error": f"Audio conversion failed: {str(chunk_error)}", - }) # Delete the conversation since it won't have audio chunks await conversation.delete() continue @@ -208,9 +231,14 @@ async def upload_and_process_audio_files( result_ttl=JOB_RESULT_TTL, job_id=transcribe_job_id, description=f"Transcribe uploaded file {conversation_id[:8]}", - meta={'conversation_id': conversation_id, 'client_id': client_id} + meta={ + "conversation_id": conversation_id, + "client_id": client_id, + }, + ) + audio_logger.info( + f"πŸ“₯ Enqueued transcription job {transcription_job.id} for uploaded file" ) - audio_logger.info(f"πŸ“₯ Enqueued transcription job {transcription_job.id} for uploaded file") else: audio_logger.warning( f"⚠️ Skipping transcription for conversation {conversation_id}: " @@ -223,16 +251,18 @@ async def upload_and_process_audio_files( user_id=user.user_id, transcript_version_id=version_id, # Pass the version_id from transcription job depends_on_job=transcription_job, # Wait for transcription to complete (or None) - client_id=client_id # Pass client_id for UI tracking + client_id=client_id, # Pass client_id for UI tracking ) file_result = { "filename": filename, "status": "started", # RQ standard: job has been enqueued "conversation_id": conversation_id, - "transcript_job_id": transcription_job.id if transcription_job else None, - "speaker_job_id": job_ids['speaker_recognition'], - "memory_job_id": job_ids['memory'], + "transcript_job_id": ( + transcription_job.id if transcription_job else None + ), + "speaker_job_id": job_ids["speaker_recognition"], + "memory_job_id": job_ids["memory"], "duration_seconds": round(duration, 2), } if is_video_source: @@ -243,10 +273,10 @@ async def upload_and_process_audio_files( job_chain = [] if transcription_job: job_chain.append(transcription_job.id) - if job_ids['speaker_recognition']: - job_chain.append(job_ids['speaker_recognition']) - if job_ids['memory']: - job_chain.append(job_ids['memory']) + if job_ids["speaker_recognition"]: + job_chain.append(job_ids["speaker_recognition"]) + if job_ids["memory"]: + job_chain.append(job_ids["memory"]) audio_logger.info( f"βœ… Processed {filename} β†’ conversation {conversation_id}, " @@ -256,19 +286,23 @@ async def upload_and_process_audio_files( except (OSError, IOError) as e: # File I/O errors during audio processing audio_logger.exception(f"File I/O error processing {filename}") - processed_files.append({ - "filename": filename, - "status": "error", - "error": str(e), - }) + processed_files.append( + { + "filename": filename, + "status": "error", + "error": str(e), + } + ) except Exception as e: # Unexpected errors during file processing audio_logger.exception(f"Unexpected error processing file {filename}") - processed_files.append({ - "filename": filename, - "status": "error", - "error": str(e), - }) + processed_files.append( + { + "filename": filename, + "status": "error", + "error": str(e), + } + ) successful_files = [f for f in processed_files if f.get("status") == "started"] failed_files = [f for f in processed_files if f.get("status") == "error"] @@ -291,7 +325,9 @@ async def upload_and_process_audio_files( return JSONResponse(status_code=400, content=response_body) elif len(failed_files) > 0: # SOME files failed (partial success) - return 207 Multi-Status - audio_logger.warning(f"Partial upload: {len(successful_files)} succeeded, {len(failed_files)} failed") + audio_logger.warning( + f"Partial upload: {len(successful_files)} succeeded, {len(failed_files)} failed" + ) return JSONResponse(status_code=207, content=response_body) else: # All files succeeded - return 200 OK diff --git a/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py index b400d3ed..d603e7ea 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py @@ -6,10 +6,7 @@ from fastapi.responses import JSONResponse -from advanced_omi_backend.client_manager import ( - ClientManager, - get_user_clients_active, -) +from advanced_omi_backend.client_manager import ClientManager, get_user_clients_active from advanced_omi_backend.users import User logger = logging.getLogger(__name__) @@ -37,7 +34,9 @@ async def get_active_clients(user: User, client_manager: ClientManager): # Filter to only the user's clients user_clients = [ - client for client in all_clients if client["client_id"] in user_active_clients + client + for client in all_clients + if client["client_id"] in user_active_clients ] return { diff --git a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py index d2cfc7df..316c3299 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py @@ -73,6 +73,7 @@ def get_job_status_from_rq(job: Job) -> str: return status_str + # Queue name constants TRANSCRIPTION_QUEUE = "transcription" MEMORY_QUEUE = "memory" @@ -86,9 +87,13 @@ def get_job_status_from_rq(job: Job) -> str: JOB_RESULT_TTL = int(os.getenv("RQ_RESULT_TTL", 86400)) # 24 hour default # Create queues with custom result TTL -transcription_queue = Queue(TRANSCRIPTION_QUEUE, connection=redis_conn, default_timeout=86400) # 24 hours for streaming jobs +transcription_queue = Queue( + TRANSCRIPTION_QUEUE, connection=redis_conn, default_timeout=86400 +) # 24 hours for streaming jobs memory_queue = Queue(MEMORY_QUEUE, connection=redis_conn, default_timeout=300) -audio_queue = Queue(AUDIO_QUEUE, connection=redis_conn, default_timeout=86400) # 24 hours for all-day sessions +audio_queue = Queue( + AUDIO_QUEUE, connection=redis_conn, default_timeout=86400 +) # 24 hours for all-day sessions default_queue = Queue(DEFAULT_QUEUE, connection=redis_conn, default_timeout=300) @@ -123,7 +128,14 @@ def get_job_stats() -> Dict[str, Any]: canceled_jobs += len(queue.canceled_job_registry) deferred_jobs += len(queue.deferred_job_registry) - total_jobs = queued_jobs + started_jobs + finished_jobs + failed_jobs + canceled_jobs + deferred_jobs + total_jobs = ( + queued_jobs + + started_jobs + + finished_jobs + + failed_jobs + + canceled_jobs + + deferred_jobs + ) return { "total_jobs": total_jobs, @@ -133,7 +145,7 @@ def get_job_stats() -> Dict[str, Any]: "failed_jobs": failed_jobs, "canceled_jobs": canceled_jobs, "deferred_jobs": deferred_jobs, - "timestamp": datetime.utcnow().isoformat() + "timestamp": datetime.utcnow().isoformat(), } @@ -142,7 +154,7 @@ def get_jobs( offset: int = 0, queue_name: str = None, job_type: str = None, - client_id: str = None + client_id: str = None, ) -> Dict[str, Any]: """ Get jobs from a specific queue or all queues with optional filtering. @@ -157,9 +169,13 @@ def get_jobs( Returns: Dict with jobs list and pagination metadata matching frontend expectations """ - logger.info(f"πŸ” DEBUG get_jobs: Filtering - queue_name={queue_name}, job_type={job_type}, client_id={client_id}") + logger.info( + f"πŸ” DEBUG get_jobs: Filtering - queue_name={queue_name}, job_type={job_type}, client_id={client_id}" + ) all_jobs = [] - seen_job_ids = set() # Track which job IDs we've already processed to avoid duplicates + seen_job_ids = ( + set() + ) # Track which job IDs we've already processed to avoid duplicates queues_to_check = [queue_name] if queue_name else QUEUE_NAMES logger.info(f"πŸ” DEBUG get_jobs: Checking queues: {queues_to_check}") @@ -170,10 +186,19 @@ def get_jobs( # Collect jobs from all registries (using RQ standard status names) registries = [ (queue.job_ids, "queued"), - (queue.started_job_registry.get_job_ids(), "started"), # RQ standard, not "processing" - (queue.finished_job_registry.get_job_ids(), "finished"), # RQ standard, not "completed" + ( + queue.started_job_registry.get_job_ids(), + "started", + ), # RQ standard, not "processing" + ( + queue.finished_job_registry.get_job_ids(), + "finished", + ), # RQ standard, not "completed" (queue.failed_job_registry.get_job_ids(), "failed"), - (queue.deferred_job_registry.get_job_ids(), "deferred"), # Jobs waiting for dependencies + ( + queue.deferred_job_registry.get_job_ids(), + "deferred", + ), # Jobs waiting for dependencies ] for job_ids, status in registries: @@ -190,46 +215,76 @@ def get_jobs( user_id = job.kwargs.get("user_id", "") if job.kwargs else "" # Extract just the function name (e.g., "listen_for_speech_job" from "module.listen_for_speech_job") - func_name = job.func_name.split('.')[-1] if job.func_name else "unknown" + func_name = ( + job.func_name.split(".")[-1] if job.func_name else "unknown" + ) # Debug: Log job details before filtering - logger.debug(f"πŸ” DEBUG get_jobs: Job {job_id} - func_name={func_name}, full_func_name={job.func_name}, meta_client_id={job.meta.get('client_id', '') if job.meta else ''}, status={status}") + logger.debug( + f"πŸ” DEBUG get_jobs: Job {job_id} - func_name={func_name}, full_func_name={job.func_name}, meta_client_id={job.meta.get('client_id', '') if job.meta else ''}, status={status}" + ) # Apply job_type filter if job_type and job_type not in func_name: - logger.debug(f"πŸ” DEBUG get_jobs: Filtered out {job_id} - job_type '{job_type}' not in func_name '{func_name}'") + logger.debug( + f"πŸ” DEBUG get_jobs: Filtered out {job_id} - job_type '{job_type}' not in func_name '{func_name}'" + ) continue # Apply client_id filter (partial match in meta) if client_id: - job_client_id = job.meta.get("client_id", "") if job.meta else "" + job_client_id = ( + job.meta.get("client_id", "") if job.meta else "" + ) if client_id not in job_client_id: - logger.debug(f"πŸ” DEBUG get_jobs: Filtered out {job_id} - client_id '{client_id}' not in job_client_id '{job_client_id}'") + logger.debug( + f"πŸ” DEBUG get_jobs: Filtered out {job_id} - client_id '{client_id}' not in job_client_id '{job_client_id}'" + ) continue - logger.debug(f"πŸ” DEBUG get_jobs: Including job {job_id} in results") - - all_jobs.append({ - "job_id": job.id, - "job_type": func_name, - "user_id": user_id, - "status": status, - "priority": "normal", # RQ doesn't track priority in metadata - "data": { - "description": job.description or "", - "queue": qname, - }, - "result": job.result if hasattr(job, 'result') else None, - "meta": job.meta if job.meta else {}, # Include job metadata - "error_message": str(job.exc_info) if job.exc_info else None, - "created_at": job.created_at.isoformat() if job.created_at else None, - "started_at": job.started_at.isoformat() if job.started_at else None, - "completed_at": job.ended_at.isoformat() if job.ended_at else None, - "retry_count": job.retries_left if hasattr(job, 'retries_left') else 0, - "max_retries": 3, # Default max retries - "progress_percent": (job.meta or {}).get("batch_progress", {}).get("percent", 0), - "progress_message": (job.meta or {}).get("batch_progress", {}).get("message", ""), - }) + logger.debug( + f"πŸ” DEBUG get_jobs: Including job {job_id} in results" + ) + + all_jobs.append( + { + "job_id": job.id, + "job_type": func_name, + "user_id": user_id, + "status": status, + "priority": "normal", # RQ doesn't track priority in metadata + "data": { + "description": job.description or "", + "queue": qname, + }, + "result": job.result if hasattr(job, "result") else None, + "meta": ( + job.meta if job.meta else {} + ), # Include job metadata + "error_message": ( + str(job.exc_info) if job.exc_info else None + ), + "created_at": ( + job.created_at.isoformat() if job.created_at else None + ), + "started_at": ( + job.started_at.isoformat() if job.started_at else None + ), + "completed_at": ( + job.ended_at.isoformat() if job.ended_at else None + ), + "retry_count": ( + job.retries_left if hasattr(job, "retries_left") else 0 + ), + "max_retries": 3, # Default max retries + "progress_percent": (job.meta or {}) + .get("batch_progress", {}) + .get("percent", 0), + "progress_message": (job.meta or {}) + .get("batch_progress", {}) + .get("message", ""), + } + ) except Exception as e: logger.error(f"Error fetching job {job_id}: {e}") @@ -238,10 +293,12 @@ def get_jobs( # Paginate total_jobs = len(all_jobs) - paginated_jobs = all_jobs[offset:offset + limit] + paginated_jobs = all_jobs[offset : offset + limit] has_more = (offset + limit) < total_jobs - logger.info(f"πŸ” DEBUG get_jobs: Found {total_jobs} matching jobs (returning {len(paginated_jobs)} after pagination)") + logger.info( + f"πŸ” DEBUG get_jobs: Found {total_jobs} matching jobs (returning {len(paginated_jobs)} after pagination)" + ) return { "jobs": paginated_jobs, @@ -250,7 +307,7 @@ def get_jobs( "limit": limit, "offset": offset, "has_more": has_more, - } + }, } @@ -281,7 +338,7 @@ def is_job_complete(job): return False # Check dependent jobs - for dep_id in (job.dependent_ids or []): + for dep_id in job.dependent_ids or []: try: dep_job = Job.fetch(dep_id, connection=redis_conn) if not is_job_complete(dep_job): @@ -310,7 +367,7 @@ def is_job_complete(job): job = Job.fetch(job_id, connection=redis_conn) # Only check jobs with client_id in meta - if job.meta and job.meta.get('client_id') == client_id: + if job.meta and job.meta.get("client_id") == client_id: if not is_job_complete(job): return False except Exception as e: @@ -320,9 +377,7 @@ def is_job_complete(job): def start_streaming_jobs( - session_id: str, - user_id: str, - client_id: str + session_id: str, user_id: str, client_id: str ) -> Dict[str, str]: """ Enqueue jobs for streaming audio session (initial session setup). @@ -351,7 +406,7 @@ def start_streaming_jobs( # Read always_persist from global config NOW (backend process has fresh config) misc_settings = get_misc_settings() - always_persist = misc_settings.get('always_persist_enabled', False) + always_persist = misc_settings.get("always_persist_enabled", False) # Enqueue speech detection job speech_job = transcription_queue.enqueue( @@ -365,7 +420,7 @@ def start_streaming_jobs( failure_ttl=86400, # Cleanup failed jobs after 24h job_id=f"speech-detect_{session_id[:12]}", description=f"Listening for speech...", - meta={'client_id': client_id, 'session_level': True} + meta={"client_id": client_id, "session_level": True}, ) # Log job enqueue with TTL information for debugging actual_ttl = redis_conn.ttl(f"rq:job:{speech_job.id}") @@ -379,7 +434,9 @@ def start_streaming_jobs( # Store job ID for cleanup (keyed by client_id for easy WebSocket cleanup) try: - redis_conn.set(f"speech_detection_job:{client_id}", speech_job.id, ex=86400) # 24 hour TTL + redis_conn.set( + f"speech_detection_job:{client_id}", speech_job.id, ex=86400 + ) # 24 hour TTL logger.info(f"πŸ“Œ Stored speech detection job ID for client {client_id}") except Exception as e: logger.warning(f"⚠️ Failed to store job ID for {client_id}: {e}") @@ -399,7 +456,10 @@ def start_streaming_jobs( failure_ttl=86400, # Cleanup failed jobs after 24h job_id=f"audio-persist_{session_id[:12]}", description=f"Audio persistence for session {session_id[:12]}", - meta={'client_id': client_id, 'session_level': True} # Mark as session-level job + meta={ + "client_id": client_id, + "session_level": True, + }, # Mark as session-level job ) # Log job enqueue with TTL information for debugging actual_ttl = redis_conn.ttl(f"rq:job:{audio_job.id}") @@ -411,19 +471,16 @@ def start_streaming_jobs( f"queue_length={audio_queue.count}, client_id={client_id}" ) - return { - 'speech_detection': speech_job.id, - 'audio_persistence': audio_job.id - } + return {"speech_detection": speech_job.id, "audio_persistence": audio_job.id} def start_post_conversation_jobs( conversation_id: str, user_id: str, transcript_version_id: Optional[str] = None, - depends_on_job = None, + depends_on_job=None, client_id: Optional[str] = None, - end_reason: str = "file_upload" + end_reason: str = "file_upload", ) -> Dict[str, str]: """ Start post-conversation processing jobs after conversation is created. @@ -458,21 +515,27 @@ def start_post_conversation_jobs( version_id = transcript_version_id or str(uuid.uuid4()) # Build job metadata (include client_id if provided for UI tracking) - job_meta = {'conversation_id': conversation_id} + job_meta = {"conversation_id": conversation_id} if client_id: - job_meta['client_id'] = client_id + job_meta["client_id"] = client_id # Check if speaker recognition is enabled - speaker_config = get_service_config('speaker_recognition') - speaker_enabled = speaker_config.get('enabled', True) # Default to True for backward compatibility + speaker_config = get_service_config("speaker_recognition") + speaker_enabled = speaker_config.get( + "enabled", True + ) # Default to True for backward compatibility # Step 1: Speaker recognition job (conditional - only if enabled) - speaker_dependency = depends_on_job # Start with upstream dependency (transcription if file upload) + speaker_dependency = ( + depends_on_job # Start with upstream dependency (transcription if file upload) + ) speaker_job = None if speaker_enabled: speaker_job_id = f"speaker_{conversation_id[:12]}" - logger.info(f"πŸ” DEBUG: Creating speaker job with job_id={speaker_job_id}, conversation_id={conversation_id[:12]}") + logger.info( + f"πŸ” DEBUG: Creating speaker job with job_id={speaker_job_id}, conversation_id={conversation_id[:12]}" + ) speaker_job = transcription_queue.enqueue( recognise_speakers_job, @@ -483,26 +546,36 @@ def start_post_conversation_jobs( depends_on=speaker_dependency, job_id=speaker_job_id, description=f"Speaker recognition for conversation {conversation_id[:8]}", - meta=job_meta + meta=job_meta, ) speaker_dependency = speaker_job # Chain for next jobs if depends_on_job: - logger.info(f"πŸ“₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (depends on {depends_on_job.id})") + logger.info( + f"πŸ“₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (depends on {depends_on_job.id})" + ) else: - logger.info(f"πŸ“₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (no dependencies, starts immediately)") + logger.info( + f"πŸ“₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (no dependencies, starts immediately)" + ) else: - logger.info(f"⏭️ Speaker recognition disabled, skipping speaker job for conversation {conversation_id[:8]}") + logger.info( + f"⏭️ Speaker recognition disabled, skipping speaker job for conversation {conversation_id[:8]}" + ) # Step 2: Memory extraction job (conditional - only if enabled) # Check if memory extraction is enabled - memory_config = get_service_config('memory.extraction') - memory_enabled = memory_config.get('enabled', True) # Default to True for backward compatibility + memory_config = get_service_config("memory.extraction") + memory_enabled = memory_config.get( + "enabled", True + ) # Default to True for backward compatibility memory_job = None if memory_enabled: # Depends on speaker job if it was created, otherwise depends on upstream (transcription or nothing) memory_job_id = f"memory_{conversation_id[:12]}" - logger.info(f"πŸ” DEBUG: Creating memory job with job_id={memory_job_id}, conversation_id={conversation_id[:12]}") + logger.info( + f"πŸ” DEBUG: Creating memory job with job_id={memory_job_id}, conversation_id={conversation_id[:12]}" + ) memory_job = memory_queue.enqueue( process_memory_job, @@ -512,23 +585,33 @@ def start_post_conversation_jobs( depends_on=speaker_dependency, # Either speaker_job or upstream dependency job_id=memory_job_id, description=f"Memory extraction for conversation {conversation_id[:8]}", - meta=job_meta + meta=job_meta, ) if speaker_job: - logger.info(f"πŸ“₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on speaker job {speaker_job.id})") + logger.info( + f"πŸ“₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on speaker job {speaker_job.id})" + ) elif depends_on_job: - logger.info(f"πŸ“₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on {depends_on_job.id})") + logger.info( + f"πŸ“₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on {depends_on_job.id})" + ) else: - logger.info(f"πŸ“₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (no dependencies, starts immediately)") + logger.info( + f"πŸ“₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (no dependencies, starts immediately)" + ) else: - logger.info(f"⏭️ Memory extraction disabled, skipping memory job for conversation {conversation_id[:8]}") + logger.info( + f"⏭️ Memory extraction disabled, skipping memory job for conversation {conversation_id[:8]}" + ) # Step 3: Title/summary generation job # Depends on memory job to avoid race condition (both jobs save the conversation document) # and to ensure fresh memories are available for context-enriched summaries title_dependency = memory_job if memory_job else speaker_dependency title_job_id = f"title_summary_{conversation_id[:12]}" - logger.info(f"πŸ” DEBUG: Creating title/summary job with job_id={title_job_id}, conversation_id={conversation_id[:12]}") + logger.info( + f"πŸ” DEBUG: Creating title/summary job with job_id={title_job_id}, conversation_id={conversation_id[:12]}" + ) title_summary_job = default_queue.enqueue( generate_title_summary_job, @@ -538,21 +621,31 @@ def start_post_conversation_jobs( depends_on=title_dependency, job_id=title_job_id, description=f"Generate title and summary for conversation {conversation_id[:8]}", - meta=job_meta + meta=job_meta, ) if memory_job: - logger.info(f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on memory job {memory_job.id})") + logger.info( + f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on memory job {memory_job.id})" + ) elif speaker_job: - logger.info(f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on speaker job {speaker_job.id})") + logger.info( + f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on speaker job {speaker_job.id})" + ) elif depends_on_job: - logger.info(f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on {depends_on_job.id})") + logger.info( + f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on {depends_on_job.id})" + ) else: - logger.info(f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (no dependencies, starts immediately)") + logger.info( + f"πŸ“₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (no dependencies, starts immediately)" + ) # Step 5: Dispatch conversation.complete event (runs after both memory and title/summary complete) # This ensures plugins receive the event after all processing is done event_job_id = f"event_complete_{conversation_id[:12]}" - logger.info(f"πŸ” DEBUG: Creating conversation complete event job with job_id={event_job_id}, conversation_id={conversation_id[:12]}") + logger.info( + f"πŸ” DEBUG: Creating conversation complete event job with job_id={event_job_id}, conversation_id={conversation_id[:12]}" + ) # Event job depends on memory and title/summary jobs that were actually enqueued # Build dependency list excluding None values @@ -571,29 +664,33 @@ def start_post_conversation_jobs( end_reason, # Use the end_reason parameter (defaults to 'file_upload' for backward compatibility) job_timeout=120, # 2 minutes result_ttl=JOB_RESULT_TTL, - depends_on=event_dependencies if event_dependencies else None, # Wait for jobs that were enqueued + depends_on=( + event_dependencies if event_dependencies else None + ), # Wait for jobs that were enqueued job_id=event_job_id, description=f"Dispatch conversation complete event ({end_reason}) for {conversation_id[:8]}", - meta=job_meta + meta=job_meta, ) # Log event dispatch dependencies if event_dependencies: dep_ids = [job.id for job in event_dependencies] - logger.info(f"πŸ“₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (depends on {', '.join(dep_ids)})") + logger.info( + f"πŸ“₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (depends on {', '.join(dep_ids)})" + ) else: - logger.info(f"πŸ“₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (no dependencies, starts immediately)") + logger.info( + f"πŸ“₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (no dependencies, starts immediately)" + ) return { - 'speaker_recognition': speaker_job.id if speaker_job else None, - 'memory': memory_job.id if memory_job else None, - 'title_summary': title_summary_job.id, - 'event_dispatch': event_dispatch_job.id + "speaker_recognition": speaker_job.id if speaker_job else None, + "memory": memory_job.id if memory_job else None, + "title_summary": title_summary_job.id, + "event_dispatch": event_dispatch_job.id, } - - def get_queue_health() -> Dict[str, Any]: """Get health status of all queues and workers.""" health = { @@ -637,15 +734,18 @@ def get_queue_health() -> Dict[str, Any]: else: health["idle_workers"] += 1 - health["workers"].append({ - "name": worker.name, - "state": state, - "queues": [q.name for q in worker.queues], - "current_job": current_job, - }) + health["workers"].append( + { + "name": worker.name, + "state": state, + "queues": [q.name for q in worker.queues], + "current_job": current_job, + } + ) return health + # needs tidying but works for now async def cleanup_stuck_stream_workers(request): """Clean up stuck Redis Stream consumers and pending messages from all active streams.""" @@ -660,7 +760,7 @@ async def cleanup_stuck_stream_workers(request): if not redis_client: return JSONResponse( status_code=503, - content={"error": "Redis client for audio streaming not initialized"} + content={"error": "Redis client for audio streaming not initialized"}, ) cleanup_results = {} @@ -673,17 +773,25 @@ async def cleanup_stuck_stream_workers(request): stream_keys = await redis_client.keys("audio:stream:*") for stream_key in stream_keys: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) try: # First check stream age - delete old streams (>1 hour) immediately - stream_info = await redis_client.execute_command('XINFO', 'STREAM', stream_name) + stream_info = await redis_client.execute_command( + "XINFO", "STREAM", stream_name + ) # Parse stream info info_dict = {} for i in range(0, len(stream_info), 2): - key_name = stream_info[i].decode() if isinstance(stream_info[i], bytes) else str(stream_info[i]) - info_dict[key_name] = stream_info[i+1] + key_name = ( + stream_info[i].decode() + if isinstance(stream_info[i], bytes) + else str(stream_info[i]) + ) + info_dict[key_name] = stream_info[i + 1] stream_length = int(info_dict.get("length", 0)) last_entry = info_dict.get("last-entry") @@ -695,12 +803,14 @@ async def cleanup_stuck_stream_workers(request): if stream_length == 0: should_delete_stream = True stream_age = 0 - elif last_entry and isinstance(last_entry, list) and len(last_entry) > 0: + elif ( + last_entry and isinstance(last_entry, list) and len(last_entry) > 0 + ): try: last_id = last_entry[0] if isinstance(last_id, bytes): last_id = last_id.decode() - last_timestamp_ms = int(last_id.split('-')[0]) + last_timestamp_ms = int(last_id.split("-")[0]) last_timestamp_s = last_timestamp_ms / 1000 stream_age = current_time - last_timestamp_s @@ -718,23 +828,33 @@ async def cleanup_stuck_stream_workers(request): "cleaned": 0, "deleted_consumers": 0, "deleted_stream": True, - "stream_age": stream_age + "stream_age": stream_age, } continue # Get consumer groups - groups = await redis_client.execute_command('XINFO', 'GROUPS', stream_name) + groups = await redis_client.execute_command( + "XINFO", "GROUPS", stream_name + ) if not groups: - cleanup_results[stream_name] = {"message": "No consumer groups found", "cleaned": 0, "deleted_stream": False} + cleanup_results[stream_name] = { + "message": "No consumer groups found", + "cleaned": 0, + "deleted_stream": False, + } continue # Parse first group group_dict = {} group = groups[0] for i in range(0, len(group), 2): - key = group[i].decode() if isinstance(group[i], bytes) else str(group[i]) - value = group[i+1] + key = ( + group[i].decode() + if isinstance(group[i], bytes) + else str(group[i]) + ) + value = group[i + 1] if isinstance(value, bytes): try: value = value.decode() @@ -749,7 +869,9 @@ async def cleanup_stuck_stream_workers(request): pending_count = int(group_dict.get("pending", 0)) # Get consumers for this group to check per-consumer pending - consumers = await redis_client.execute_command('XINFO', 'CONSUMERS', stream_name, group_name) + consumers = await redis_client.execute_command( + "XINFO", "CONSUMERS", stream_name, group_name + ) cleaned_count = 0 total_consumer_pending = 0 @@ -759,8 +881,12 @@ async def cleanup_stuck_stream_workers(request): for consumer in consumers: consumer_dict = {} for i in range(0, len(consumer), 2): - key = consumer[i].decode() if isinstance(consumer[i], bytes) else str(consumer[i]) - value = consumer[i+1] + key = ( + consumer[i].decode() + if isinstance(consumer[i], bytes) + else str(consumer[i]) + ) + value = consumer[i + 1] if isinstance(value, bytes): try: value = value.decode() @@ -780,12 +906,20 @@ async def cleanup_stuck_stream_workers(request): is_dead = consumer_idle_ms > 300000 if consumer_pending > 0: - logger.info(f"Found {consumer_pending} pending messages for consumer {consumer_name} (idle: {consumer_idle_ms}ms)") + logger.info( + f"Found {consumer_pending} pending messages for consumer {consumer_name} (idle: {consumer_idle_ms}ms)" + ) # Get pending messages for this specific consumer try: pending_messages = await redis_client.execute_command( - 'XPENDING', stream_name, group_name, '-', '+', str(consumer_pending), consumer_name + "XPENDING", + stream_name, + group_name, + "-", + "+", + str(consumer_pending), + consumer_name, ) # XPENDING returns flat list: [msg_id, consumer, idle_ms, delivery_count, msg_id, ...] @@ -799,31 +933,55 @@ async def cleanup_stuck_stream_workers(request): # Claim the message to a cleanup worker try: await redis_client.execute_command( - 'XCLAIM', stream_name, group_name, 'cleanup-worker', '0', msg_id + "XCLAIM", + stream_name, + group_name, + "cleanup-worker", + "0", + msg_id, ) # Acknowledge it immediately - await redis_client.xack(stream_name, group_name, msg_id) + await redis_client.xack( + stream_name, group_name, msg_id + ) cleaned_count += 1 except Exception as claim_error: - logger.warning(f"Failed to claim/ack message {msg_id}: {claim_error}") + logger.warning( + f"Failed to claim/ack message {msg_id}: {claim_error}" + ) except Exception as consumer_error: - logger.error(f"Error processing consumer {consumer_name}: {consumer_error}") + logger.error( + f"Error processing consumer {consumer_name}: {consumer_error}" + ) # Delete dead consumers (idle > 5 minutes with no pending messages) if is_dead and consumer_pending == 0: try: await redis_client.execute_command( - 'XGROUP', 'DELCONSUMER', stream_name, group_name, consumer_name + "XGROUP", + "DELCONSUMER", + stream_name, + group_name, + consumer_name, ) deleted_consumers += 1 - logger.info(f"🧹 Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)") + logger.info( + f"🧹 Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)" + ) except Exception as delete_error: - logger.warning(f"Failed to delete consumer {consumer_name}: {delete_error}") + logger.warning( + f"Failed to delete consumer {consumer_name}: {delete_error}" + ) if total_consumer_pending == 0 and deleted_consumers == 0: - cleanup_results[stream_name] = {"message": "No pending messages or dead consumers", "cleaned": 0, "deleted_consumers": 0, "deleted_stream": False} + cleanup_results[stream_name] = { + "message": "No pending messages or dead consumers", + "cleaned": 0, + "deleted_consumers": 0, + "deleted_stream": False, + } continue total_cleaned += cleaned_count @@ -833,14 +991,11 @@ async def cleanup_stuck_stream_workers(request): "cleaned": cleaned_count, "deleted_consumers": deleted_consumers, "deleted_stream": False, - "original_pending": pending_count + "original_pending": pending_count, } except Exception as e: - cleanup_results[stream_name] = { - "error": str(e), - "cleaned": 0 - } + cleanup_results[stream_name] = {"error": str(e), "cleaned": 0} return { "success": True, @@ -849,11 +1004,12 @@ async def cleanup_stuck_stream_workers(request): "total_deleted_streams": total_deleted_streams, "streams": cleanup_results, # New key for per-stream results "providers": cleanup_results, # Keep for backward compatibility with frontend - "timestamp": time.time() + "timestamp": time.time(), } except Exception as e: logger.error(f"Error cleaning up stuck workers: {e}", exc_info=True) return JSONResponse( - status_code=500, content={"error": f"Failed to cleanup stuck workers: {str(e)}"} + status_code=500, + content={"error": f"Failed to cleanup stuck workers: {str(e)}"}, ) diff --git a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py index 9b3a2de9..f30401ec 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py @@ -24,7 +24,7 @@ async def mark_session_complete( "user_stopped", "inactivity_timeout", "max_duration", - "all_jobs_complete" + "all_jobs_complete", ], ) -> None: """ @@ -57,12 +57,17 @@ async def mark_session_complete( """ session_key = f"audio:session:{session_id}" mark_time = time.time() - await redis_client.hset(session_key, mapping={ - "status": "finished", - "completed_at": str(mark_time), - "completion_reason": reason - }) - logger.info(f"βœ… Session {session_id[:12]} marked finished: {reason} [TIME: {mark_time:.3f}]") + await redis_client.hset( + session_key, + mapping={ + "status": "finished", + "completed_at": str(mark_time), + "completion_reason": reason, + }, + ) + logger.info( + f"βœ… Session {session_id[:12]} marked finished: {reason} [TIME: {mark_time:.3f}]" + ) async def request_conversation_close( @@ -92,7 +97,9 @@ async def request_conversation_close( if not await redis_client.exists(session_key): return False await redis_client.hset(session_key, "conversation_close_requested", reason) - logger.info(f"πŸ”’ Conversation close requested for session {session_id[:12]}: {reason}") + logger.info( + f"πŸ”’ Conversation close requested for session {session_id[:12]}: {reason}" + ) return True @@ -117,7 +124,9 @@ async def get_session_info(redis_client, session_id: str) -> Optional[Dict]: # Get conversation count for this session conversation_count_key = f"session:conversation_count:{session_id}" conversation_count_bytes = await redis_client.get(conversation_count_key) - conversation_count = int(conversation_count_bytes.decode()) if conversation_count_bytes else 0 + conversation_count = ( + int(conversation_count_bytes.decode()) if conversation_count_bytes else 0 + ) started_at = float(session_data.get(b"started_at", b"0")) last_chunk_at = float(session_data.get(b"last_chunk_at", b"0")) @@ -138,8 +147,12 @@ async def get_session_info(redis_client, session_id: str) -> Optional[Dict]: # Speech detection events "last_event": session_data.get(b"last_event", b"").decode(), "speech_detected_at": session_data.get(b"speech_detected_at", b"").decode(), - "speaker_check_status": session_data.get(b"speaker_check_status", b"").decode(), - "identified_speakers": session_data.get(b"identified_speakers", b"").decode() + "speaker_check_status": session_data.get( + b"speaker_check_status", b"" + ).decode(), + "identified_speakers": session_data.get( + b"identified_speakers", b"" + ).decode(), } except Exception as e: @@ -166,7 +179,7 @@ async def get_all_sessions(redis_client, limit: int = 100) -> List[Dict]: cursor, keys = await redis_client.scan( cursor, match="audio:session:*", count=limit ) - session_keys.extend(keys[:limit - len(session_keys)]) + session_keys.extend(keys[: limit - len(session_keys)]) # Get info for each session sessions = [] @@ -221,7 +234,9 @@ async def increment_session_conversation_count(redis_client, session_id: str) -> logger.info(f"πŸ“Š Conversation count for session {session_id}: {count}") return count except Exception as e: - logger.error(f"Error incrementing conversation count for session {session_id}: {e}") + logger.error( + f"Error incrementing conversation count for session {session_id}: {e}" + ) return 0 @@ -241,7 +256,7 @@ async def get_streaming_status(request): if not redis_client: return JSONResponse( status_code=503, - content={"error": "Redis client for audio streaming not initialized"} + content={"error": "Redis client for audio streaming not initialized"}, ) # Get all sessions (both active and completed) @@ -273,23 +288,48 @@ async def get_streaming_status(request): # All jobs finished - this is truly a finished session # Update Redis status if it wasn't already marked finished if status != "finished": - await mark_session_complete(redis_client, session_id, "all_jobs_complete") + await mark_session_complete( + redis_client, session_id, "all_jobs_complete" + ) # Get additional session data for completed sessions session_key = f"audio:session:{session_id}" session_data = await redis_client.hgetall(session_key) - completed_sessions_from_redis.append({ - "session_id": session_id, - "client_id": session_obj.get("client_id", ""), - "conversation_id": session_data.get(b"conversation_id", b"").decode() if session_data and b"conversation_id" in session_data else None, - "has_conversation": bool(session_data and session_data.get(b"conversation_id", b"")), - "action": session_data.get(b"action", b"finished").decode() if session_data and b"action" in session_data else "finished", - "reason": session_data.get(b"reason", b"").decode() if session_data and b"reason" in session_data else "", - "completed_at": session_obj.get("last_chunk_at", 0), - "audio_file": session_data.get(b"audio_file", b"").decode() if session_data and b"audio_file" in session_data else "", - "conversation_count": session_obj.get("conversation_count", 0) - }) + completed_sessions_from_redis.append( + { + "session_id": session_id, + "client_id": session_obj.get("client_id", ""), + "conversation_id": ( + session_data.get(b"conversation_id", b"").decode() + if session_data and b"conversation_id" in session_data + else None + ), + "has_conversation": bool( + session_data + and session_data.get(b"conversation_id", b"") + ), + "action": ( + session_data.get(b"action", b"finished").decode() + if session_data and b"action" in session_data + else "finished" + ), + "reason": ( + session_data.get(b"reason", b"").decode() + if session_data and b"reason" in session_data + else "" + ), + "completed_at": session_obj.get("last_chunk_at", 0), + "audio_file": ( + session_data.get(b"audio_file", b"").decode() + if session_data and b"audio_file" in session_data + else "" + ), + "conversation_count": session_obj.get( + "conversation_count", 0 + ), + } + ) else: # Status says complete but jobs still processing - keep in active active_sessions.append(session_obj) @@ -314,16 +354,24 @@ async def get_streaming_status(request): current_time = time.time() for stream_key in stream_keys: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) try: # Check if stream exists - stream_info = await redis_client.execute_command('XINFO', 'STREAM', stream_name) + stream_info = await redis_client.execute_command( + "XINFO", "STREAM", stream_name + ) # Parse stream info (returns flat list of key-value pairs) info_dict = {} for i in range(0, len(stream_info), 2): - key = stream_info[i].decode() if isinstance(stream_info[i], bytes) else str(stream_info[i]) - value = stream_info[i+1] + key = ( + stream_info[i].decode() + if isinstance(stream_info[i], bytes) + else str(stream_info[i]) + ) + value = stream_info[i + 1] # Skip complex binary structures like first-entry and last-entry # which contain message data that can't be JSON serialized @@ -351,7 +399,7 @@ async def get_streaming_status(request): if last_entry_id: try: # Redis Stream IDs format: "milliseconds-sequence" - last_timestamp_ms = int(last_entry_id.split('-')[0]) + last_timestamp_ms = int(last_entry_id.split("-")[0]) last_timestamp_s = last_timestamp_ms / 1000 stream_age_seconds = current_time - last_timestamp_s except (ValueError, IndexError, AttributeError): @@ -369,7 +417,9 @@ async def get_streaming_status(request): session_idle_seconds = session_data.get("idle_seconds", 0) # Get consumer groups - groups = await redis_client.execute_command('XINFO', 'GROUPS', stream_name) + groups = await redis_client.execute_command( + "XINFO", "GROUPS", stream_name + ) stream_data = { "stream_length": info_dict.get("length", 0), @@ -378,19 +428,23 @@ async def get_streaming_status(request): "session_age_seconds": session_age_seconds, # Age since session started "session_idle_seconds": session_idle_seconds, # Time since last audio chunk "client_id": client_id, # Include client_id for reference - "consumer_groups": [] + "consumer_groups": [], } # Track if stream has any active consumers has_active_consumer = False - min_consumer_idle_ms = float('inf') + min_consumer_idle_ms = float("inf") # Parse consumer groups for group in groups: group_dict = {} for i in range(0, len(group), 2): - key = group[i].decode() if isinstance(group[i], bytes) else str(group[i]) - value = group[i+1] + key = ( + group[i].decode() + if isinstance(group[i], bytes) + else str(group[i]) + ) + value = group[i + 1] if isinstance(value, bytes): try: value = value.decode() @@ -403,15 +457,21 @@ async def get_streaming_status(request): group_name = group_name.decode() # Get consumers for this group - consumers = await redis_client.execute_command('XINFO', 'CONSUMERS', stream_name, group_name) + consumers = await redis_client.execute_command( + "XINFO", "CONSUMERS", stream_name, group_name + ) consumer_list = [] consumer_pending_total = 0 for consumer in consumers: consumer_dict = {} for i in range(0, len(consumer), 2): - key = consumer[i].decode() if isinstance(consumer[i], bytes) else str(consumer[i]) - value = consumer[i+1] + key = ( + consumer[i].decode() + if isinstance(consumer[i], bytes) + else str(consumer[i]) + ) + value = consumer[i + 1] if isinstance(value, bytes): try: value = value.decode() @@ -428,17 +488,21 @@ async def get_streaming_status(request): consumer_pending_total += consumer_pending # Track minimum idle time - min_consumer_idle_ms = min(min_consumer_idle_ms, consumer_idle_ms) + min_consumer_idle_ms = min( + min_consumer_idle_ms, consumer_idle_ms + ) # Consumer is active if idle < 5 minutes (300000ms) if consumer_idle_ms < 300000: has_active_consumer = True - consumer_list.append({ - "name": consumer_name, - "pending": consumer_pending, - "idle_ms": consumer_idle_ms - }) + consumer_list.append( + { + "name": consumer_name, + "pending": consumer_pending, + "idle_ms": consumer_idle_ms, + } + ) # Get group-level pending count (may be 0 even if consumers have pending) try: @@ -451,20 +515,24 @@ async def get_streaming_status(request): # (Sometimes group pending is 0 but consumers still have pending messages) effective_pending = max(group_pending_count, consumer_pending_total) - stream_data["consumer_groups"].append({ - "name": str(group_name), - "consumers": consumer_list, - "pending": int(effective_pending) - }) + stream_data["consumer_groups"].append( + { + "name": str(group_name), + "consumers": consumer_list, + "pending": int(effective_pending), + } + ) # Determine if stream is active or completed # Active: has active consumers OR pending messages OR recent activity (< 5 min) # Completed: no active consumers and idle > 5 minutes but < 1 hour - total_pending = sum(group["pending"] for group in stream_data["consumer_groups"]) + total_pending = sum( + group["pending"] for group in stream_data["consumer_groups"] + ) is_active = ( - has_active_consumer or - total_pending > 0 or - stream_age_seconds < 300 # Less than 5 minutes old + has_active_consumer + or total_pending > 0 + or stream_age_seconds < 300 # Less than 5 minutes old ) if is_active: @@ -487,7 +555,7 @@ async def get_streaming_status(request): "finished": len(transcription_queue.finished_job_registry), "failed": len(transcription_queue.failed_job_registry), "canceled": len(transcription_queue.canceled_job_registry), - "deferred": len(transcription_queue.deferred_job_registry) + "deferred": len(transcription_queue.deferred_job_registry), }, "memory_queue": { "queued": memory_queue.count, @@ -495,7 +563,7 @@ async def get_streaming_status(request): "finished": len(memory_queue.finished_job_registry), "failed": len(memory_queue.failed_job_registry), "canceled": len(memory_queue.canceled_job_registry), - "deferred": len(memory_queue.deferred_job_registry) + "deferred": len(memory_queue.deferred_job_registry), }, "default_queue": { "queued": default_queue.count, @@ -503,8 +571,8 @@ async def get_streaming_status(request): "finished": len(default_queue.finished_job_registry), "failed": len(default_queue.failed_job_registry), "canceled": len(default_queue.canceled_job_registry), - "deferred": len(default_queue.deferred_job_registry) - } + "deferred": len(default_queue.deferred_job_registry), + }, } return { @@ -514,14 +582,14 @@ async def get_streaming_status(request): "completed_streams": completed_streams, "stream_health": active_streams, # Backward compatibility - use active_streams "rq_queues": rq_stats, - "timestamp": time.time() + "timestamp": time.time(), } except Exception as e: logger.error(f"Error getting streaming status: {e}", exc_info=True) return JSONResponse( status_code=500, - content={"error": f"Failed to get streaming status: {str(e)}"} + content={"error": f"Failed to get streaming status: {str(e)}"}, ) @@ -538,7 +606,7 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): if not redis_client: return JSONResponse( status_code=503, - content={"error": "Redis client for audio streaming not initialized"} + content={"error": "Redis client for audio streaming not initialized"}, ) # Get all session keys @@ -560,17 +628,18 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): age_seconds = current_time - started_at # Clean up sessions older than max_age or stuck in "finalizing" - should_clean = ( - age_seconds > max_age_seconds or - (status == "finalizing" and age_seconds > 300) # Finalizing for more than 5 minutes - ) + should_clean = age_seconds > max_age_seconds or ( + status == "finalizing" and age_seconds > 300 + ) # Finalizing for more than 5 minutes if should_clean: - old_sessions.append({ - "session_id": session_id, - "age_seconds": age_seconds, - "status": status - }) + old_sessions.append( + { + "session_id": session_id, + "age_seconds": age_seconds, + "status": status, + } + ) await redis_client.delete(key) cleaned_sessions += 1 @@ -580,17 +649,25 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): old_streams = [] for stream_key in stream_keys: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) try: # Check stream info to get last activity - stream_info = await redis_client.execute_command('XINFO', 'STREAM', stream_name) + stream_info = await redis_client.execute_command( + "XINFO", "STREAM", stream_name + ) # Parse stream info info_dict = {} for i in range(0, len(stream_info), 2): - key_name = stream_info[i].decode() if isinstance(stream_info[i], bytes) else str(stream_info[i]) - info_dict[key_name] = stream_info[i+1] + key_name = ( + stream_info[i].decode() + if isinstance(stream_info[i], bytes) + else str(stream_info[i]) + ) + info_dict[key_name] = stream_info[i + 1] stream_length = int(info_dict.get("length", 0)) last_entry = info_dict.get("last-entry") @@ -603,7 +680,9 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): # Empty stream - safe to delete should_delete = True reason = "empty" - elif last_entry and isinstance(last_entry, list) and len(last_entry) > 0: + elif ( + last_entry and isinstance(last_entry, list) and len(last_entry) > 0 + ): # Extract timestamp from last entry ID last_id = last_entry[0] if isinstance(last_id, bytes): @@ -611,7 +690,7 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): # Redis Stream IDs format: "milliseconds-sequence" try: - last_timestamp_ms = int(last_id.split('-')[0]) + last_timestamp_ms = int(last_id.split("-")[0]) last_timestamp_s = last_timestamp_ms / 1000 age_seconds = current_time - last_timestamp_s @@ -622,12 +701,16 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): except (ValueError, IndexError): # If we can't parse timestamp, check if first entry is old first_entry = info_dict.get("first-entry") - if first_entry and isinstance(first_entry, list) and len(first_entry) > 0: + if ( + first_entry + and isinstance(first_entry, list) + and len(first_entry) > 0 + ): try: first_id = first_entry[0] if isinstance(first_id, bytes): first_id = first_id.decode() - first_timestamp_ms = int(first_id.split('-')[0]) + first_timestamp_ms = int(first_id.split("-")[0]) first_timestamp_s = first_timestamp_ms / 1000 age_seconds = current_time - first_timestamp_s @@ -640,12 +723,14 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): if should_delete: await redis_client.delete(stream_name) cleaned_streams += 1 - old_streams.append({ - "stream_name": stream_name, - "reason": reason, - "age_seconds": age_seconds, - "length": stream_length - }) + old_streams.append( + { + "stream_name": stream_name, + "reason": reason, + "age_seconds": age_seconds, + "length": stream_length, + } + ) except Exception as e: logger.debug(f"Error checking stream {stream_name}: {e}") @@ -657,11 +742,12 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600): "cleaned_streams": cleaned_streams, "cleaned_session_details": old_sessions, "cleaned_stream_details": old_streams, - "timestamp": time.time() + "timestamp": time.time(), } except Exception as e: logger.error(f"Error cleaning up old sessions: {e}", exc_info=True) return JSONResponse( - status_code=500, content={"error": f"Failed to cleanup old sessions: {str(e)}"} + status_code=500, + content={"error": f"Failed to cleanup old sessions: {str(e)}"}, ) diff --git a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py index bf3ce1b1..4a6f83c2 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py @@ -7,30 +7,29 @@ import logging import os import re -import signal import shutil +import signal import time import warnings from datetime import UTC, datetime +from io import StringIO from pathlib import Path from typing import Optional -from io import StringIO - -from ruamel.yaml import YAML from fastapi import HTTPException +from ruamel.yaml import YAML from advanced_omi_backend.config import ( get_diarization_settings as load_diarization_settings, ) from advanced_omi_backend.config import get_misc_settings as load_misc_settings -from advanced_omi_backend.config import ( - save_diarization_settings, - save_misc_settings, +from advanced_omi_backend.config import save_diarization_settings, save_misc_settings +from advanced_omi_backend.config_loader import get_plugins_yml_path, save_config_section +from advanced_omi_backend.model_registry import ( + _find_config_path, + get_models_registry, + load_models_config, ) -from advanced_omi_backend.config_loader import get_plugins_yml_path -from advanced_omi_backend.config_loader import save_config_section -from advanced_omi_backend.model_registry import _find_config_path, get_models_registry, load_models_config from advanced_omi_backend.models.user import User logger = logging.getLogger(__name__) @@ -43,7 +42,7 @@ async def get_config_diagnostics(): """ Get comprehensive configuration diagnostics. - + Returns warnings, errors, and status for all configuration components. """ diagnostics = { @@ -52,9 +51,9 @@ async def get_config_diagnostics(): "issues": [], "warnings": [], "info": [], - "components": {} + "components": {}, } - + # Test OmegaConf configuration loading try: from advanced_omi_backend.config_loader import load_config @@ -63,7 +62,7 @@ async def get_config_diagnostics(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") config = load_config(force_reload=True) - + # Check for OmegaConf warnings for warning in w: warning_msg = str(warning.message) @@ -71,148 +70,174 @@ async def get_config_diagnostics(): # Extract the variable name from warning if "variable '" in warning_msg.lower(): var_name = warning_msg.split("'")[1] - diagnostics["warnings"].append({ - "component": "OmegaConf", - "severity": "warning", - "message": f"Environment variable '{var_name}' not set (using empty default)", - "resolution": f"Set {var_name} in .env file if needed" - }) - + diagnostics["warnings"].append( + { + "component": "OmegaConf", + "severity": "warning", + "message": f"Environment variable '{var_name}' not set (using empty default)", + "resolution": f"Set {var_name} in .env file if needed", + } + ) + diagnostics["components"]["omegaconf"] = { "status": "healthy", - "message": "Configuration loaded successfully" + "message": "Configuration loaded successfully", } except Exception as e: diagnostics["overall_status"] = "unhealthy" - diagnostics["issues"].append({ - "component": "OmegaConf", - "severity": "error", - "message": f"Failed to load configuration: {str(e)}", - "resolution": "Check config/defaults.yml and config/config.yml syntax" - }) + diagnostics["issues"].append( + { + "component": "OmegaConf", + "severity": "error", + "message": f"Failed to load configuration: {str(e)}", + "resolution": "Check config/defaults.yml and config/config.yml syntax", + } + ) diagnostics["components"]["omegaconf"] = { "status": "unhealthy", - "message": str(e) + "message": str(e), } - + # Test model registry try: from advanced_omi_backend.model_registry import get_models_registry - + with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") registry = get_models_registry() - + # Capture model loading warnings for warning in w: warning_msg = str(warning.message) - diagnostics["warnings"].append({ - "component": "Model Registry", - "severity": "warning", - "message": warning_msg, - "resolution": "Check model definitions in config/defaults.yml" - }) - + diagnostics["warnings"].append( + { + "component": "Model Registry", + "severity": "warning", + "message": warning_msg, + "resolution": "Check model definitions in config/defaults.yml", + } + ) + if registry: diagnostics["components"]["model_registry"] = { "status": "healthy", "message": f"Loaded {len(registry.models)} models", "details": { "total_models": len(registry.models), - "defaults": dict(registry.defaults) if registry.defaults else {} - } + "defaults": dict(registry.defaults) if registry.defaults else {}, + }, } - + # Check critical models stt = registry.get_default("stt") stt_stream = registry.get_default("stt_stream") llm = registry.get_default("llm") - + # STT check if stt: if stt.api_key: - diagnostics["info"].append({ - "component": "STT (Batch)", - "message": f"Configured: {stt.name} ({stt.model_provider}) - API key present" - }) + diagnostics["info"].append( + { + "component": "STT (Batch)", + "message": f"Configured: {stt.name} ({stt.model_provider}) - API key present", + } + ) else: - diagnostics["warnings"].append({ - "component": "STT (Batch)", - "severity": "warning", - "message": f"{stt.name} ({stt.model_provider}) - No API key configured", - "resolution": "Transcription can fail without API key" - }) + diagnostics["warnings"].append( + { + "component": "STT (Batch)", + "severity": "warning", + "message": f"{stt.name} ({stt.model_provider}) - No API key configured", + "resolution": "Transcription can fail without API key", + } + ) else: - diagnostics["issues"].append({ - "component": "STT (Batch)", - "severity": "error", - "message": "No batch STT model configured", - "resolution": "Set defaults.stt in config.yml" - }) + diagnostics["issues"].append( + { + "component": "STT (Batch)", + "severity": "error", + "message": "No batch STT model configured", + "resolution": "Set defaults.stt in config.yml", + } + ) diagnostics["overall_status"] = "partial" - + # Streaming STT check if stt_stream: if stt_stream.api_key: - diagnostics["info"].append({ - "component": "STT (Streaming)", - "message": f"Configured: {stt_stream.name} ({stt_stream.model_provider}) - API key present" - }) + diagnostics["info"].append( + { + "component": "STT (Streaming)", + "message": f"Configured: {stt_stream.name} ({stt_stream.model_provider}) - API key present", + } + ) else: - diagnostics["warnings"].append({ + diagnostics["warnings"].append( + { + "component": "STT (Streaming)", + "severity": "warning", + "message": f"{stt_stream.name} ({stt_stream.model_provider}) - No API key configured", + "resolution": "Real-time transcription can fail without API key", + } + ) + else: + diagnostics["warnings"].append( + { "component": "STT (Streaming)", "severity": "warning", - "message": f"{stt_stream.name} ({stt_stream.model_provider}) - No API key configured", - "resolution": "Real-time transcription can fail without API key" - }) - else: - diagnostics["warnings"].append({ - "component": "STT (Streaming)", - "severity": "warning", - "message": "No streaming STT model configured - streaming worker disabled", - "resolution": "Set defaults.stt_stream in config.yml for WebSocket transcription" - }) - + "message": "No streaming STT model configured - streaming worker disabled", + "resolution": "Set defaults.stt_stream in config.yml for WebSocket transcription", + } + ) + # LLM check if llm: if llm.api_key: - diagnostics["info"].append({ - "component": "LLM", - "message": f"Configured: {llm.name} ({llm.model_provider}) - API key present" - }) + diagnostics["info"].append( + { + "component": "LLM", + "message": f"Configured: {llm.name} ({llm.model_provider}) - API key present", + } + ) else: - diagnostics["warnings"].append({ - "component": "LLM", - "severity": "warning", - "message": f"{llm.name} ({llm.model_provider}) - No API key configured", - "resolution": "Memory extraction can fail without API key" - }) - + diagnostics["warnings"].append( + { + "component": "LLM", + "severity": "warning", + "message": f"{llm.name} ({llm.model_provider}) - No API key configured", + "resolution": "Memory extraction can fail without API key", + } + ) + else: diagnostics["overall_status"] = "unhealthy" - diagnostics["issues"].append({ - "component": "Model Registry", - "severity": "error", - "message": "Failed to load model registry", - "resolution": "Check config/defaults.yml for syntax errors" - }) + diagnostics["issues"].append( + { + "component": "Model Registry", + "severity": "error", + "message": "Failed to load model registry", + "resolution": "Check config/defaults.yml for syntax errors", + } + ) diagnostics["components"]["model_registry"] = { "status": "unhealthy", - "message": "Registry failed to load" + "message": "Registry failed to load", } except Exception as e: diagnostics["overall_status"] = "partial" - diagnostics["issues"].append({ - "component": "Model Registry", - "severity": "error", - "message": f"Error loading registry: {str(e)}", - "resolution": "Check logs for detailed error information" - }) + diagnostics["issues"].append( + { + "component": "Model Registry", + "severity": "error", + "message": f"Error loading registry: {str(e)}", + "resolution": "Check logs for detailed error information", + } + ) diagnostics["components"]["model_registry"] = { "status": "unhealthy", - "message": str(e) + "message": str(e), } - + # Check environment variables (only warn about keys relevant to configured providers) env_checks = [ ("AUTH_SECRET_KEY", "Required for authentication"), @@ -224,7 +249,9 @@ async def get_config_diagnostics(): # Add LLM API key check based on active provider llm_model = registry.get_default("llm") if llm_model and llm_model.model_provider == "openai": - env_checks.append(("OPENAI_API_KEY", "Required for OpenAI LLM and embeddings")) + env_checks.append( + ("OPENAI_API_KEY", "Required for OpenAI LLM and embeddings") + ) elif llm_model and llm_model.model_provider == "groq": env_checks.append(("GROQ_API_KEY", "Required for Groq LLM")) @@ -233,20 +260,26 @@ async def get_config_diagnostics(): if stt_model: provider = stt_model.model_provider if provider == "deepgram": - env_checks.append(("DEEPGRAM_API_KEY", "Required for Deepgram transcription")) + env_checks.append( + ("DEEPGRAM_API_KEY", "Required for Deepgram transcription") + ) elif provider == "smallest": - env_checks.append(("SMALLEST_API_KEY", "Required for Smallest.ai Pulse transcription")) - + env_checks.append( + ("SMALLEST_API_KEY", "Required for Smallest.ai Pulse transcription") + ) + for env_var, description in env_checks: value = os.getenv(env_var) if not value or value == "": - diagnostics["warnings"].append({ - "component": "Environment Variables", - "severity": "warning", - "message": f"{env_var} not set - {description}", - "resolution": f"Set {env_var} in .env file" - }) - + diagnostics["warnings"].append( + { + "component": "Environment Variables", + "severity": "warning", + "message": f"{env_var} not set - {description}", + "resolution": f"Set {env_var} in .env file", + } + ) + return diagnostics @@ -297,7 +330,9 @@ async def get_observability_config(): from advanced_omi_backend.config_loader import load_config cfg = load_config() - public_url = cfg.get("observability", {}).get("langfuse", {}).get("public_url", "") + public_url = ( + cfg.get("observability", {}).get("langfuse", {}).get("public_url", "") + ) if public_url: # Strip trailing slash and build session URL session_base_url = f"{public_url.rstrip('/')}/project/chronicle/sessions" @@ -321,10 +356,7 @@ async def get_diarization_settings(): try: # Get settings using OmegaConf settings = load_diarization_settings() - return { - "settings": settings, - "status": "success" - } + return {"settings": settings, "status": "success"} except Exception as e: logger.exception("Error getting diarization settings") raise e @@ -335,8 +367,13 @@ async def save_diarization_settings_controller(settings: dict): try: # Validate settings valid_keys = { - "diarization_source", "similarity_threshold", "min_duration", "collar", - "min_duration_off", "min_speakers", "max_speakers" + "diarization_source", + "similarity_threshold", + "min_duration", + "collar", + "min_duration_off", + "min_speakers", + "max_speakers", } # Filter to only valid keys (allow round-trip GETβ†’POST) @@ -348,19 +385,30 @@ async def save_diarization_settings_controller(settings: dict): # Type validation for known keys only if key in ["min_speakers", "max_speakers"]: if not isinstance(value, int) or value < 1 or value > 20: - raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be integer 1-20") + raise HTTPException( + status_code=400, + detail=f"Invalid value for {key}: must be integer 1-20", + ) elif key == "diarization_source": if not isinstance(value, str) or value not in ["pyannote", "deepgram"]: - raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be 'pyannote' or 'deepgram'") + raise HTTPException( + status_code=400, + detail=f"Invalid value for {key}: must be 'pyannote' or 'deepgram'", + ) else: if not isinstance(value, (int, float)) or value < 0: - raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be positive number") + raise HTTPException( + status_code=400, + detail=f"Invalid value for {key}: must be positive number", + ) filtered_settings[key] = value # Reject if NO valid keys provided (completely invalid request) if not filtered_settings: - raise HTTPException(status_code=400, detail="No valid diarization settings provided") + raise HTTPException( + status_code=400, detail="No valid diarization settings provided" + ) # Get current settings and merge with new values current_settings = load_diarization_settings() @@ -373,14 +421,14 @@ async def save_diarization_settings_controller(settings: dict): return { "message": "Diarization settings saved successfully", "settings": current_settings, - "status": "success" + "status": "success", } else: logger.warning("Settings save failed") return { "message": "Settings save failed", "settings": current_settings, - "status": "error" + "status": "error", } except Exception as e: @@ -393,10 +441,7 @@ async def get_misc_settings(): try: # Get settings using OmegaConf settings = load_misc_settings() - return { - "settings": settings, - "status": "success" - } + return {"settings": settings, "status": "success"} except Exception as e: logger.exception("Error getting misc settings") raise e @@ -406,7 +451,12 @@ async def save_misc_settings_controller(settings: dict): """Save miscellaneous settings.""" try: # Validate settings - boolean_keys = {"always_persist_enabled", "use_provider_segments", "per_segment_speaker_id", "always_batch_retranscribe"} + boolean_keys = { + "always_persist_enabled", + "use_provider_segments", + "per_segment_speaker_id", + "always_batch_retranscribe", + } integer_keys = {"transcription_job_timeout_seconds"} valid_keys = boolean_keys | integer_keys @@ -419,16 +469,24 @@ async def save_misc_settings_controller(settings: dict): # Type validation if key in boolean_keys: if not isinstance(value, bool): - raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be boolean") + raise HTTPException( + status_code=400, + detail=f"Invalid value for {key}: must be boolean", + ) elif key == "transcription_job_timeout_seconds": if not isinstance(value, int) or value < 60 or value > 7200: - raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be integer between 60 and 7200") + raise HTTPException( + status_code=400, + detail=f"Invalid value for {key}: must be integer between 60 and 7200", + ) filtered_settings[key] = value # Reject if NO valid keys provided if not filtered_settings: - raise HTTPException(status_code=400, detail="No valid misc settings provided") + raise HTTPException( + status_code=400, detail="No valid misc settings provided" + ) # Save using OmegaConf if save_misc_settings(filtered_settings): @@ -439,14 +497,14 @@ async def save_misc_settings_controller(settings: dict): return { "message": "Miscellaneous settings saved successfully", "settings": updated_settings, - "status": "success" + "status": "success", } else: logger.warning("Settings save failed") return { "message": "Settings save failed", "settings": load_misc_settings(), - "status": "error" + "status": "error", } except HTTPException: @@ -472,9 +530,7 @@ async def get_cleanup_settings_controller(user: User) -> dict: async def save_cleanup_settings_controller( - auto_cleanup_enabled: bool, - retention_days: int, - user: User + auto_cleanup_enabled: bool, retention_days: int, user: User ) -> dict: """ Save cleanup settings (admin only). @@ -504,19 +560,20 @@ async def save_cleanup_settings_controller( # Create settings object settings = CleanupSettings( - auto_cleanup_enabled=auto_cleanup_enabled, - retention_days=retention_days + auto_cleanup_enabled=auto_cleanup_enabled, retention_days=retention_days ) # Save using OmegaConf save_cleanup_settings(settings) - logger.info(f"Admin {user.email} updated cleanup settings: auto_cleanup={auto_cleanup_enabled}, retention={retention_days}d") + logger.info( + f"Admin {user.email} updated cleanup settings: auto_cleanup={auto_cleanup_enabled}, retention={retention_days}d" + ) return { "auto_cleanup_enabled": settings.auto_cleanup_enabled, "retention_days": settings.retention_days, - "message": "Cleanup settings saved successfully" + "message": "Cleanup settings saved successfully", } @@ -526,7 +583,7 @@ async def get_speaker_configuration(user: User): return { "primary_speakers": user.primary_speakers, "user_id": user.user_id, - "status": "success" + "status": "success", } except Exception as e: logger.exception(f"Error getting speaker configuration for user {user.user_id}") @@ -540,32 +597,36 @@ async def update_speaker_configuration(user: User, primary_speakers: list[dict]) for speaker in primary_speakers: if not isinstance(speaker, dict): raise ValueError("Each speaker must be a dictionary") - + required_fields = ["speaker_id", "name", "user_id"] for field in required_fields: if field not in speaker: raise ValueError(f"Missing required field: {field}") - + # Enforce server-side user_id and add timestamp to each speaker for speaker in primary_speakers: speaker["user_id"] = user.user_id # Override client-supplied user_id speaker["selected_at"] = datetime.now(UTC).isoformat() - + # Update user model user.primary_speakers = primary_speakers await user.save() - - logger.info(f"Updated primary speakers configuration for user {user.user_id}: {len(primary_speakers)} speakers") - + + logger.info( + f"Updated primary speakers configuration for user {user.user_id}: {len(primary_speakers)} speakers" + ) + return { "message": "Primary speakers configuration updated successfully", "primary_speakers": primary_speakers, "count": len(primary_speakers), - "status": "success" + "status": "success", } - + except Exception as e: - logger.exception(f"Error updating speaker configuration for user {user.user_id}") + logger.exception( + f"Error updating speaker configuration for user {user.user_id}" + ) raise e @@ -578,25 +639,25 @@ async def get_enrolled_speakers(user: User): # Initialize speaker recognition client speaker_client = SpeakerRecognitionClient() - + if not speaker_client.enabled: return { "speakers": [], "service_available": False, "message": "Speaker recognition service is not configured or disabled", - "status": "success" + "status": "success", } - + # Get enrolled speakers - using hardcoded user_id=1 for now (as noted in speaker_recognition_client.py) speakers = await speaker_client.get_enrolled_speakers(user_id="1") - + return { "speakers": speakers.get("speakers", []) if speakers else [], "service_available": True, "message": "Successfully retrieved enrolled speakers", - "status": "success" + "status": "success", } - + except Exception as e: logger.exception(f"Error getting enrolled speakers for user {user.user_id}") raise e @@ -611,25 +672,25 @@ async def get_speaker_service_status(): # Initialize speaker recognition client speaker_client = SpeakerRecognitionClient() - + if not speaker_client.enabled: return { "service_available": False, "healthy": False, "message": "Speaker recognition service is not configured or disabled", - "status": "disabled" + "status": "disabled", } - + # Perform health check health_result = await speaker_client.health_check() - + if health_result: return { "service_available": True, "healthy": True, "message": "Speaker recognition service is healthy", "service_url": speaker_client.service_url, - "status": "healthy" + "status": "healthy", } else: return { @@ -637,17 +698,17 @@ async def get_speaker_service_status(): "healthy": False, "message": "Speaker recognition service is not responding", "service_url": speaker_client.service_url, - "status": "unhealthy" + "status": "unhealthy", } - + except Exception as e: logger.exception("Error checking speaker service status") raise e - # Memory Configuration Management Functions + async def get_memory_config_raw(): """Get current memory configuration (memory section of config.yml) as YAML.""" try: @@ -655,7 +716,7 @@ async def get_memory_config_raw(): if not os.path.exists(cfg_path): raise FileNotFoundError(f"Config file not found: {cfg_path}") - with open(cfg_path, 'r') as f: + with open(cfg_path, "r") as f: data = _yaml.load(f) or {} memory_section = data.get("memory", {}) stream = StringIO() @@ -691,10 +752,10 @@ async def update_memory_config_raw(config_yaml: str): shutil.copy2(cfg_path, backup_path) # Update memory section and write file - with open(cfg_path, 'r') as f: + with open(cfg_path, "r") as f: data = _yaml.load(f) or {} data["memory"] = new_mem - with open(cfg_path, 'w') as f: + with open(cfg_path, "w") as f: _yaml.dump(data, f) # Reload registry @@ -717,9 +778,13 @@ async def validate_memory_config(config_yaml: str): try: parsed = _yaml.load(config_yaml) except Exception as e: - raise HTTPException(status_code=400, detail=f"Invalid YAML syntax: {str(e)}") + raise HTTPException( + status_code=400, detail=f"Invalid YAML syntax: {str(e)}" + ) if not isinstance(parsed, dict): - raise HTTPException(status_code=400, detail="Configuration must be a YAML object") + raise HTTPException( + status_code=400, detail="Configuration must be a YAML object" + ) # Minimal checks # provider optional; timeout_seconds optional; extraction enabled/prompt optional return {"message": "Configuration is valid", "status": "success"} @@ -728,7 +793,9 @@ async def validate_memory_config(config_yaml: str): raise except Exception as e: logger.exception("Error validating memory config") - raise HTTPException(status_code=500, detail=f"Error validating memory config: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Error validating memory config: {str(e)}" + ) async def reload_memory_config(): @@ -736,7 +803,11 @@ async def reload_memory_config(): try: cfg_path = _find_config_path() load_models_config(force_reload=True) - return {"message": "Configuration reloaded", "config_path": str(cfg_path), "status": "success"} + return { + "message": "Configuration reloaded", + "config_path": str(cfg_path), + "status": "success", + } except Exception as e: logger.exception("Error reloading config") raise e @@ -758,7 +829,7 @@ async def delete_all_user_memories(user: User): "message": f"Successfully deleted {deleted_count} memories", "deleted_count": deleted_count, "user_id": user.user_id, - "status": "success" + "status": "success", } except Exception as e: @@ -768,6 +839,7 @@ async def delete_all_user_memories(user: User): # Memory Provider Configuration Functions + async def get_memory_provider(): """Get current memory provider configuration.""" try: @@ -782,7 +854,7 @@ async def get_memory_provider(): return { "current_provider": current_provider, "available_providers": available_providers, - "status": "success" + "status": "success", } except Exception as e: @@ -798,7 +870,9 @@ async def set_memory_provider(provider: str): valid_providers = ["chronicle", "openmemory_mcp"] if provider not in valid_providers: - raise ValueError(f"Invalid provider '{provider}'. Valid providers: {', '.join(valid_providers)}") + raise ValueError( + f"Invalid provider '{provider}'. Valid providers: {', '.join(valid_providers)}" + ) # Path to .env file (assuming we're running from backends/advanced/) env_path = os.path.join(os.getcwd(), ".env") @@ -807,7 +881,7 @@ async def set_memory_provider(provider: str): raise FileNotFoundError(f".env file not found at {env_path}") # Read current .env file - with open(env_path, 'r') as file: + with open(env_path, "r") as file: lines = file.readlines() # Update or add MEMORY_PROVIDER line @@ -823,7 +897,9 @@ async def set_memory_provider(provider: str): # If MEMORY_PROVIDER wasn't found, add it if not provider_found: - updated_lines.append(f"\n# Memory Provider Configuration\nMEMORY_PROVIDER={provider}\n") + updated_lines.append( + f"\n# Memory Provider Configuration\nMEMORY_PROVIDER={provider}\n" + ) # Create backup backup_path = f"{env_path}.bak" @@ -831,7 +907,7 @@ async def set_memory_provider(provider: str): logger.info(f"Created .env backup at {backup_path}") # Write updated .env file - with open(env_path, 'w') as file: + with open(env_path, "w") as file: file.writelines(updated_lines) # Update environment variable for current process @@ -845,7 +921,7 @@ async def set_memory_provider(provider: str): "env_path": env_path, "backup_created": True, "requires_restart": True, - "status": "success" + "status": "success", } except Exception as e: @@ -855,6 +931,7 @@ async def set_memory_provider(provider: str): # LLM Operations Configuration Functions + async def get_llm_operations(): """Get LLM operation configurations and available models.""" try: @@ -902,29 +979,49 @@ async def save_llm_operations(operations: dict): for op_name, op_value in operations.items(): if not isinstance(op_value, dict): - raise HTTPException(status_code=400, detail=f"Operation '{op_name}' must be a dict") + raise HTTPException( + status_code=400, detail=f"Operation '{op_name}' must be a dict" + ) extra_keys = set(op_value.keys()) - valid_keys if extra_keys: - raise HTTPException(status_code=400, detail=f"Invalid keys for '{op_name}': {extra_keys}") + raise HTTPException( + status_code=400, + detail=f"Invalid keys for '{op_name}': {extra_keys}", + ) if "temperature" in op_value and op_value["temperature"] is not None: t = op_value["temperature"] if not isinstance(t, (int, float)) or t < 0 or t > 2: - raise HTTPException(status_code=400, detail=f"Invalid temperature for '{op_name}': must be 0-2") + raise HTTPException( + status_code=400, + detail=f"Invalid temperature for '{op_name}': must be 0-2", + ) if "max_tokens" in op_value and op_value["max_tokens"] is not None: mt = op_value["max_tokens"] if not isinstance(mt, int) or mt <= 0: - raise HTTPException(status_code=400, detail=f"Invalid max_tokens for '{op_name}': must be positive int") + raise HTTPException( + status_code=400, + detail=f"Invalid max_tokens for '{op_name}': must be positive int", + ) if "model" in op_value and op_value["model"] is not None: if not registry.get_by_name(op_value["model"]): - raise HTTPException(status_code=400, detail=f"Model '{op_value['model']}' not found in registry") - - if "response_format" in op_value and op_value["response_format"] is not None: + raise HTTPException( + status_code=400, + detail=f"Model '{op_value['model']}' not found in registry", + ) + + if ( + "response_format" in op_value + and op_value["response_format"] is not None + ): if op_value["response_format"] != "json": - raise HTTPException(status_code=400, detail=f"response_format must be 'json' or null") + raise HTTPException( + status_code=400, + detail=f"response_format must be 'json' or null", + ) if save_config_section("llm_operations", operations): load_models_config(force_reload=True) @@ -958,11 +1055,21 @@ async def test_llm_model(model_name: Optional[str]): if model_name: model_def = registry.get_by_name(model_name) if not model_def: - return {"success": False, "model_name": model_name, "error": f"Model '{model_name}' not found", "status": "error"} + return { + "success": False, + "model_name": model_name, + "error": f"Model '{model_name}' not found", + "status": "error", + } else: model_def = registry.get_default("llm") if not model_def: - return {"success": False, "model_name": None, "error": "No default LLM configured", "status": "error"} + return { + "success": False, + "model_name": None, + "error": "No default LLM configured", + "status": "error", + } client = create_openai_client( api_key=model_def.api_key or "", @@ -998,6 +1105,7 @@ async def test_llm_model(model_name: Optional[str]): # Chat Configuration Management Functions + async def get_chat_config_yaml() -> str: """Get chat system prompt as plain text.""" try: @@ -1012,11 +1120,11 @@ async def get_chat_config_yaml() -> str: if not os.path.exists(config_path): return default_prompt - with open(config_path, 'r') as f: + with open(config_path, "r") as f: full_config = _yaml.load(f) or {} - chat_config = full_config.get('chat', {}) - system_prompt = chat_config.get('system_prompt', default_prompt) + chat_config = full_config.get("chat", {}) + system_prompt = chat_config.get("system_prompt", default_prompt) # Return just the prompt text, not the YAML structure return system_prompt @@ -1042,26 +1150,26 @@ async def save_chat_config_yaml(prompt_text: str) -> dict: raise ValueError("Prompt too long (maximum 10000 characters)") # Create chat config dict - chat_config = {'system_prompt': prompt_text} + chat_config = {"system_prompt": prompt_text} # Load full config if os.path.exists(config_path): - with open(config_path, 'r') as f: + with open(config_path, "r") as f: full_config = _yaml.load(f) or {} else: full_config = {} # Backup existing config if os.path.exists(config_path): - backup_path = str(config_path) + '.backup' + backup_path = str(config_path) + ".backup" shutil.copy2(config_path, backup_path) logger.info(f"Created config backup at {backup_path}") # Update chat section - full_config['chat'] = chat_config + full_config["chat"] = chat_config # Save - with open(config_path, 'w') as f: + with open(config_path, "w") as f: _yaml.dump(full_config, f) # Reload config in memory (hot-reload) @@ -1087,7 +1195,10 @@ async def validate_chat_config_yaml(prompt_text: str) -> dict: if len(prompt_text) < 10: return {"valid": False, "error": "Prompt too short (minimum 10 characters)"} if len(prompt_text) > 10000: - return {"valid": False, "error": "Prompt too long (maximum 10000 characters)"} + return { + "valid": False, + "error": "Prompt too long (maximum 10000 characters)", + } return {"valid": True, "message": "Configuration is valid"} @@ -1098,6 +1209,7 @@ async def validate_chat_config_yaml(prompt_text: str) -> dict: # Plugin Configuration Management Functions + async def get_plugins_config_yaml() -> str: """Get plugins configuration as YAML text.""" try: @@ -1120,7 +1232,7 @@ async def get_plugins_config_yaml() -> str: if not plugins_yml_path.exists(): return default_config - with open(plugins_yml_path, 'r') as f: + with open(plugins_yml_path, "r") as f: yaml_content = f.read() return yaml_content @@ -1142,7 +1254,7 @@ async def save_plugins_config_yaml(yaml_content: str) -> dict: raise ValueError("Configuration must be a YAML dictionary") # Validate has 'plugins' key - if 'plugins' not in parsed_config: + if "plugins" not in parsed_config: raise ValueError("Configuration must contain 'plugins' key") except ValueError: @@ -1155,12 +1267,12 @@ async def save_plugins_config_yaml(yaml_content: str) -> dict: # Backup existing config if plugins_yml_path.exists(): - backup_path = str(plugins_yml_path) + '.backup' + backup_path = str(plugins_yml_path) + ".backup" shutil.copy2(plugins_yml_path, backup_path) logger.info(f"Created plugins config backup at {backup_path}") # Save new config - with open(plugins_yml_path, 'w') as f: + with open(plugins_yml_path, "w") as f: f.write(yaml_content) # Hot-reload plugins and signal worker restart @@ -1201,35 +1313,55 @@ async def validate_plugins_config_yaml(yaml_content: str) -> dict: if not isinstance(parsed_config, dict): return {"valid": False, "error": "Configuration must be a YAML dictionary"} - if 'plugins' not in parsed_config: + if "plugins" not in parsed_config: return {"valid": False, "error": "Configuration must contain 'plugins' key"} - plugins = parsed_config['plugins'] + plugins = parsed_config["plugins"] if not isinstance(plugins, dict): return {"valid": False, "error": "'plugins' must be a dictionary"} # Validate each plugin - valid_access_levels = ['transcript', 'conversation', 'memory'] - valid_trigger_types = ['wake_word', 'always', 'conditional'] + valid_access_levels = ["transcript", "conversation", "memory"] + valid_trigger_types = ["wake_word", "always", "conditional"] for plugin_id, plugin_config in plugins.items(): if not isinstance(plugin_config, dict): - return {"valid": False, "error": f"Plugin '{plugin_id}' config must be a dictionary"} + return { + "valid": False, + "error": f"Plugin '{plugin_id}' config must be a dictionary", + } # Check required fields - if 'enabled' in plugin_config and not isinstance(plugin_config['enabled'], bool): - return {"valid": False, "error": f"Plugin '{plugin_id}': 'enabled' must be boolean"} + if "enabled" in plugin_config and not isinstance( + plugin_config["enabled"], bool + ): + return { + "valid": False, + "error": f"Plugin '{plugin_id}': 'enabled' must be boolean", + } - if 'access_level' in plugin_config and plugin_config['access_level'] not in valid_access_levels: - return {"valid": False, "error": f"Plugin '{plugin_id}': invalid access_level (must be one of {valid_access_levels})"} + if ( + "access_level" in plugin_config + and plugin_config["access_level"] not in valid_access_levels + ): + return { + "valid": False, + "error": f"Plugin '{plugin_id}': invalid access_level (must be one of {valid_access_levels})", + } - if 'trigger' in plugin_config: - trigger = plugin_config['trigger'] + if "trigger" in plugin_config: + trigger = plugin_config["trigger"] if not isinstance(trigger, dict): - return {"valid": False, "error": f"Plugin '{plugin_id}': 'trigger' must be a dictionary"} + return { + "valid": False, + "error": f"Plugin '{plugin_id}': 'trigger' must be a dictionary", + } - if 'type' in trigger and trigger['type'] not in valid_trigger_types: - return {"valid": False, "error": f"Plugin '{plugin_id}': invalid trigger type (must be one of {valid_trigger_types})"} + if "type" in trigger and trigger["type"] not in valid_trigger_types: + return { + "valid": False, + "error": f"Plugin '{plugin_id}': invalid trigger type (must be one of {valid_trigger_types})", + } return {"valid": True, "message": "Configuration is valid"} @@ -1314,9 +1446,11 @@ async def reload_plugins_controller(app=None) -> dict: return { "success": reload_result.get("success", False), - "message": "Plugins reloaded and worker restart signaled" - if worker_signal_sent - else "Plugins reloaded but worker restart signal failed", + "message": ( + "Plugins reloaded and worker restart signaled" + if worker_signal_sent + else "Plugins reloaded but worker restart signal failed" + ), "reload": reload_result, "worker_signal_sent": worker_signal_sent, } @@ -1324,6 +1458,7 @@ async def reload_plugins_controller(app=None) -> dict: # Structured Plugin Configuration Management Functions (Form-based UI) + async def get_plugins_metadata() -> dict: """Get plugin metadata for form-based configuration UI. @@ -1350,30 +1485,28 @@ async def get_plugins_metadata() -> dict: orchestration_configs = {} if plugins_yml_path.exists(): - with open(plugins_yml_path, 'r') as f: + with open(plugins_yml_path, "r") as f: plugins_data = _yaml.load(f) or {} - orchestration_configs = plugins_data.get('plugins', {}) + orchestration_configs = plugins_data.get("plugins", {}) # Build metadata for each plugin plugins_metadata = [] for plugin_id, plugin_class in discovered_plugins.items(): # Get orchestration config (or empty dict if not configured) - orchestration_config = orchestration_configs.get(plugin_id, { - 'enabled': False, - 'events': [], - 'condition': {'type': 'always'} - }) + orchestration_config = orchestration_configs.get( + plugin_id, + {"enabled": False, "events": [], "condition": {"type": "always"}}, + ) # Get complete metadata including schema - metadata = get_plugin_metadata(plugin_id, plugin_class, orchestration_config) + metadata = get_plugin_metadata( + plugin_id, plugin_class, orchestration_config + ) plugins_metadata.append(metadata) logger.info(f"Retrieved metadata for {len(plugins_metadata)} plugins") - return { - "plugins": plugins_metadata, - "status": "success" - } + return {"plugins": plugins_metadata, "status": "success"} except Exception as e: logger.exception("Error getting plugins metadata") @@ -1396,7 +1529,10 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: Success message with list of updated files """ try: - from advanced_omi_backend.services.plugin_service import _get_plugins_dir, discover_plugins + from advanced_omi_backend.services.plugin_service import ( + _get_plugins_dir, + discover_plugins, + ) # Validate plugin exists discovered_plugins = discover_plugins() @@ -1406,84 +1542,87 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: updated_files = [] # 1. Update config/plugins.yml (orchestration) - if 'orchestration' in config: + if "orchestration" in config: plugins_yml_path = get_plugins_yml_path() # Load current plugins.yml if plugins_yml_path.exists(): - with open(plugins_yml_path, 'r') as f: + with open(plugins_yml_path, "r") as f: plugins_data = _yaml.load(f) or {} else: plugins_data = {} - if 'plugins' not in plugins_data: - plugins_data['plugins'] = {} + if "plugins" not in plugins_data: + plugins_data["plugins"] = {} # Update orchestration config - orchestration = config['orchestration'] - plugins_data['plugins'][plugin_id] = { - 'enabled': orchestration.get('enabled', False), - 'events': orchestration.get('events', []), - 'condition': orchestration.get('condition', {'type': 'always'}) + orchestration = config["orchestration"] + plugins_data["plugins"][plugin_id] = { + "enabled": orchestration.get("enabled", False), + "events": orchestration.get("events", []), + "condition": orchestration.get("condition", {"type": "always"}), } # Create backup if plugins_yml_path.exists(): - backup_path = str(plugins_yml_path) + '.backup' + backup_path = str(plugins_yml_path) + ".backup" shutil.copy2(plugins_yml_path, backup_path) # Create config directory if needed plugins_yml_path.parent.mkdir(parents=True, exist_ok=True) # Write updated plugins.yml - with open(plugins_yml_path, 'w') as f: + with open(plugins_yml_path, "w") as f: _yaml.dump(plugins_data, f) updated_files.append(str(plugins_yml_path)) - logger.info(f"Updated orchestration config for '{plugin_id}' in {plugins_yml_path}") + logger.info( + f"Updated orchestration config for '{plugin_id}' in {plugins_yml_path}" + ) # 2. Update plugins/{plugin_id}/config.yml (settings with env var references) - if 'settings' in config: + if "settings" in config: plugins_dir = _get_plugins_dir() plugin_config_path = plugins_dir / plugin_id / "config.yml" # Load current config.yml if plugin_config_path.exists(): - with open(plugin_config_path, 'r') as f: + with open(plugin_config_path, "r") as f: plugin_config_data = _yaml.load(f) or {} else: plugin_config_data = {} # Update settings (preserve ${ENV_VAR} references) - settings = config['settings'] + settings = config["settings"] plugin_config_data.update(settings) # Create backup if plugin_config_path.exists(): - backup_path = str(plugin_config_path) + '.backup' + backup_path = str(plugin_config_path) + ".backup" shutil.copy2(plugin_config_path, backup_path) # Write updated config.yml - with open(plugin_config_path, 'w') as f: + with open(plugin_config_path, "w") as f: _yaml.dump(plugin_config_data, f) updated_files.append(str(plugin_config_path)) logger.info(f"Updated settings for '{plugin_id}' in {plugin_config_path}") # 3. Update per-plugin .env (only changed env vars) - if 'env_vars' in config and config['env_vars']: + if "env_vars" in config and config["env_vars"]: from advanced_omi_backend.services.plugin_service import save_plugin_env # Filter out masked values (unchanged secrets) changed_vars = { - k: v for k, v in config['env_vars'].items() - if v != 'β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’' + k: v for k, v in config["env_vars"].items() if v != "β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’" } if changed_vars: env_path = save_plugin_env(plugin_id, changed_vars) updated_files.append(str(env_path)) - logger.info(f"Saved {len(changed_vars)} env var(s) to per-plugin .env for '{plugin_id}'") + logger.info( + f"Saved {len(changed_vars)} env var(s) to per-plugin .env for '{plugin_id}'" + ) # Update os.environ so hot-reload picks up changes immediately for k, v in changed_vars.items(): @@ -1496,7 +1635,9 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: except Exception as reload_err: logger.warning(f"Auto-reload failed, manual restart needed: {reload_err}") - message = f"Plugin '{plugin_id}' configuration updated and reloaded successfully." + message = ( + f"Plugin '{plugin_id}' configuration updated and reloaded successfully." + ) if reload_result is None: message = f"Plugin '{plugin_id}' configuration updated. Restart backend for changes to take effect." @@ -1505,7 +1646,7 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict: "message": message, "updated_files": updated_files, "reload": reload_result, - "status": "success" + "status": "success", } except Exception as e: @@ -1541,29 +1682,29 @@ async def test_plugin_connection(plugin_id: str, config: dict) -> dict: plugin_class = discovered_plugins[plugin_id] # Check if plugin supports testing - if not hasattr(plugin_class, 'test_connection'): + if not hasattr(plugin_class, "test_connection"): return { "success": False, "message": f"Plugin '{plugin_id}' does not support connection testing", - "status": "unsupported" + "status": "unsupported", } # Build complete config from provided data test_config = {} # Merge settings - if 'settings' in config: - test_config.update(config['settings']) + if "settings" in config: + test_config.update(config["settings"]) # Load per-plugin env for resolving masked values plugin_env = load_plugin_env(plugin_id) # Add env vars (expand any ${ENV_VAR} references with test values) - if 'env_vars' in config: - for key, value in config['env_vars'].items(): + if "env_vars" in config: + for key, value in config["env_vars"].items(): # For masked values, resolve from per-plugin .env then os.environ - if value == 'β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’': - value = plugin_env.get(key) or os.getenv(key, '') + if value == "β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’β€’": + value = plugin_env.get(key) or os.getenv(key, "") test_config[key.lower()] = value # Expand any remaining env var references @@ -1572,7 +1713,9 @@ async def test_plugin_connection(plugin_id: str, config: dict) -> dict: # Call plugin's test_connection static method result = await plugin_class.test_connection(test_config) - logger.info(f"Test connection for '{plugin_id}': {result.get('message', 'No message')}") + logger.info( + f"Test connection for '{plugin_id}': {result.get('message', 'No message')}" + ) return result @@ -1581,12 +1724,13 @@ async def test_plugin_connection(plugin_id: str, config: dict) -> dict: return { "success": False, "message": f"Connection test failed: {str(e)}", - "status": "error" + "status": "error", } # Plugin Lifecycle Management Functions (create / write-code / delete) + def _snake_to_pascal(snake_str: str) -> str: """Convert snake_case to PascalCase.""" return "".join(word.capitalize() for word in snake_str.split("_")) @@ -1615,25 +1759,40 @@ async def create_plugin( Returns: Success dict with plugin_id and created_files list """ - from advanced_omi_backend.services.plugin_service import _get_plugins_dir, discover_plugins + from advanced_omi_backend.services.plugin_service import ( + _get_plugins_dir, + discover_plugins, + ) # Validate name if not plugin_name.replace("_", "").isalnum(): - return {"success": False, "error": "Plugin name must be alphanumeric with underscores only"} + return { + "success": False, + "error": "Plugin name must be alphanumeric with underscores only", + } if not re.match(r"^[a-z][a-z0-9_]*$", plugin_name): - return {"success": False, "error": "Plugin name must be lowercase snake_case starting with a letter"} + return { + "success": False, + "error": "Plugin name must be lowercase snake_case starting with a letter", + } plugins_dir = _get_plugins_dir() plugin_dir = plugins_dir / plugin_name # Collision check if plugin_dir.exists(): - return {"success": False, "error": f"Plugin '{plugin_name}' already exists at {plugin_dir}"} + return { + "success": False, + "error": f"Plugin '{plugin_name}' already exists at {plugin_dir}", + } discovered = discover_plugins() if plugin_name in discovered: - return {"success": False, "error": f"Plugin '{plugin_name}' is already registered"} + return { + "success": False, + "error": f"Plugin '{plugin_name}' is already registered", + } class_name = _snake_to_pascal(plugin_name) + "Plugin" created_files: list[str] = [] @@ -1650,8 +1809,14 @@ async def create_plugin( (plugin_dir / "plugin.py").write_text(plugin_code, encoding="utf-8") else: # Write standard boilerplate - events_str = ", ".join(f'"{e}"' for e in events) if events else '"conversation.complete"' - boilerplate = inspect.cleandoc(f''' + events_str = ( + ", ".join(f'"{e}"' for e in events) + if events + else '"conversation.complete"' + ) + boilerplate = ( + inspect.cleandoc( + f''' """ {class_name} implementation. @@ -1688,7 +1853,10 @@ async def cleanup(self): async def on_conversation_complete(self, context: PluginContext) -> Optional[PluginResult]: logger.info(f"Processing conversation for user: {{context.user_id}}") return PluginResult(success=True, message="OK") - ''') + "\n" + ''' + ) + + "\n" + ) (plugin_dir / "plugin.py").write_text(boilerplate, encoding="utf-8") created_files.append("plugin.py") @@ -1699,7 +1867,7 @@ async def on_conversation_complete(self, context: PluginContext) -> Optional[Plu # config.yml config_yml = {"description": description} - with open(plugin_dir / "config.yml", 'w', encoding="utf-8") as f: + with open(plugin_dir / "config.yml", "w", encoding="utf-8") as f: _yaml.dump(config_yml, f) created_files.append("config.yml") @@ -1765,7 +1933,10 @@ async def write_plugin_code( plugin_dir = plugins_dir / plugin_id if not plugin_dir.exists(): - return {"success": False, "error": f"Plugin '{plugin_id}' not found at {plugin_dir}"} + return { + "success": False, + "error": f"Plugin '{plugin_id}' not found at {plugin_dir}", + } updated_files: list[str] = [] @@ -1848,9 +2019,14 @@ async def delete_plugin(plugin_id: str, remove_files: bool = False) -> dict: logger.info(f"Removed plugin directory: {plugin_dir}") if not removed_from_yml and not files_removed: - return {"success": False, "error": f"Plugin '{plugin_id}' not found in plugins.yml or on disk"} + return { + "success": False, + "error": f"Plugin '{plugin_id}' not found in plugins.yml or on disk", + } - logger.info(f"Deleted plugin '{plugin_id}' (yml={removed_from_yml}, files={files_removed})") + logger.info( + f"Deleted plugin '{plugin_id}' (yml={removed_from_yml}, files={files_removed})" + ) return { "success": True, "plugin_id": plugin_id, diff --git a/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py index ce801327..a0c10e4f 100644 --- a/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py +++ b/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py @@ -9,11 +9,7 @@ from fastapi import HTTPException from fastapi.responses import JSONResponse -from advanced_omi_backend.auth import ( - ADMIN_EMAIL, - UserManager, - get_user_db, -) +from advanced_omi_backend.auth import ADMIN_EMAIL, UserManager, get_user_db from advanced_omi_backend.client_manager import get_user_clients_all from advanced_omi_backend.database import db, users_col from advanced_omi_backend.models.conversation import Conversation @@ -50,7 +46,9 @@ async def create_user(user_data: UserCreate): # If we get here, user exists return JSONResponse( status_code=409, - content={"message": f"User with email {user_data.email} already exists"}, + content={ + "message": f"User with email {user_data.email} already exists" + }, ) except Exception: # User doesn't exist, continue with creation @@ -61,15 +59,17 @@ async def create_user(user_data: UserCreate): # Return the full user object (serialized via UserRead schema) from advanced_omi_backend.models.user import UserRead + user_read = UserRead.model_validate(user) return JSONResponse( status_code=201, - content=user_read.model_dump(mode='json'), + content=user_read.model_dump(mode="json"), ) except Exception as e: import traceback + error_details = traceback.format_exc() logger.error(f"Error creating user: {e}") logger.error(f"Full traceback: {error_details}") @@ -89,15 +89,16 @@ async def update_user(user_id: str, user_data: UserUpdate): logger.error(f"Invalid ObjectId format for user_id {user_id}: {e}") return JSONResponse( status_code=400, - content={"message": f"Invalid user_id format: {user_id}. Must be a valid ObjectId."}, + content={ + "message": f"Invalid user_id format: {user_id}. Must be a valid ObjectId." + }, ) # Check if user exists existing_user = await users_col.find_one({"_id": object_id}) if not existing_user: return JSONResponse( - status_code=404, - content={"message": f"User {user_id} not found"} + status_code=404, content={"message": f"User {user_id} not found"} ) # Get user database and create user manager @@ -114,15 +115,17 @@ async def update_user(user_id: str, user_data: UserUpdate): # Return the full user object (serialized via UserRead schema) from advanced_omi_backend.models.user import UserRead + user_read = UserRead.model_validate(updated_user) return JSONResponse( status_code=200, - content=user_read.model_dump(mode='json'), + content=user_read.model_dump(mode="json"), ) except Exception as e: import traceback + error_details = traceback.format_exc() logger.error(f"Error updating user: {e}") logger.error(f"Full traceback: {error_details}") @@ -154,7 +157,9 @@ async def delete_user( # Check if user exists existing_user = await users_col.find_one({"_id": object_id}) if not existing_user: - return JSONResponse(status_code=404, content={"message": f"User {user_id} not found"}) + return JSONResponse( + status_code=404, content={"message": f"User {user_id} not found"} + ) # Prevent deletion of administrator user user_email = existing_user.get("email", "") @@ -176,7 +181,9 @@ async def delete_user( if delete_conversations: # Delete all conversations for this user - conversations_result = await Conversation.find(Conversation.user_id == user_id).delete() + conversations_result = await Conversation.find( + Conversation.user_id == user_id + ).delete() deleted_data["conversations_deleted"] = conversations_result.deleted_count if delete_memories: @@ -196,7 +203,9 @@ async def delete_user( message = f"User {user_id} deleted successfully" deleted_items = [] if delete_conversations and deleted_data.get("conversations_deleted", 0) > 0: - deleted_items.append(f"{deleted_data['conversations_deleted']} conversations") + deleted_items.append( + f"{deleted_data['conversations_deleted']} conversations" + ) if delete_memories and deleted_data.get("memories_deleted", 0) > 0: deleted_items.append(f"{deleted_data['memories_deleted']} memories") diff --git a/backends/advanced/src/advanced_omi_backend/cron_scheduler.py b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py index a496516f..3e3f6b8b 100644 --- a/backends/advanced/src/advanced_omi_backend/cron_scheduler.py +++ b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py @@ -32,6 +32,7 @@ # Data classes # --------------------------------------------------------------------------- + @dataclass class CronJobConfig: job_id: str @@ -66,6 +67,7 @@ def _get_job_func(job_id: str) -> Optional[JobFunc]: # Scheduler # --------------------------------------------------------------------------- + class CronScheduler: def __init__(self) -> None: self.jobs: Dict[str, CronJobConfig] = {} @@ -79,6 +81,7 @@ def __init__(self) -> None: async def start(self) -> None: """Load config, restore state from Redis, and start the scheduler loop.""" import os + redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") self._redis = aioredis.from_url(redis_url, decode_responses=True) @@ -129,7 +132,9 @@ async def update_job( if not croniter.is_valid(schedule): raise ValueError(f"Invalid cron expression: {schedule}") cfg.schedule = schedule - cfg.next_run = croniter(schedule, datetime.now(timezone.utc)).get_next(datetime) + cfg.next_run = croniter(schedule, datetime.now(timezone.utc)).get_next( + datetime + ) if enabled is not None: cfg.enabled = enabled @@ -137,7 +142,11 @@ async def update_job( # Persist changes to config.yml save_config_section( f"cron_jobs.{job_id}", - {"enabled": cfg.enabled, "schedule": cfg.schedule, "description": cfg.description}, + { + "enabled": cfg.enabled, + "schedule": cfg.schedule, + "description": cfg.description, + }, ) # Update next_run in Redis @@ -147,22 +156,26 @@ async def update_job( cfg.next_run.isoformat(), ) - logger.info(f"Updated cron job '{job_id}': enabled={cfg.enabled}, schedule={cfg.schedule}") + logger.info( + f"Updated cron job '{job_id}': enabled={cfg.enabled}, schedule={cfg.schedule}" + ) async def get_all_jobs_status(self) -> List[dict]: """Return status of all registered cron jobs.""" result = [] for job_id, cfg in self.jobs.items(): - result.append({ - "job_id": job_id, - "enabled": cfg.enabled, - "schedule": cfg.schedule, - "description": cfg.description, - "last_run": cfg.last_run.isoformat() if cfg.last_run else None, - "next_run": cfg.next_run.isoformat() if cfg.next_run else None, - "running": cfg.running, - "last_error": cfg.last_error, - }) + result.append( + { + "job_id": job_id, + "enabled": cfg.enabled, + "schedule": cfg.schedule, + "description": cfg.description, + "last_run": cfg.last_run.isoformat() if cfg.last_run else None, + "next_run": cfg.next_run.isoformat() if cfg.next_run else None, + "running": cfg.running, + "last_error": cfg.last_error, + } + ) return result # -- internals ----------------------------------------------------------- @@ -175,7 +188,9 @@ def _load_jobs_from_config(self) -> None: for job_id, job_cfg in cron_section.items(): schedule = str(job_cfg.get("schedule", "0 * * * *")) if not croniter.is_valid(schedule): - logger.warning(f"Invalid cron expression for job '{job_id}': {schedule} β€” skipping") + logger.warning( + f"Invalid cron expression for job '{job_id}': {schedule} β€” skipping" + ) continue now = datetime.now(timezone.utc) self.jobs[job_id] = CronJobConfig( diff --git a/backends/advanced/src/advanced_omi_backend/database.py b/backends/advanced/src/advanced_omi_backend/database.py index 1b214b6d..392b22c8 100644 --- a/backends/advanced/src/advanced_omi_backend/database.py +++ b/backends/advanced/src/advanced_omi_backend/database.py @@ -44,5 +44,3 @@ def get_collections(): return { "users_col": users_col, } - - diff --git a/backends/advanced/src/advanced_omi_backend/llm_client.py b/backends/advanced/src/advanced_omi_backend/llm_client.py index 96ccc77b..4aaa223f 100644 --- a/backends/advanced/src/advanced_omi_backend/llm_client.py +++ b/backends/advanced/src/advanced_omi_backend/llm_client.py @@ -11,7 +11,10 @@ from typing import Any, Dict, Optional from advanced_omi_backend.model_registry import get_models_registry -from advanced_omi_backend.openai_factory import create_openai_client, is_langfuse_enabled +from advanced_omi_backend.openai_factory import ( + create_openai_client, + is_langfuse_enabled, +) from advanced_omi_backend.services.memory.config import ( load_config_yml as _load_root_config, ) @@ -62,7 +65,9 @@ def __init__( self.base_url = base_url self.model = model if not self.api_key or not self.base_url or not self.model: - raise ValueError(f"LLM configuration incomplete: api_key={'set' if self.api_key else 'MISSING'}, base_url={'set' if self.base_url else 'MISSING'}, model={'set' if self.model else 'MISSING'}") + raise ValueError( + f"LLM configuration incomplete: api_key={'set' if self.api_key else 'MISSING'}, base_url={'set' if self.base_url else 'MISSING'}, model={'set' if self.model else 'MISSING'}" + ) # Initialize OpenAI client with optional Langfuse tracing try: @@ -71,14 +76,19 @@ def __init__( ) self.logger.info(f"OpenAI client initialized, base_url: {self.base_url}") except ImportError: - self.logger.error("OpenAI library not installed. Install with: pip install openai") + self.logger.error( + "OpenAI library not installed. Install with: pip install openai" + ) raise except Exception as e: self.logger.error(f"Failed to initialize OpenAI client: {e}") raise def generate( - self, prompt: str, model: str | None = None, temperature: float | None = None, + self, + prompt: str, + model: str | None = None, + temperature: float | None = None, **langfuse_kwargs, ) -> str: """Generate text completion using OpenAI-compatible API.""" @@ -101,8 +111,12 @@ def generate( raise def chat_with_tools( - self, messages: list, tools: list | None = None, model: str | None = None, - temperature: float | None = None, **langfuse_kwargs, + self, + messages: list, + tools: list | None = None, + model: str | None = None, + temperature: float | None = None, + **langfuse_kwargs, ): """Chat completion with tool/function calling support. Returns raw response object.""" model_name = model or self.model @@ -127,14 +141,18 @@ def health_check(self) -> Dict: "status": "βœ… Connected", "base_url": self.base_url, "default_model": self.model, - "api_key_configured": bool(self.api_key and self.api_key != "dummy"), + "api_key_configured": bool( + self.api_key and self.api_key != "dummy" + ), } else: return { "status": "⚠️ Configuration incomplete", "base_url": self.base_url, "default_model": self.model, - "api_key_configured": bool(self.api_key and self.api_key != "dummy"), + "api_key_configured": bool( + self.api_key and self.api_key != "dummy" + ), } except Exception as e: self.logger.error(f"Health check failed: {e}") @@ -157,11 +175,13 @@ class LLMClientFactory: def create_client() -> LLMClient: """Create an LLM client based on model registry configuration (config.yml).""" registry = get_models_registry() - + if registry: llm_def = registry.get_default("llm") if llm_def: - logger.info(f"Creating LLM client from registry: {llm_def.name} ({llm_def.model_provider})") + logger.info( + f"Creating LLM client from registry: {llm_def.name} ({llm_def.model_provider})" + ) params = llm_def.model_params or {} return OpenAILLMClient( api_key=llm_def.api_key, @@ -169,7 +189,7 @@ def create_client() -> LLMClient: model=llm_def.model_name, temperature=params.get("temperature", 0.1), ) - + raise ValueError("No default LLM defined in config.yml") @staticmethod diff --git a/backends/advanced/src/advanced_omi_backend/main.py b/backends/advanced/src/advanced_omi_backend/main.py index ee60696f..5b0f3f05 100644 --- a/backends/advanced/src/advanced_omi_backend/main.py +++ b/backends/advanced/src/advanced_omi_backend/main.py @@ -46,5 +46,5 @@ port=port, reload=False, # Set to True for development access_log=False, # Disabled - using custom RequestLoggingMiddleware instead - log_level="info" + log_level="info", ) diff --git a/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py b/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py index 069d5239..21fa98d9 100644 --- a/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py +++ b/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py @@ -26,7 +26,9 @@ def setup_cors_middleware(app: FastAPI) -> None: config = get_app_config() logger.info(f"🌐 CORS configured with origins: {config.allowed_origins}") - logger.info(f"🌐 CORS also allows Tailscale IPs via regex: {config.tailscale_regex}") + logger.info( + f"🌐 CORS also allows Tailscale IPs via regex: {config.tailscale_regex}" + ) app.add_middleware( CORSMiddleware, @@ -157,6 +159,7 @@ async def dispatch(self, request: Request, call_next): # Recreate response with the body we consumed from starlette.responses import Response + return Response( content=response_body, status_code=response.status_code, @@ -191,8 +194,8 @@ async def database_exception_handler(request: Request, exc: Exception): content={ "detail": "Unable to connect to server. Please check your connection and try again.", "error_type": "connection_failure", - "error_category": "database" - } + "error_category": "database", + }, ) @app.exception_handler(ConnectionError) @@ -204,8 +207,8 @@ async def connection_exception_handler(request: Request, exc: ConnectionError): content={ "detail": "Unable to connect to server. Please check your connection and try again.", "error_type": "connection_failure", - "error_category": "network" - } + "error_category": "network", + }, ) @app.exception_handler(HTTPException) @@ -218,15 +221,12 @@ async def http_exception_handler(request: Request, exc: HTTPException): content={ "detail": exc.detail, "error_type": "authentication_failure", - "error_category": "security" - } + "error_category": "security", + }, ) # For other HTTP exceptions, return as normal - return JSONResponse( - status_code=exc.status_code, - content={"detail": exc.detail} - ) + return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) def setup_middleware(app: FastAPI) -> None: @@ -236,4 +236,4 @@ def setup_middleware(app: FastAPI) -> None: logger.info("πŸ“ Request logging middleware enabled") setup_cors_middleware(app) - setup_exception_handlers(app) \ No newline at end of file + setup_exception_handlers(app) diff --git a/backends/advanced/src/advanced_omi_backend/model_registry.py b/backends/advanced/src/advanced_omi_backend/model_registry.py index bc7e5fc5..e7a6ddf5 100644 --- a/backends/advanced/src/advanced_omi_backend/model_registry.py +++ b/backends/advanced/src/advanced_omi_backend/model_registry.py @@ -30,76 +30,95 @@ class ModelDef(BaseModel): """Model definition with validation. - + Represents a single model configuration (LLM, embedding, STT, TTS, etc.) from config.yml with automatic validation and type checking. """ - + model_config = ConfigDict( - extra='allow', # Allow extra fields for extensibility + extra="allow", # Allow extra fields for extensibility validate_assignment=True, # Validate on attribute assignment arbitrary_types_allowed=True, ) - + name: str = Field(..., min_length=1, description="Unique model identifier") - model_type: str = Field(..., description="Model type: llm, embedding, stt, tts, etc.") - model_provider: str = Field(default="unknown", description="Provider name: openai, ollama, deepgram, parakeet, vibevoice, etc.") - api_family: str = Field(default="openai", description="API family: openai, http, websocket, etc.") + model_type: str = Field( + ..., description="Model type: llm, embedding, stt, tts, etc." + ) + model_provider: str = Field( + default="unknown", + description="Provider name: openai, ollama, deepgram, parakeet, vibevoice, etc.", + ) + api_family: str = Field( + default="openai", description="API family: openai, http, websocket, etc." + ) model_name: str = Field(default="", description="Provider-specific model name") model_url: str = Field(default="", description="Base URL for API requests") - api_key: Optional[str] = Field(default=None, description="API key or authentication token") - description: Optional[str] = Field(default=None, description="Human-readable description") - model_params: Dict[str, Any] = Field(default_factory=dict, description="Model-specific parameters") - model_output: Optional[str] = Field(default=None, description="Output format: json, text, vector, etc.") - embedding_dimensions: Optional[int] = Field(default=None, ge=1, description="Embedding vector dimensions") - operations: Dict[str, Any] = Field(default_factory=dict, description="API operation definitions") + api_key: Optional[str] = Field( + default=None, description="API key or authentication token" + ) + description: Optional[str] = Field( + default=None, description="Human-readable description" + ) + model_params: Dict[str, Any] = Field( + default_factory=dict, description="Model-specific parameters" + ) + model_output: Optional[str] = Field( + default=None, description="Output format: json, text, vector, etc." + ) + embedding_dimensions: Optional[int] = Field( + default=None, ge=1, description="Embedding vector dimensions" + ) + operations: Dict[str, Any] = Field( + default_factory=dict, description="API operation definitions" + ) capabilities: List[str] = Field( default_factory=list, - description="Provider capabilities: word_timestamps, segments, diarization (for STT providers)" + description="Provider capabilities: word_timestamps, segments, diarization (for STT providers)", ) - - @field_validator('model_name', mode='before') + + @field_validator("model_name", mode="before") @classmethod def default_model_name(cls, v: Any, info) -> str: """Default model_name to name if not provided.""" - if not v and info.data.get('name'): - return info.data['name'] + if not v and info.data.get("name"): + return info.data["name"] return v or "" - - @field_validator('model_url', mode='before') + + @field_validator("model_url", mode="before") @classmethod def validate_url(cls, v: Any) -> str: """Ensure URL doesn't have trailing whitespace.""" if isinstance(v, str): return v.strip() return v or "" - - @field_validator('api_key', mode='before') + + @field_validator("api_key", mode="before") @classmethod def sanitize_api_key(cls, v: Any) -> Optional[str]: """Sanitize API key, treat empty strings as None.""" if isinstance(v, str): v = v.strip() - if not v or v.lower() in ['dummy', 'none', 'null']: + if not v or v.lower() in ["dummy", "none", "null"]: return None return v return v - - @model_validator(mode='after') + + @model_validator(mode="after") def validate_model(self) -> ModelDef: """Cross-field validation.""" # Ensure embedding models have dimensions specified - if self.model_type == 'embedding' and not self.embedding_dimensions: + if self.model_type == "embedding" and not self.embedding_dimensions: # Common defaults defaults = { - 'text-embedding-3-small': 1536, - 'text-embedding-3-large': 3072, - 'text-embedding-ada-002': 1536, - 'nomic-embed-text-v1.5': 768, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + "nomic-embed-text-v1.5": 768, } if self.model_name in defaults: self.embedding_dimensions = defaults[self.model_name] - + return self @@ -180,52 +199,49 @@ class AppModels(BaseModel): """ model_config = ConfigDict( - extra='allow', + extra="allow", validate_assignment=True, ) defaults: Dict[str, str] = Field( - default_factory=dict, - description="Default model names for each model_type" + default_factory=dict, description="Default model names for each model_type" ) models: Dict[str, ModelDef] = Field( default_factory=dict, - description="All available model definitions keyed by name" + description="All available model definitions keyed by name", ) memory: Dict[str, Any] = Field( - default_factory=dict, - description="Memory service configuration" + default_factory=dict, description="Memory service configuration" ) speaker_recognition: Dict[str, Any] = Field( - default_factory=dict, - description="Speaker recognition service configuration" + default_factory=dict, description="Speaker recognition service configuration" ) chat: Dict[str, Any] = Field( default_factory=dict, - description="Chat service configuration including system prompt" + description="Chat service configuration including system prompt", ) llm_operations: Dict[str, LLMOperationConfig] = Field( default_factory=dict, - description="Per-operation LLM configuration (temperature, model override, etc.)" + description="Per-operation LLM configuration (temperature, model override, etc.)", ) - + def get_by_name(self, name: str) -> Optional[ModelDef]: """Get a model by its unique name. - + Args: name: Model name to look up - + Returns: ModelDef if found, None otherwise """ return self.models.get(name) - + def get_default(self, model_type: str) -> Optional[ModelDef]: """Get the default model for a given type. - + Args: model_type: Type of model (llm, embedding, stt, tts, etc.) - + Returns: Default ModelDef for the type, or first available model of that type, or None if no models of that type exist @@ -236,25 +252,25 @@ def get_default(self, model_type: str) -> Optional[ModelDef]: model = self.get_by_name(name) if model: return model - + # Fallback: first model of that type for m in self.models.values(): if m.model_type == model_type: return m - + return None - + def get_all_by_type(self, model_type: str) -> List[ModelDef]: """Get all models of a specific type. - + Args: model_type: Type of model to filter by - + Returns: List of ModelDef objects matching the type """ return [m for m in self.models.values() if m.model_type == model_type] - + def list_model_types(self) -> List[str]: """Get all unique model types in the registry. @@ -340,6 +356,7 @@ def _find_config_path() -> Path: Path to config.yml """ from advanced_omi_backend.config import get_config_yml_path + return get_config_yml_path() @@ -413,13 +430,13 @@ def load_models_config(force_reload: bool = False) -> Optional[AppModels]: def get_models_registry() -> Optional[AppModels]: """Get the global models registry. - + This is the primary interface for accessing model configurations. The registry is loaded once and cached for performance. - + Returns: AppModels instance, or None if config.yml not found - + Example: >>> registry = get_models_registry() >>> if registry: diff --git a/backends/advanced/src/advanced_omi_backend/models/__init__.py b/backends/advanced/src/advanced_omi_backend/models/__init__.py index a19fa0db..4dbedc18 100644 --- a/backends/advanced/src/advanced_omi_backend/models/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/models/__init__.py @@ -7,4 +7,4 @@ # Models can be imported directly from their files # e.g. from .job import TranscriptionJob -# e.g. from .conversation import Conversation, create_conversation \ No newline at end of file +# e.g. from .conversation import Conversation, create_conversation diff --git a/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py b/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py index 5f3b4c1d..d8ed0125 100644 --- a/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py +++ b/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py @@ -41,10 +41,7 @@ class AudioChunkDocument(Document): conversation_id: Indexed(str) = Field( description="Parent conversation ID (UUID format)" ) - chunk_index: int = Field( - description="Sequential chunk number (0-based)", - ge=0 - ) + chunk_index: int = Field(description="Sequential chunk number (0-based)", ge=0) # Audio data audio_data: bytes = Field( @@ -53,61 +50,48 @@ class AudioChunkDocument(Document): # Size tracking original_size: int = Field( - description="Original PCM size in bytes (before compression)", - gt=0 + description="Original PCM size in bytes (before compression)", gt=0 ) compressed_size: int = Field( - description="Opus-encoded size in bytes (after compression)", - gt=0 + description="Opus-encoded size in bytes (after compression)", gt=0 ) # Time boundaries start_time: float = Field( - description="Start time in seconds from conversation start", - ge=0.0 + description="Start time in seconds from conversation start", ge=0.0 ) end_time: float = Field( - description="End time in seconds from conversation start", - gt=0.0 + description="End time in seconds from conversation start", gt=0.0 ) duration: float = Field( - description="Chunk duration in seconds (typically 10.0)", - gt=0.0 + description="Chunk duration in seconds (typically 10.0)", gt=0.0 ) # Audio format - sample_rate: int = Field( - default=16000, - description="Original PCM sample rate (Hz)" - ) + sample_rate: int = Field(default=16000, description="Original PCM sample rate (Hz)") channels: int = Field( - default=1, - description="Number of audio channels (1=mono, 2=stereo)" + default=1, description="Number of audio channels (1=mono, 2=stereo)" ) # Optional analysis has_speech: Optional[bool] = Field( - default=None, - description="Voice Activity Detection result (if available)" + default=None, description="Voice Activity Detection result (if available)" ) # Metadata created_at: datetime = Field( - default_factory=datetime.utcnow, - description="Chunk creation timestamp" + default_factory=datetime.utcnow, description="Chunk creation timestamp" ) # Soft delete fields deleted: bool = Field( - default=False, - description="Whether this chunk was soft-deleted" + default=False, description="Whether this chunk was soft-deleted" ) deleted_at: Optional[datetime] = Field( - default=None, - description="When the chunk was marked as deleted" + default=None, description="When the chunk was marked as deleted" ) - @field_serializer('audio_data') + @field_serializer("audio_data") def serialize_audio_data(self, v: bytes) -> Binary: """ Convert bytes to BSON Binary for MongoDB storage. @@ -121,20 +105,18 @@ def serialize_audio_data(self, v: bytes) -> Binary: class Settings: """Beanie document settings.""" + name = "audio_chunks" indexes = [ # Primary query: Retrieve chunks in order for a conversation [("conversation_id", 1), ("chunk_index", 1)], - # Conversation lookup and counting "conversation_id", - # Maintenance queries (cleanup, monitoring) "created_at", - # Soft delete filtering - "deleted" + "deleted", ] @property diff --git a/backends/advanced/src/advanced_omi_backend/models/conversation.py b/backends/advanced/src/advanced_omi_backend/models/conversation.py index 79b6d798..62d0286e 100644 --- a/backends/advanced/src/advanced_omi_backend/models/conversation.py +++ b/backends/advanced/src/advanced_omi_backend/models/conversation.py @@ -22,179 +22,238 @@ class Conversation(Document): class MemoryProvider(str, Enum): """Supported memory providers.""" + CHRONICLE = "chronicle" OPENMEMORY_MCP = "openmemory_mcp" FRIEND_LITE = "friend_lite" # Legacy value class ConversationStatus(str, Enum): """Conversation processing status.""" + ACTIVE = "active" # Has running jobs or open websocket COMPLETED = "completed" # All jobs succeeded FAILED = "failed" # One or more jobs failed class EndReason(str, Enum): """Reason for conversation ending.""" + USER_STOPPED = "user_stopped" # User manually stopped recording - INACTIVITY_TIMEOUT = "inactivity_timeout" # No speech detected for threshold period - WEBSOCKET_DISCONNECT = "websocket_disconnect" # Connection lost (Bluetooth, network, etc.) + INACTIVITY_TIMEOUT = ( + "inactivity_timeout" # No speech detected for threshold period + ) + WEBSOCKET_DISCONNECT = ( + "websocket_disconnect" # Connection lost (Bluetooth, network, etc.) + ) MAX_DURATION = "max_duration" # Hit maximum conversation duration - CLOSE_REQUESTED = "close_requested" # External close signal (API, plugin, button) + CLOSE_REQUESTED = ( + "close_requested" # External close signal (API, plugin, button) + ) ERROR = "error" # Processing error forced conversation end UNKNOWN = "unknown" # Unknown or legacy reason # Nested Models class Word(BaseModel): """Individual word with timestamp in a transcript.""" + word: str = Field(description="Word text") start: float = Field(description="Start time in seconds") end: float = Field(description="End time in seconds") confidence: Optional[float] = Field(None, description="Confidence score (0-1)") speaker: Optional[int] = Field(None, description="Speaker ID from diarization") - speaker_confidence: Optional[float] = Field(None, description="Speaker diarization confidence") + speaker_confidence: Optional[float] = Field( + None, description="Speaker diarization confidence" + ) class SegmentType(str, Enum): """Type of transcript segment.""" + SPEECH = "speech" - EVENT = "event" # Non-speech: [laughter], [music], etc. - NOTE = "note" # User-inserted annotation/tag + EVENT = "event" # Non-speech: [laughter], [music], etc. + NOTE = "note" # User-inserted annotation/tag class SpeakerSegment(BaseModel): """Individual speaker segment in a transcript.""" + start: float = Field(description="Start time in seconds") end: float = Field(description="End time in seconds") text: str = Field(description="Transcript text for this segment") speaker: str = Field(description="Speaker identifier") segment_type: str = Field( default="speech", - description="Type: speech, event (non-speech from ASR), or note (user-inserted)" + description="Type: speech, event (non-speech from ASR), or note (user-inserted)", + ) + identified_as: Optional[str] = Field( + None, + description="Speaker name from speaker recognition (None if not identified)", ) - identified_as: Optional[str] = Field(None, description="Speaker name from speaker recognition (None if not identified)") confidence: Optional[float] = Field(None, description="Confidence score (0-1)") - words: List["Conversation.Word"] = Field(default_factory=list, description="Word-level timestamps for this segment") + words: List["Conversation.Word"] = Field( + default_factory=list, description="Word-level timestamps for this segment" + ) class TranscriptVersion(BaseModel): """Version of a transcript with processing metadata.""" + version_id: str = Field(description="Unique version identifier") transcript: Optional[str] = Field(None, description="Full transcript text") words: List["Conversation.Word"] = Field( default_factory=list, - description="Word-level timestamps for entire transcript" + description="Word-level timestamps for entire transcript", ) segments: List["Conversation.SpeakerSegment"] = Field( default_factory=list, - description="Speaker segments (filled by speaker recognition)" + description="Speaker segments (filled by speaker recognition)", + ) + provider: Optional[str] = Field( + None, + description="Transcription provider used (deepgram, parakeet, vibevoice, etc.)", + ) + model: Optional[str] = Field( + None, description="Model used (e.g., nova-3, parakeet)" ) - provider: Optional[str] = Field(None, description="Transcription provider used (deepgram, parakeet, vibevoice, etc.)") - model: Optional[str] = Field(None, description="Model used (e.g., nova-3, parakeet)") created_at: datetime = Field(description="When this version was created") - processing_time_seconds: Optional[float] = Field(None, description="Time taken to process") + processing_time_seconds: Optional[float] = Field( + None, description="Time taken to process" + ) diarization_source: Optional[str] = Field( None, - description="Source of speaker diarization: 'provider' (transcription service), 'pyannote' (speaker recognition), or None" + description="Source of speaker diarization: 'provider' (transcription service), 'pyannote' (speaker recognition), or None", + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional provider-specific metadata" ) - metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional provider-specific metadata") class MemoryVersion(BaseModel): """Version of memory extraction with processing metadata.""" + version_id: str = Field(description="Unique version identifier") memory_count: int = Field(description="Number of memories extracted") - transcript_version_id: str = Field(description="Which transcript version was used") - provider: "Conversation.MemoryProvider" = Field(description="Memory provider used") - model: Optional[str] = Field(None, description="Model used (e.g., gpt-4o-mini, llama3)") + transcript_version_id: str = Field( + description="Which transcript version was used" + ) + provider: "Conversation.MemoryProvider" = Field( + description="Memory provider used" + ) + model: Optional[str] = Field( + None, description="Model used (e.g., gpt-4o-mini, llama3)" + ) created_at: datetime = Field(description="When this version was created") - processing_time_seconds: Optional[float] = Field(None, description="Time taken to process") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional provider-specific metadata") + processing_time_seconds: Optional[float] = Field( + None, description="Time taken to process" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, description="Additional provider-specific metadata" + ) # Core identifiers - conversation_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique conversation identifier") + conversation_id: Indexed(str, unique=True) = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique conversation identifier", + ) user_id: Indexed(str) = Field(description="User who owns this conversation") client_id: Indexed(str) = Field(description="Client device identifier") # External file tracking (for deduplication of imported files) external_source_id: Optional[str] = Field( None, - description="External file identifier (e.g., Google Drive file_id) for deduplication" + description="External file identifier (e.g., Google Drive file_id) for deduplication", ) external_source_type: Optional[str] = Field( - None, - description="Type of external source (gdrive, dropbox, s3, etc.)" + None, description="Type of external source (gdrive, dropbox, s3, etc.)" ) # MongoDB chunk-based audio storage (new system) audio_chunks_count: Optional[int] = Field( - None, - description="Total number of 10-second audio chunks stored in MongoDB" + None, description="Total number of 10-second audio chunks stored in MongoDB" ) audio_total_duration: Optional[float] = Field( - None, - description="Total audio duration in seconds (sum of all chunks)" + None, description="Total audio duration in seconds (sum of all chunks)" ) audio_compression_ratio: Optional[float] = Field( None, - description="Compression ratio (compressed_size / original_size), typically ~0.047 for Opus" + description="Compression ratio (compressed_size / original_size), typically ~0.047 for Opus", ) # Markers (e.g., button events) captured during the session markers: List[Dict[str, Any]] = Field( default_factory=list, - description="Markers captured during audio session (button events, bookmarks, etc.)" + description="Markers captured during audio session (button events, bookmarks, etc.)", ) # Creation metadata - created_at: Indexed(datetime) = Field(default_factory=datetime.utcnow, description="When the conversation was created") + created_at: Indexed(datetime) = Field( + default_factory=datetime.utcnow, description="When the conversation was created" + ) # Processing status tracking - deleted: bool = Field(False, description="Whether this conversation was deleted due to processing failure") - deletion_reason: Optional[str] = Field(None, description="Reason for deletion (no_meaningful_speech, audio_file_not_ready, etc.)") - deleted_at: Optional[datetime] = Field(None, description="When the conversation was marked as deleted") + deleted: bool = Field( + False, + description="Whether this conversation was deleted due to processing failure", + ) + deletion_reason: Optional[str] = Field( + None, + description="Reason for deletion (no_meaningful_speech, audio_file_not_ready, etc.)", + ) + deleted_at: Optional[datetime] = Field( + None, description="When the conversation was marked as deleted" + ) # Always persist audio flag and processing status processing_status: Optional[str] = Field( None, - description="Processing status: pending_transcription, transcription_failed, completed" + description="Processing status: pending_transcription, transcription_failed, completed", ) always_persist: bool = Field( default=False, - description="Flag indicating conversation was created for audio persistence" + description="Flag indicating conversation was created for audio persistence", ) # Conversation completion tracking - end_reason: Optional["Conversation.EndReason"] = Field(None, description="Reason why the conversation ended") - completed_at: Optional[datetime] = Field(None, description="When the conversation was completed/closed") + end_reason: Optional["Conversation.EndReason"] = Field( + None, description="Reason why the conversation ended" + ) + completed_at: Optional[datetime] = Field( + None, description="When the conversation was completed/closed" + ) # Star/favorite - starred: bool = Field(False, description="Whether this conversation is starred/favorited") - starred_at: Optional[datetime] = Field(None, description="When the conversation was starred") + starred: bool = Field( + False, description="Whether this conversation is starred/favorited" + ) + starred_at: Optional[datetime] = Field( + None, description="When the conversation was starred" + ) # Summary fields (auto-generated from transcript) title: Optional[str] = Field(None, description="Auto-generated conversation title") - summary: Optional[str] = Field(None, description="Auto-generated short summary (1-2 sentences)") - detailed_summary: Optional[str] = Field(None, description="Auto-generated detailed summary (comprehensive, corrected content)") + summary: Optional[str] = Field( + None, description="Auto-generated short summary (1-2 sentences)" + ) + detailed_summary: Optional[str] = Field( + None, + description="Auto-generated detailed summary (comprehensive, corrected content)", + ) # Versioned processing transcript_versions: List["Conversation.TranscriptVersion"] = Field( - default_factory=list, - description="All transcript processing attempts" + default_factory=list, description="All transcript processing attempts" ) memory_versions: List["Conversation.MemoryVersion"] = Field( - default_factory=list, - description="All memory extraction attempts" + default_factory=list, description="All memory extraction attempts" ) # Active version pointers active_transcript_version: Optional[str] = Field( - None, - description="Version ID of currently active transcript" + None, description="Version ID of currently active transcript" ) active_memory_version: Optional[str] = Field( - None, - description="Version ID of currently active memory extraction" + None, description="Version ID of currently active memory extraction" ) # Legacy fields removed - use transcript_versions[active_transcript_version] and memory_versions[active_memory_version] # Frontend should access: conversation.active_transcript.segments, conversation.active_transcript.transcript - @model_validator(mode='before') + @model_validator(mode="before") @classmethod def clean_legacy_data(cls, data: Any) -> Any: """Clean up legacy/malformed data before Pydantic validation.""" @@ -203,26 +262,32 @@ def clean_legacy_data(cls, data: Any) -> Any: return data # Fix malformed transcript_versions (from old schema versions) - if 'transcript_versions' in data and isinstance(data['transcript_versions'], list): - for version in data['transcript_versions']: + if "transcript_versions" in data and isinstance( + data["transcript_versions"], list + ): + for version in data["transcript_versions"]: if isinstance(version, dict): # If segments is not a list, clear it - if 'segments' in version and not isinstance(version['segments'], list): - version['segments'] = [] + if "segments" in version and not isinstance( + version["segments"], list + ): + version["segments"] = [] # If transcript is a dict, clear it - if 'transcript' in version and isinstance(version['transcript'], dict): - version['transcript'] = None + if "transcript" in version and isinstance( + version["transcript"], dict + ): + version["transcript"] = None # Normalize provider to lowercase (legacy data had "Deepgram" instead of "deepgram") - if 'provider' in version and isinstance(version['provider'], str): - version['provider'] = version['provider'].lower() + if "provider" in version and isinstance(version["provider"], str): + version["provider"] = version["provider"].lower() # Fix speaker IDs in segments (legacy data had integers, need strings) - if 'segments' in version and isinstance(version['segments'], list): - for segment in version['segments']: - if isinstance(segment, dict) and 'speaker' in segment: - if isinstance(segment['speaker'], int): - segment['speaker'] = f"Speaker {segment['speaker']}" - elif not isinstance(segment['speaker'], str): - segment['speaker'] = "unknown" + if "segments" in version and isinstance(version["segments"], list): + for segment in version["segments"]: + if isinstance(segment, dict) and "speaker" in segment: + if isinstance(segment["speaker"], int): + segment["speaker"] = f"Speaker {segment['speaker']}" + elif not isinstance(segment["speaker"], str): + segment["speaker"] = "unknown" return data @@ -325,7 +390,7 @@ def add_transcript_version( model: Optional[str] = None, processing_time_seconds: Optional[float] = None, metadata: Optional[Dict[str, Any]] = None, - set_as_active: bool = True + set_as_active: bool = True, ) -> "Conversation.TranscriptVersion": """Add a new transcript version and optionally set it as active.""" new_version = Conversation.TranscriptVersion( @@ -337,7 +402,7 @@ def add_transcript_version( model=model, created_at=datetime.now(), processing_time_seconds=processing_time_seconds, - metadata=metadata or {} + metadata=metadata or {}, ) self.transcript_versions.append(new_version) @@ -356,7 +421,7 @@ def add_memory_version( model: Optional[str] = None, processing_time_seconds: Optional[float] = None, metadata: Optional[Dict[str, Any]] = None, - set_as_active: bool = True + set_as_active: bool = True, ) -> "Conversation.MemoryVersion": """Add a new memory version and optionally set it as active.""" new_version = Conversation.MemoryVersion( @@ -367,7 +432,7 @@ def add_memory_version( model=model, created_at=datetime.now(), processing_time_seconds=processing_time_seconds, - metadata=metadata or {} + metadata=metadata or {}, ) self.memory_versions.append(new_version) @@ -399,12 +464,27 @@ class Settings: "conversation_id", "user_id", "created_at", - [("user_id", 1), ("deleted", 1), ("created_at", -1)], # Compound index for paginated list queries - IndexModel([("external_source_id", 1)], sparse=True), # Sparse index for deduplication + [ + ("user_id", 1), + ("deleted", 1), + ("created_at", -1), + ], # Compound index for paginated list queries + IndexModel( + [("external_source_id", 1)], sparse=True + ), # Sparse index for deduplication IndexModel( - [("title", "text"), ("summary", "text"), ("detailed_summary", "text"), - ("transcript_versions.transcript", "text")], - weights={"title": 10, "summary": 5, "detailed_summary": 3, "transcript_versions.transcript": 1}, + [ + ("title", "text"), + ("summary", "text"), + ("detailed_summary", "text"), + ("transcript_versions.transcript", "text"), + ], + weights={ + "title": 10, + "summary": 5, + "detailed_summary": 3, + "transcript_versions.transcript": 1, + }, name="conversation_text_search", ), ] @@ -458,4 +538,4 @@ def create_conversation( if conversation_id is not None: conv_data["conversation_id"] = conversation_id - return Conversation(**conv_data) \ No newline at end of file + return Conversation(**conv_data) diff --git a/backends/advanced/src/advanced_omi_backend/models/job.py b/backends/advanced/src/advanced_omi_backend/models/job.py index f7f44d4c..91870643 100644 --- a/backends/advanced/src/advanced_omi_backend/models/job.py +++ b/backends/advanced/src/advanced_omi_backend/models/job.py @@ -27,11 +27,12 @@ _beanie_initialized = False _beanie_init_lock = asyncio.Lock() + async def _ensure_beanie_initialized(): """Ensure Beanie is initialized in the current process (for RQ workers).""" global _beanie_initialized async with _beanie_init_lock: - if _beanie_initialized: + if _beanie_initialized: return try: import os @@ -85,10 +86,11 @@ class JobPriority(str, Enum): - NORMAL: 5 minutes timeout (default) - LOW: 3 minutes timeout """ - URGENT = "urgent" # 1 - Process immediately - HIGH = "high" # 2 - Process before normal - NORMAL = "normal" # 3 - Default priority - LOW = "low" # 4 - Process when idle + + URGENT = "urgent" # 1 - Process immediately + HIGH = "high" # 2 - Process before normal + NORMAL = "normal" # 3 - Default priority + LOW = "low" # 4 - Process when idle class BaseRQJob(ABC): @@ -188,6 +190,7 @@ def run(self, **kwargs) -> Dict[str, Any]: asyncio.set_event_loop(loop) try: + async def process(): await self._setup() try: @@ -207,11 +210,15 @@ async def process(): except Exception as e: elapsed = time.time() - self.job_start_time - logger.error(f"❌ {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True) + logger.error( + f"❌ {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True + ) raise -def async_job(redis: bool = True, beanie: bool = True, timeout: int = 300, result_ttl: int = 3600): +def async_job( + redis: bool = True, beanie: bool = True, timeout: int = 300, result_ttl: int = 3600 +): """ Decorator to convert async functions into RQ-compatible job functions. @@ -239,6 +246,7 @@ async def my_job(arg1, arg2, redis_client=None): queue.enqueue(my_job, arg1_value, arg2_value) # Uses timeout=600 queue.enqueue(my_job, arg1_value, arg2_value, job_timeout=1200) # Override """ + def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs) -> Dict[str, Any]: @@ -254,6 +262,7 @@ def wrapper(*args, **kwargs) -> Dict[str, Any]: asyncio.set_event_loop(loop) try: + async def process(): nonlocal redis_client @@ -267,8 +276,9 @@ async def process(): from advanced_omi_backend.controllers.queue_controller import ( REDIS_URL, ) + redis_client = redis_async.from_url(REDIS_URL) - kwargs['redis_client'] = redis_client + kwargs["redis_client"] = redis_client logger.debug(f"Redis client created") try: @@ -292,7 +302,9 @@ async def process(): except Exception as e: elapsed = time.time() - start_time - logger.error(f"❌ {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True) + logger.error( + f"❌ {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True + ) raise # Store default job configuration as attributes for RQ introspection @@ -300,4 +312,5 @@ async def process(): wrapper.result_ttl = result_ttl return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/backends/advanced/src/advanced_omi_backend/models/user.py b/backends/advanced/src/advanced_omi_backend/models/user.py index 7291f9bb..77e0cc38 100644 --- a/backends/advanced/src/advanced_omi_backend/models/user.py +++ b/backends/advanced/src/advanced_omi_backend/models/user.py @@ -80,7 +80,9 @@ def user_id(self) -> str: """Return string representation of MongoDB ObjectId for backward compatibility.""" return str(self.id) - def register_client(self, client_id: str, device_name: Optional[str] = None) -> None: + def register_client( + self, client_id: str, device_name: Optional[str] = None + ) -> None: """Register a new client for this user.""" # Check if client already exists if client_id in self.registered_clients: diff --git a/backends/advanced/src/advanced_omi_backend/models/waveform.py b/backends/advanced/src/advanced_omi_backend/models/waveform.py index caf6fd49..ccf1158c 100644 --- a/backends/advanced/src/advanced_omi_backend/models/waveform.py +++ b/backends/advanced/src/advanced_omi_backend/models/waveform.py @@ -32,12 +32,10 @@ class WaveformData(Document): # Metadata duration_seconds: float = Field(description="Total audio duration in seconds") created_at: datetime = Field( - default_factory=datetime.utcnow, - description="When this waveform was generated" + default_factory=datetime.utcnow, description="When this waveform was generated" ) processing_time_seconds: Optional[float] = Field( - None, - description="Time taken to generate waveform" + None, description="Time taken to generate waveform" ) class Settings: diff --git a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py index 90c47460..35b0a2b0 100644 --- a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py @@ -20,13 +20,13 @@ from .services import PluginServices __all__ = [ - 'BasePlugin', - 'ButtonActionType', - 'ButtonState', - 'ConversationCloseReason', - 'PluginContext', - 'PluginEvent', - 'PluginResult', - 'PluginRouter', - 'PluginServices', + "BasePlugin", + "ButtonActionType", + "ButtonState", + "ConversationCloseReason", + "PluginContext", + "PluginEvent", + "PluginResult", + "PluginRouter", + "PluginServices", ] diff --git a/backends/advanced/src/advanced_omi_backend/routers/api_router.py b/backends/advanced/src/advanced_omi_backend/routers/api_router.py index e4c89531..63e7180d 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/api_router.py +++ b/backends/advanced/src/advanced_omi_backend/routers/api_router.py @@ -47,12 +47,15 @@ router.include_router(obsidian_router) router.include_router(system_router) router.include_router(queue_router) -router.include_router(health_router) # Also include under /api for frontend compatibility +router.include_router( + health_router +) # Also include under /api for frontend compatibility # Conditionally include test routes (only in test environments) if os.getenv("DEBUG_DIR"): try: from .modules.test_routes import router as test_router + router.include_router(test_router) logger.info("βœ… Test routes loaded (test environment detected)") except Exception as e: diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py b/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py index 501377fc..a65463d7 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py @@ -35,19 +35,19 @@ from .websocket_routes import router as websocket_router __all__ = [ - "admin_router", - "annotation_router", - "audio_router", - "chat_router", - "client_router", - "conversation_router", - "finetuning_router", - "health_router", - "knowledge_graph_router", - "memory_router", - "obsidian_router", - "queue_router", - "system_router", - "user_router", - "websocket_router", + "admin_router", + "annotation_router", + "audio_router", + "chat_router", + "client_router", + "conversation_router", + "finetuning_router", + "health_router", + "knowledge_graph_router", + "memory_router", + "obsidian_router", + "queue_router", + "system_router", + "user_router", + "websocket_router", ] diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py index 49594dd0..2b6f0dd9 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py @@ -21,32 +21,29 @@ def require_admin(current_user: User = Depends(current_active_user)) -> User: """Dependency to require admin/superuser permissions.""" if not current_user.is_superuser: - raise HTTPException( - status_code=403, - detail="Admin permissions required" - ) + raise HTTPException(status_code=403, detail="Admin permissions required") return current_user @router.get("/cleanup/settings") -async def get_cleanup_settings_admin( - admin: User = Depends(require_admin) -): +async def get_cleanup_settings_admin(admin: User = Depends(require_admin)): """Get current cleanup settings (admin only).""" from advanced_omi_backend.config import get_cleanup_settings settings = get_cleanup_settings() return { **settings, - "note": "Cleanup settings are stored in /app/data/cleanup_config.json" + "note": "Cleanup settings are stored in /app/data/cleanup_config.json", } @router.post("/cleanup") async def trigger_cleanup( dry_run: bool = Query(False, description="Preview what would be deleted"), - retention_days: Optional[int] = Query(None, description="Override retention period"), - admin: User = Depends(require_admin) + retention_days: Optional[int] = Query( + None, description="Override retention period" + ), + admin: User = Depends(require_admin), ): """Manually trigger cleanup of soft-deleted conversations (admin only).""" try: @@ -64,7 +61,9 @@ async def trigger_cleanup( job_timeout="30m", ) - logger.info(f"Admin {admin.email} triggered cleanup job {job.id} (dry_run={dry_run}, retention={retention_days or 'default'})") + logger.info( + f"Admin {admin.email} triggered cleanup job {job.id} (dry_run={dry_run}, retention={retention_days or 'default'})" + ) return JSONResponse( status_code=200, @@ -73,22 +72,23 @@ async def trigger_cleanup( "job_id": job.id, "retention_days": retention_days or "default (from config)", "dry_run": dry_run, - "note": "Check job status at /api/queue/jobs/{job_id}" - } + "note": "Check job status at /api/queue/jobs/{job_id}", + }, ) except Exception as e: logger.error(f"Failed to trigger cleanup: {e}") return JSONResponse( - status_code=500, - content={"error": f"Failed to trigger cleanup: {str(e)}"} + status_code=500, content={"error": f"Failed to trigger cleanup: {str(e)}"} ) @router.get("/cleanup/preview") async def preview_cleanup( - retention_days: Optional[int] = Query(None, description="Preview with specific retention period"), - admin: User = Depends(require_admin) + retention_days: Optional[int] = Query( + None, description="Preview with specific retention period" + ), + admin: User = Depends(require_admin), ): """Preview what would be deleted by cleanup (admin only).""" try: @@ -100,26 +100,24 @@ async def preview_cleanup( # Use provided retention or default from config if retention_days is None: settings_dict = get_cleanup_settings() - retention_days = settings_dict['retention_days'] + retention_days = settings_dict["retention_days"] cutoff_date = datetime.utcnow() - timedelta(days=retention_days) # Count conversations that would be deleted count = await Conversation.find( - Conversation.deleted == True, - Conversation.deleted_at < cutoff_date + Conversation.deleted == True, Conversation.deleted_at < cutoff_date ).count() return { "retention_days": retention_days, "cutoff_date": cutoff_date.isoformat(), "conversations_to_delete": count, - "note": f"Conversations deleted before {cutoff_date.date()} would be purged" + "note": f"Conversations deleted before {cutoff_date.date()} would be purged", } except Exception as e: logger.error(f"Failed to preview cleanup: {e}") return JSONResponse( - status_code=500, - content={"error": f"Failed to preview cleanup: {str(e)}"} + status_code=500, content={"error": f"Failed to preview cleanup: {str(e)}"} ) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py index fd1c659f..81065a8e 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py @@ -49,7 +49,10 @@ def _safe_filename(conversation: "Conversation") -> str: @router.post("/upload_audio_from_gdrive") async def upload_audio_from_drive_folder( - gdrive_folder_id: str = Query(..., description="Google Drive Folder ID containing audio files (e.g., the string after /folders/ in the URL)"), + gdrive_folder_id: str = Query( + ..., + description="Google Drive Folder ID containing audio files (e.g., the string after /folders/ in the URL)", + ), current_user: User = Depends(current_superuser), device_name: str = Query(default="upload"), ): @@ -67,7 +70,9 @@ async def upload_audio_from_drive_folder( async def get_conversation_audio( conversation_id: str, request: Request, - token: Optional[str] = Query(default=None, description="JWT token for audio element access"), + token: Optional[str] = Query( + default=None, description="JWT token for audio element access" + ), current_user: Optional[User] = Depends(current_active_user_optional), ): """ @@ -111,7 +116,9 @@ async def get_conversation_audio( raise HTTPException(status_code=404, detail="Conversation not found") # Check ownership (admins can access all) - if not current_user.is_superuser and conversation.user_id != str(current_user.user_id): + if not current_user.is_superuser and conversation.user_id != str( + current_user.user_id + ): raise HTTPException(status_code=403, detail="Access denied") # Reconstruct WAV from MongoDB chunks @@ -123,8 +130,7 @@ async def get_conversation_audio( except Exception as e: # Reconstruction failed raise HTTPException( - status_code=500, - detail=f"Failed to reconstruct audio: {str(e)}" + status_code=500, detail=f"Failed to reconstruct audio: {str(e)}" ) # Handle Range requests for seeking support @@ -143,7 +149,7 @@ async def get_conversation_audio( "Accept-Ranges": "bytes", "X-Audio-Source": "mongodb-chunks", "X-Chunk-Count": str(conversation.audio_chunks_count or 0), - } + }, ) # Parse Range header (e.g., "bytes=0-1023") @@ -159,7 +165,7 @@ async def get_conversation_audio( content_length = range_end - range_start + 1 # Extract requested byte range - range_data = wav_data[range_start:range_end + 1] + range_data = wav_data[range_start : range_end + 1] # Return 206 Partial Content with Range headers return Response( @@ -172,22 +178,21 @@ async def get_conversation_audio( "Accept-Ranges": "bytes", "Content-Disposition": f'inline; filename="{filename}.wav"', "X-Audio-Source": "mongodb-chunks", - } + }, ) except (ValueError, IndexError) as e: # Invalid Range header, return 416 Range Not Satisfiable return Response( - status_code=416, - headers={ - "Content-Range": f"bytes */{file_size}" - } + status_code=416, headers={"Content-Range": f"bytes */{file_size}"} ) @router.get("/stream_audio/{conversation_id}") async def stream_conversation_audio( conversation_id: str, - token: Optional[str] = Query(default=None, description="JWT token for audio element access"), + token: Optional[str] = Query( + default=None, description="JWT token for audio element access" + ), current_user: Optional[User] = Depends(current_active_user_optional), ): """ @@ -230,12 +235,16 @@ async def stream_conversation_audio( raise HTTPException(status_code=404, detail="Conversation not found") # Check ownership (admins can access all) - if not current_user.is_superuser and conversation.user_id != str(current_user.user_id): + if not current_user.is_superuser and conversation.user_id != str( + current_user.user_id + ): raise HTTPException(status_code=403, detail="Access denied") # Check if chunks exist if not conversation.audio_chunks_count or conversation.audio_chunks_count == 0: - raise HTTPException(status_code=404, detail="No audio data for this conversation") + raise HTTPException( + status_code=404, detail="No audio data for this conversation" + ) async def stream_chunks(): """Generator that yields WAV data in batches.""" @@ -249,6 +258,7 @@ async def stream_chunks(): # We'll write a placeholder size since we're streaming wav_header = io.BytesIO() import wave + with wave.open(wav_header, "wb") as wav: wav.setnchannels(CHANNELS) wav.setsampwidth(SAMPLE_WIDTH) @@ -268,7 +278,7 @@ async def stream_chunks(): chunks = await retrieve_audio_chunks( conversation_id=conversation_id, start_index=start_index, - limit=batch_size + limit=batch_size, ) if not chunks: @@ -292,7 +302,7 @@ async def stream_chunks(): "X-Audio-Source": "mongodb-chunks-stream", "X-Chunk-Count": str(conversation.audio_chunks_count or 0), "X-Total-Duration": str(conversation.audio_total_duration or 0), - } + }, ) @@ -301,7 +311,9 @@ async def get_audio_chunk_range( conversation_id: str, start_time: float = Query(..., description="Start time in seconds"), end_time: float = Query(..., description="End time in seconds"), - token: Optional[str] = Query(default=None, description="JWT token for audio element access"), + token: Optional[str] = Query( + default=None, description="JWT token for audio element access" + ), current_user: Optional[User] = Depends(current_active_user_optional), ): """ @@ -331,8 +343,11 @@ async def get_audio_chunk_range( 400: If time range is invalid """ import logging + logger = logging.getLogger(__name__) - logger.info(f"🎡 Audio chunk request: conversation={conversation_id[:8]}..., start={start_time:.2f}s, end={end_time:.2f}s") + logger.info( + f"🎡 Audio chunk request: conversation={conversation_id[:8]}..., start={start_time:.2f}s, end={end_time:.2f}s" + ) # Try token param if header auth failed if not current_user and token: @@ -350,27 +365,38 @@ async def get_audio_chunk_range( raise HTTPException(status_code=404, detail="Conversation not found") # Check ownership (admins can access all) - if not current_user.is_superuser and conversation.user_id != str(current_user.user_id): + if not current_user.is_superuser and conversation.user_id != str( + current_user.user_id + ): raise HTTPException(status_code=403, detail="Access denied") # Validate time range if start_time < 0 or end_time <= start_time: raise HTTPException(status_code=400, detail="Invalid time range") - if conversation.audio_total_duration and end_time > conversation.audio_total_duration: + if ( + conversation.audio_total_duration + and end_time > conversation.audio_total_duration + ): end_time = conversation.audio_total_duration # Use the dedicated segment reconstruction function from advanced_omi_backend.utils.audio_chunk_utils import reconstruct_audio_segment try: - wav_data = await reconstruct_audio_segment(conversation_id, start_time, end_time) - logger.info(f"βœ… Returning WAV: {len(wav_data)} bytes for range {start_time:.2f}s - {end_time:.2f}s") + wav_data = await reconstruct_audio_segment( + conversation_id, start_time, end_time + ) + logger.info( + f"βœ… Returning WAV: {len(wav_data)} bytes for range {start_time:.2f}s - {end_time:.2f}s" + ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: logger.error(f"Failed to reconstruct audio segment: {e}") - raise HTTPException(status_code=500, detail=f"Failed to reconstruct audio: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to reconstruct audio: {str(e)}" + ) return StreamingResponse( io.BytesIO(wav_data), @@ -381,7 +407,7 @@ async def get_audio_chunk_range( "X-Audio-Duration": str(end_time - start_time), "X-Start-Time": str(start_time), "X-End-Time": str(end_time), - } + }, ) @@ -389,7 +415,9 @@ async def get_audio_chunk_range( async def upload_audio_files( current_user: User = Depends(current_superuser), files: list[UploadFile] = File(...), - device_name: str = Query(default="upload", description="Device name for uploaded files"), + device_name: str = Query( + default="upload", description="Device name for uploaded files" + ), ): """ Upload and process audio files. Admin only. diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py index fdb73e5d..0d296340 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py @@ -31,18 +31,31 @@ # --- OpenAI-compatible chat completion models --- + class ChatCompletionMessage(BaseModel): - role: str = Field(..., description="The role of the message author (system, user, assistant)") + role: str = Field( + ..., description="The role of the message author (system, user, assistant)" + ) content: str = Field(..., description="The message content") class ChatCompletionRequest(BaseModel): - messages: List[ChatCompletionMessage] = Field(..., min_length=1, description="List of messages in the conversation") - model: Optional[str] = Field(None, description="Model to use (ignored, uses server-configured model)") + messages: List[ChatCompletionMessage] = Field( + ..., min_length=1, description="List of messages in the conversation" + ) + model: Optional[str] = Field( + None, description="Model to use (ignored, uses server-configured model)" + ) stream: Optional[bool] = Field(True, description="Whether to stream the response") - temperature: Optional[float] = Field(None, description="Sampling temperature (ignored, uses server config)") - session_id: Optional[str] = Field(None, description="Chronicle session ID (creates new if not provided)") - include_obsidian_memory: Optional[bool] = Field(False, description="Whether to include Obsidian vault context") + temperature: Optional[float] = Field( + None, description="Sampling temperature (ignored, uses server config)" + ) + session_id: Optional[str] = Field( + None, description="Chronicle session ID (creates new if not provided)" + ) + include_obsidian_memory: Optional[bool] = Field( + False, description="Whether to include Obsidian vault context" + ) class ChatCompletionChunkDelta(BaseModel): @@ -104,7 +117,9 @@ class ChatSessionCreateRequest(BaseModel): class ChatSessionUpdateRequest(BaseModel): - title: str = Field(..., min_length=1, max_length=200, description="New session title") + title: str = Field( + ..., min_length=1, max_length=200, description="New session title" + ) class ChatStatisticsResponse(BaseModel): @@ -115,99 +130,97 @@ class ChatStatisticsResponse(BaseModel): @router.post("/sessions", response_model=ChatSessionResponse) async def create_chat_session( - request: ChatSessionCreateRequest, - current_user: User = Depends(current_active_user) + request: ChatSessionCreateRequest, current_user: User = Depends(current_active_user) ): """Create a new chat session.""" try: chat_service = get_chat_service() session = await chat_service.create_session( - user_id=str(current_user.id), - title=request.title + user_id=str(current_user.id), title=request.title ) - + return ChatSessionResponse( session_id=session.session_id, title=session.title, created_at=session.created_at.isoformat(), - updated_at=session.updated_at.isoformat() + updated_at=session.updated_at.isoformat(), ) except Exception as e: logger.error(f"Failed to create chat session for user {current_user.id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to create chat session" + detail="Failed to create chat session", ) @router.get("/sessions", response_model=List[ChatSessionResponse]) async def get_chat_sessions( - limit: int = 50, - current_user: User = Depends(current_active_user) + limit: int = 50, current_user: User = Depends(current_active_user) ): """Get all chat sessions for the current user.""" try: chat_service = get_chat_service() sessions = await chat_service.get_user_sessions( - user_id=str(current_user.id), - limit=min(limit, 100) # Cap at 100 + user_id=str(current_user.id), limit=min(limit, 100) # Cap at 100 ) - + # Get message counts for each session (this could be optimized with aggregation) session_responses = [] for session in sessions: messages = await chat_service.get_session_messages( session_id=session.session_id, user_id=str(current_user.id), - limit=1 # We just need count, but MongoDB doesn't have efficient count + limit=1, # We just need count, but MongoDB doesn't have efficient count ) - - session_responses.append(ChatSessionResponse( - session_id=session.session_id, - title=session.title, - created_at=session.created_at.isoformat(), - updated_at=session.updated_at.isoformat(), - message_count=len(messages) # This is approximate for efficiency - )) - + + session_responses.append( + ChatSessionResponse( + session_id=session.session_id, + title=session.title, + created_at=session.created_at.isoformat(), + updated_at=session.updated_at.isoformat(), + message_count=len(messages), # This is approximate for efficiency + ) + ) + return session_responses except Exception as e: logger.error(f"Failed to get chat sessions for user {current_user.id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve chat sessions" + detail="Failed to retrieve chat sessions", ) @router.get("/sessions/{session_id}", response_model=ChatSessionResponse) async def get_chat_session( - session_id: str, - current_user: User = Depends(current_active_user) + session_id: str, current_user: User = Depends(current_active_user) ): """Get a specific chat session.""" try: chat_service = get_chat_service() session = await chat_service.get_session(session_id, str(current_user.id)) - + if not session: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Chat session not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found" ) - + return ChatSessionResponse( session_id=session.session_id, title=session.title, created_at=session.created_at.isoformat(), - updated_at=session.updated_at.isoformat() + updated_at=session.updated_at.isoformat(), ) except HTTPException: raise except Exception as e: - logger.error(f"Failed to get chat session {session_id} for user {current_user.id}: {e}") + logger.error( + f"Failed to get chat session {session_id} for user {current_user.id}: {e}" + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve chat session" + detail="Failed to retrieve chat session", ) @@ -215,100 +228,100 @@ async def get_chat_session( async def update_chat_session( session_id: str, request: ChatSessionUpdateRequest, - current_user: User = Depends(current_active_user) + current_user: User = Depends(current_active_user), ): """Update a chat session's title.""" try: chat_service = get_chat_service() - + # Verify session exists and belongs to user session = await chat_service.get_session(session_id, str(current_user.id)) if not session: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Chat session not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found" ) - + # Update the title success = await chat_service.update_session_title( session_id, str(current_user.id), request.title ) - + if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update session title" + detail="Failed to update session title", ) - + # Return updated session - updated_session = await chat_service.get_session(session_id, str(current_user.id)) + updated_session = await chat_service.get_session( + session_id, str(current_user.id) + ) return ChatSessionResponse( session_id=updated_session.session_id, title=updated_session.title, created_at=updated_session.created_at.isoformat(), - updated_at=updated_session.updated_at.isoformat() + updated_at=updated_session.updated_at.isoformat(), ) except HTTPException: raise except Exception as e: - logger.error(f"Failed to update chat session {session_id} for user {current_user.id}: {e}") + logger.error( + f"Failed to update chat session {session_id} for user {current_user.id}: {e}" + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update chat session" + detail="Failed to update chat session", ) @router.delete("/sessions/{session_id}") async def delete_chat_session( - session_id: str, - current_user: User = Depends(current_active_user) + session_id: str, current_user: User = Depends(current_active_user) ): """Delete a chat session and all its messages.""" try: chat_service = get_chat_service() success = await chat_service.delete_session(session_id, str(current_user.id)) - + if not success: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Chat session not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found" ) - + return {"message": "Chat session deleted successfully"} except HTTPException: raise except Exception as e: - logger.error(f"Failed to delete chat session {session_id} for user {current_user.id}: {e}") + logger.error( + f"Failed to delete chat session {session_id} for user {current_user.id}: {e}" + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to delete chat session" + detail="Failed to delete chat session", ) @router.get("/sessions/{session_id}/messages", response_model=List[ChatMessageResponse]) async def get_session_messages( - session_id: str, - limit: int = 100, - current_user: User = Depends(current_active_user) + session_id: str, limit: int = 100, current_user: User = Depends(current_active_user) ): """Get all messages in a chat session.""" try: chat_service = get_chat_service() - + # Verify session exists and belongs to user session = await chat_service.get_session(session_id, str(current_user.id)) if not session: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Chat session not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found" ) - + messages = await chat_service.get_session_messages( session_id=session_id, user_id=str(current_user.id), - limit=min(limit, 200) # Cap at 200 + limit=min(limit, 200), # Cap at 200 ) - + return [ ChatMessageResponse( message_id=msg.message_id, @@ -316,24 +329,25 @@ async def get_session_messages( role=msg.role, content=msg.content, timestamp=msg.timestamp.isoformat(), - memories_used=msg.memories_used + memories_used=msg.memories_used, ) for msg in messages ] except HTTPException: raise except Exception as e: - logger.error(f"Failed to get messages for session {session_id}, user {current_user.id}: {e}") + logger.error( + f"Failed to get messages for session {session_id}, user {current_user.id}: {e}" + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve messages" + detail="Failed to retrieve messages", ) @router.post("/completions") async def chat_completions( - request: ChatCompletionRequest, - current_user: User = Depends(current_active_user) + request: ChatCompletionRequest, current_user: User = Depends(current_active_user) ): """OpenAI-compatible chat completions endpoint with streaming support.""" try: @@ -349,7 +363,7 @@ async def chat_completions( if not session: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Chat session not found" + detail="Chat session not found", ) # Extract the latest user message @@ -357,7 +371,7 @@ async def chat_completions( if not user_messages: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="At least one user message is required" + detail="At least one user message is required", ) message_content = user_messages[-1].content @@ -368,9 +382,14 @@ async def chat_completions( if request.stream: return StreamingResponse( _stream_openai_format( - chat_service, session_id, str(current_user.id), - message_content, request.include_obsidian_memory, - completion_id, created, model_name, + chat_service, + session_id, + str(current_user.id), + message_content, + request.include_obsidian_memory, + completion_id, + created, + model_name, ), media_type="text/event-stream", headers={ @@ -381,9 +400,14 @@ async def chat_completions( ) else: return await _non_streaming_response( - chat_service, session_id, str(current_user.id), - message_content, request.include_obsidian_memory, - completion_id, created, model_name, + chat_service, + session_id, + str(current_user.id), + message_content, + request.include_obsidian_memory, + completion_id, + created, + model_name, ) except HTTPException: @@ -392,14 +416,19 @@ async def chat_completions( logger.error(f"Failed to process message for user {current_user.id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to process message" + detail="Failed to process message", ) async def _stream_openai_format( - chat_service, session_id: str, user_id: str, - message_content: str, include_obsidian_memory: bool, - completion_id: str, created: int, model_name: str, + chat_service, + session_id: str, + user_id: str, + message_content: str, + include_obsidian_memory: bool, + completion_id: str, + created: int, + model_name: str, ): """Map internal streaming events to OpenAI SSE chunk format.""" previous_text = "" @@ -415,10 +444,14 @@ async def _stream_openai_format( if event_type == "memory_context": # First chunk: send role + chronicle metadata chunk = ChatCompletionChunk( - id=completion_id, created=created, model=model_name, - choices=[ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta(role="assistant"), - )], + id=completion_id, + created=created, + model=model_name, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(role="assistant"), + ) + ], chronicle_metadata={ "session_id": session_id, **event["data"], @@ -429,24 +462,32 @@ async def _stream_openai_format( elif event_type == "token": # Internal events carry accumulated text; compute delta accumulated = event["data"] - delta_text = accumulated[len(previous_text):] + delta_text = accumulated[len(previous_text) :] previous_text = accumulated if delta_text: chunk = ChatCompletionChunk( - id=completion_id, created=created, model=model_name, - choices=[ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta(content=delta_text), - )], + id=completion_id, + created=created, + model=model_name, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(content=delta_text), + ) + ], ) yield f"data: {chunk.model_dump_json()}\n\n" elif event_type == "complete": chunk = ChatCompletionChunk( - id=completion_id, created=created, model=model_name, - choices=[ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta(), - finish_reason="stop", - )], + id=completion_id, + created=created, + model=model_name, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(), + finish_reason="stop", + ) + ], chronicle_metadata={ "session_id": session_id, "message_id": event["data"].get("message_id"), @@ -473,9 +514,14 @@ async def _stream_openai_format( async def _non_streaming_response( - chat_service, session_id: str, user_id: str, - message_content: str, include_obsidian_memory: bool, - completion_id: str, created: int, model_name: str, + chat_service, + session_id: str, + user_id: str, + message_content: str, + include_obsidian_memory: bool, + completion_id: str, + created: int, + model_name: str, ) -> ChatCompletionResponse: """Collect all events and return a single ChatCompletionResponse.""" full_content = "" @@ -506,9 +552,13 @@ async def _non_streaming_response( id=completion_id, created=created, model=model_name, - choices=[ChatCompletionChoice( - message=ChatCompletionMessage(role="assistant", content=full_content.strip()), - )], + choices=[ + ChatCompletionChoice( + message=ChatCompletionMessage( + role="assistant", content=full_content.strip() + ), + ) + ], usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, session_id=session_id, chronicle_metadata=metadata, @@ -516,62 +566,62 @@ async def _non_streaming_response( @router.get("/statistics", response_model=ChatStatisticsResponse) -async def get_chat_statistics( - current_user: User = Depends(current_active_user) -): +async def get_chat_statistics(current_user: User = Depends(current_active_user)): """Get chat statistics for the current user.""" try: chat_service = get_chat_service() stats = await chat_service.get_chat_statistics(str(current_user.id)) - + return ChatStatisticsResponse( total_sessions=stats["total_sessions"], total_messages=stats["total_messages"], - last_chat=stats["last_chat"].isoformat() if stats["last_chat"] else None + last_chat=stats["last_chat"].isoformat() if stats["last_chat"] else None, ) except Exception as e: logger.error(f"Failed to get chat statistics for user {current_user.id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve chat statistics" + detail="Failed to retrieve chat statistics", ) @router.post("/sessions/{session_id}/extract-memories") async def extract_memories_from_session( - session_id: str, - current_user: User = Depends(current_active_user) + session_id: str, current_user: User = Depends(current_active_user) ): """Extract memories from a chat session.""" try: chat_service = get_chat_service() - + # Extract memories from the session - success, memory_ids, memory_count = await chat_service.extract_memories_from_session( - session_id=session_id, - user_id=str(current_user.id) + success, memory_ids, memory_count = ( + await chat_service.extract_memories_from_session( + session_id=session_id, user_id=str(current_user.id) + ) ) - + if success: return { "success": True, "memory_ids": memory_ids, "count": memory_count, - "message": f"Successfully extracted {memory_count} memories from chat session" + "message": f"Successfully extracted {memory_count} memories from chat session", } else: return { "success": False, "memory_ids": [], "count": 0, - "message": "Failed to extract memories from chat session" + "message": "Failed to extract memories from chat session", } - + except Exception as e: - logger.error(f"Failed to extract memories from session {session_id} for user {current_user.id}: {e}") + logger.error( + f"Failed to extract memories from session {session_id} for user {current_user.id}: {e}" + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to extract memories from chat session" + detail="Failed to extract memories from chat session", ) @@ -583,15 +633,11 @@ async def chat_health_check(): # Simple health check - verify service can be initialized if not chat_service._initialized: await chat_service.initialize() - - return { - "status": "healthy", - "service": "chat", - "timestamp": time.time() - } + + return {"status": "healthy", "service": "chat", "timestamp": time.time()} except Exception as e: logger.error(f"Chat service health check failed: {e}") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Chat service is not available" - ) \ No newline at end of file + detail="Chat service is not available", + ) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py index 7a89fd5f..f46daf2e 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py @@ -26,28 +26,44 @@ async def close_current_conversation( current_user: User = Depends(current_active_user), ): """Close the current active conversation for a client. Works for both connected and disconnected clients.""" - return await conversation_controller.close_current_conversation(client_id, current_user) + return await conversation_controller.close_current_conversation( + client_id, current_user + ) @router.get("") async def get_conversations( - include_deleted: bool = Query(False, description="Include soft-deleted conversations"), - include_unprocessed: bool = Query(False, description="Include orphan audio sessions (always_persist with failed/pending transcription)"), - starred_only: bool = Query(False, description="Only return starred/favorited conversations"), + include_deleted: bool = Query( + False, description="Include soft-deleted conversations" + ), + include_unprocessed: bool = Query( + False, + description="Include orphan audio sessions (always_persist with failed/pending transcription)", + ), + starred_only: bool = Query( + False, description="Only return starred/favorited conversations" + ), limit: int = Query(200, ge=1, le=500, description="Max conversations to return"), offset: int = Query(0, ge=0, description="Number of conversations to skip"), - sort_by: str = Query("created_at", description="Sort field: created_at, title, audio_total_duration"), + sort_by: str = Query( + "created_at", description="Sort field: created_at, title, audio_total_duration" + ), sort_order: str = Query("desc", description="Sort direction: asc or desc"), - current_user: User = Depends(current_active_user) + current_user: User = Depends(current_active_user), ): """Get conversations. Admins see all conversations, users see only their own.""" return await conversation_controller.get_conversations( - current_user, include_deleted, include_unprocessed, starred_only, limit, offset, - sort_by=sort_by, sort_order=sort_order, + current_user, + include_deleted, + include_unprocessed, + starred_only, + limit, + offset, + sort_by=sort_by, + sort_order=sort_order, ) - @router.get("/search") async def search_conversations( q: str = Query(..., min_length=1, description="Text search query"), @@ -56,13 +72,14 @@ async def search_conversations( current_user: User = Depends(current_active_user), ): """Full-text search across conversation titles, summaries, and transcripts.""" - return await conversation_controller.search_conversations(q, current_user, limit, offset) + return await conversation_controller.search_conversations( + q, current_user, limit, offset + ) @router.get("/{conversation_id}") async def get_conversation_detail( - conversation_id: str, - current_user: User = Depends(current_active_user) + conversation_id: str, current_user: User = Depends(current_active_user) ): """Get a specific conversation with full transcript details.""" return await conversation_controller.get_conversation(conversation_id, current_user) @@ -94,24 +111,28 @@ async def reprocess_transcript( conversation_id: str, current_user: User = Depends(current_active_user) ): """Reprocess transcript for a conversation. Users can only reprocess their own conversations.""" - return await conversation_controller.reprocess_transcript(conversation_id, current_user) + return await conversation_controller.reprocess_transcript( + conversation_id, current_user + ) @router.post("/{conversation_id}/reprocess-memory") async def reprocess_memory( conversation_id: str, current_user: User = Depends(current_active_user), - transcript_version_id: str = Query(default="active") + transcript_version_id: str = Query(default="active"), ): """Reprocess memory extraction for a specific transcript version. Users can only reprocess their own conversations.""" - return await conversation_controller.reprocess_memory(conversation_id, transcript_version_id, current_user) + return await conversation_controller.reprocess_memory( + conversation_id, transcript_version_id, current_user + ) @router.post("/{conversation_id}/reprocess-speakers") async def reprocess_speakers( conversation_id: str, current_user: User = Depends(current_active_user), - transcript_version_id: str = Query(default="active") + transcript_version_id: str = Query(default="active"), ): """ Re-run speaker identification/diarization on existing transcript. @@ -127,9 +148,7 @@ async def reprocess_speakers( Job status with job_id and new version_id """ return await conversation_controller.reprocess_speakers( - conversation_id, - transcript_version_id, - current_user + conversation_id, transcript_version_id, current_user ) @@ -137,20 +156,24 @@ async def reprocess_speakers( async def activate_transcript_version( conversation_id: str, version_id: str, - current_user: User = Depends(current_active_user) + current_user: User = Depends(current_active_user), ): """Activate a specific transcript version. Users can only modify their own conversations.""" - return await conversation_controller.activate_transcript_version(conversation_id, version_id, current_user) + return await conversation_controller.activate_transcript_version( + conversation_id, version_id, current_user + ) @router.post("/{conversation_id}/activate-memory/{version_id}") async def activate_memory_version( conversation_id: str, version_id: str, - current_user: User = Depends(current_active_user) + current_user: User = Depends(current_active_user), ): """Activate a specific memory version. Users can only modify their own conversations.""" - return await conversation_controller.activate_memory_version(conversation_id, version_id, current_user) + return await conversation_controller.activate_memory_version( + conversation_id, version_id, current_user + ) @router.get("/{conversation_id}/versions") @@ -158,13 +181,14 @@ async def get_conversation_version_history( conversation_id: str, current_user: User = Depends(current_active_user) ): """Get version history for a conversation. Users can only access their own conversations.""" - return await conversation_controller.get_conversation_version_history(conversation_id, current_user) + return await conversation_controller.get_conversation_version_history( + conversation_id, current_user + ) @router.get("/{conversation_id}/waveform") async def get_conversation_waveform( - conversation_id: str, - current_user: User = Depends(current_active_user) + conversation_id: str, current_user: User = Depends(current_active_user) ): """ Get or generate waveform visualization data for a conversation. @@ -208,37 +232,38 @@ async def get_conversation_waveform( # If waveform exists, return cached version if waveform: - logger.info(f"Returning cached waveform for conversation {conversation_id[:12]}") + logger.info( + f"Returning cached waveform for conversation {conversation_id[:12]}" + ) return waveform.model_dump(exclude={"id", "revision_id"}) # Generate waveform on-demand - logger.info(f"Generating waveform on-demand for conversation {conversation_id[:12]}") + logger.info( + f"Generating waveform on-demand for conversation {conversation_id[:12]}" + ) waveform_dict = await generate_waveform_data( - conversation_id=conversation_id, - sample_rate=3 + conversation_id=conversation_id, sample_rate=3 ) if not waveform_dict.get("success"): error_msg = waveform_dict.get("error", "Unknown error") logger.error(f"Waveform generation failed: {error_msg}") raise HTTPException( - status_code=500, - detail=f"Waveform generation failed: {error_msg}" + status_code=500, detail=f"Waveform generation failed: {error_msg}" ) # Return generated waveform (already saved to database by generator) return { "samples": waveform_dict["samples"], "sample_rate": waveform_dict["sample_rate"], - "duration_seconds": waveform_dict["duration_seconds"] + "duration_seconds": waveform_dict["duration_seconds"], } @router.get("/{conversation_id}/metadata") async def get_conversation_metadata( - conversation_id: str, - current_user: User = Depends(current_active_user) + conversation_id: str, current_user: User = Depends(current_active_user) ) -> dict: """ Get conversation metadata (duration, etc.) without loading audio. @@ -270,7 +295,7 @@ async def get_conversation_metadata( "conversation_id": conversation_id, "duration": conversation.audio_total_duration or 0.0, "created_at": conversation.created_at, - "has_audio": (conversation.audio_total_duration or 0.0) > 0 + "has_audio": (conversation.audio_total_duration or 0.0) > 0, } @@ -278,8 +303,10 @@ async def get_conversation_metadata( async def get_audio_segment( conversation_id: str, start: float = Query(0.0, description="Start time in seconds"), - duration: Optional[float] = Query(None, description="Duration in seconds (omit for full audio)"), - current_user: User = Depends(current_active_user) + duration: Optional[float] = Query( + None, description="Duration in seconds (omit for full audio)" + ), + current_user: User = Depends(current_active_user), ) -> Response: """ Get audio segment from a conversation. @@ -297,6 +324,7 @@ async def get_audio_segment( WAV audio bytes (16kHz, mono) for the requested time range """ import time + request_start = time.time() # Verify conversation exists and user has access @@ -314,7 +342,9 @@ async def get_audio_segment( # Calculate end time total_duration = conversation.audio_total_duration or 0.0 if total_duration == 0: - raise HTTPException(status_code=404, detail="No audio available for this conversation") + raise HTTPException( + status_code=404, detail="No audio available for this conversation" + ) if duration is None: end = total_duration @@ -323,18 +353,23 @@ async def get_audio_segment( # Validate time range if start < 0 or start >= total_duration: - raise HTTPException(status_code=400, detail=f"Invalid start time: {start}s (max: {total_duration}s)") + raise HTTPException( + status_code=400, + detail=f"Invalid start time: {start}s (max: {total_duration}s)", + ) # Get audio chunks for time range try: wav_bytes = await reconstruct_audio_segment( - conversation_id=conversation_id, - start_time=start, - end_time=end + conversation_id=conversation_id, start_time=start, end_time=end ) except Exception as e: - logger.error(f"Failed to reconstruct audio segment for {conversation_id[:12]}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to reconstruct audio: {str(e)}") + logger.error( + f"Failed to reconstruct audio segment for {conversation_id[:12]}: {e}" + ) + raise HTTPException( + status_code=500, detail=f"Failed to reconstruct audio: {str(e)}" + ) request_time = time.time() - request_start logger.info( @@ -351,15 +386,14 @@ async def get_audio_segment( "Content-Disposition": f"attachment; filename=segment_{start}_{end}.wav", "X-Audio-Start": str(start), "X-Audio-End": str(end), - "X-Audio-Duration": str(end - start) - } + "X-Audio-Duration": str(end - start), + }, ) @router.post("/{conversation_id}/star") async def toggle_star( - conversation_id: str, - current_user: User = Depends(current_active_user) + conversation_id: str, current_user: User = Depends(current_active_user) ): """Toggle the starred/favorite status of a conversation.""" return await conversation_controller.toggle_star(conversation_id, current_user) @@ -369,16 +403,19 @@ async def toggle_star( async def delete_conversation( conversation_id: str, permanent: bool = Query(False, description="Permanently delete (admin only)"), - current_user: User = Depends(current_active_user) + current_user: User = Depends(current_active_user), ): """Soft delete a conversation (or permanently delete if admin).""" - return await conversation_controller.delete_conversation(conversation_id, current_user, permanent) + return await conversation_controller.delete_conversation( + conversation_id, current_user, permanent + ) @router.post("/{conversation_id}/restore") async def restore_conversation( - conversation_id: str, - current_user: User = Depends(current_active_user) + conversation_id: str, current_user: User = Depends(current_active_user) ): """Restore a soft-deleted conversation.""" - return await conversation_controller.restore_conversation(conversation_id, current_user) + return await conversation_controller.restore_conversation( + conversation_id, current_user + ) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py index cd4bff97..4ea42139 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py @@ -48,7 +48,7 @@ else: _llm_def = _embed_def = _vs_def = None -QDRANT_BASE_URL = (_vs_def.model_params.get("host") if _vs_def else "qdrant") +QDRANT_BASE_URL = _vs_def.model_params.get("host") if _vs_def else "qdrant" QDRANT_PORT = str(_vs_def.model_params.get("port") if _vs_def else "6333") @@ -58,7 +58,7 @@ async def auth_health_check(): try: # Test database connectivity await mongo_client.admin.command("ping") - + # Test memory service if available if memory_service: try: @@ -69,12 +69,12 @@ async def auth_health_check(): memory_status = "degraded" else: memory_status = "unavailable" - + return { "status": "ok", - "database": "ok", + "database": "ok", "memory_service": memory_status, - "timestamp": int(time.time()) + "timestamp": int(time.time()), } except Exception as e: logger.error(f"Auth health check failed: {e}") @@ -84,8 +84,8 @@ async def auth_health_check(): "status": "error", "detail": "Service connectivity check failed", "error_type": "connection_failure", - "timestamp": int(time.time()) - } + "timestamp": int(time.time()), + }, ) @@ -129,16 +129,18 @@ async def health_check(): else "Not configured" ), "transcription_provider": ( - REGISTRY.get_default("stt").name if REGISTRY and REGISTRY.get_default("stt") + REGISTRY.get_default("stt").name + if REGISTRY and REGISTRY.get_default("stt") else "not configured" ), - "provider_type": ( transcription_provider.mode if transcription_provider else "none" ), "chunk_dir": str(os.getenv("CHUNK_DIR", "./audio_chunks")), "active_clients": get_client_manager().get_client_count(), - "new_conversation_timeout_minutes": float(os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5")), + "new_conversation_timeout_minutes": float( + os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5") + ), "llm_provider": (_llm_def.model_provider if _llm_def else None), "llm_model": (_llm_def.model_name if _llm_def else None), "llm_base_url": (_llm_def.model_url if _llm_def else None), @@ -204,14 +206,14 @@ async def health_check(): "worker_count": worker_count, "active_workers": active_workers, "idle_workers": idle_workers, - "queues": queue_health.get("queues", {}) + "queues": queue_health.get("queues", {}), } else: health_status["services"]["redis"] = { "status": f"❌ Connection Failed: {queue_health.get('redis_connection')}", "healthy": False, "critical": True, - "worker_count": 0 + "worker_count": 0, } overall_healthy = False critical_services_healthy = False @@ -221,7 +223,7 @@ async def health_check(): "status": "❌ Connection Timeout (5s)", "healthy": False, "critical": True, - "worker_count": 0 + "worker_count": 0, } overall_healthy = False critical_services_healthy = False @@ -230,7 +232,7 @@ async def health_check(): "status": f"❌ Connection Failed: {str(e)}", "healthy": False, "critical": True, - "worker_count": 0 + "worker_count": 0, } overall_healthy = False critical_services_healthy = False @@ -267,7 +269,9 @@ async def health_check(): if memory_provider == "chronicle": try: # Test Chronicle memory service connection with timeout - test_success = await asyncio.wait_for(memory_service.test_connection(), timeout=8.0) + test_success = await asyncio.wait_for( + memory_service.test_connection(), timeout=8.0 + ) if test_success: health_status["services"]["memory_service"] = { "status": "βœ… Chronicle Memory Connected", @@ -361,7 +365,8 @@ async def health_check(): # Make a health check request to the speaker service async with aiohttp.ClientSession() as session: async with session.get( - f"{speaker_service_url}/health", timeout=aiohttp.ClientTimeout(total=5) + f"{speaker_service_url}/health", + timeout=aiohttp.ClientTimeout(total=5), ) as response: if response.status == 200: health_status["services"]["speaker_recognition"] = { @@ -401,7 +406,8 @@ async def health_check(): # Make a health check request to the OpenMemory MCP service async with aiohttp.ClientSession() as session: async with session.get( - f"{openmemory_mcp_url}/api/v1/apps/", timeout=aiohttp.ClientTimeout(total=5) + f"{openmemory_mcp_url}/api/v1/apps/", + timeout=aiohttp.ClientTimeout(total=5), ) as response: if response.status == 200: health_status["services"]["openmemory_mcp"] = { @@ -464,7 +470,9 @@ async def health_check(): if not service["healthy"] and not service.get("critical", True) ] if unhealthy_optional: - messages.append(f"Optional services unavailable: {', '.join(unhealthy_optional)}") + messages.append( + f"Optional services unavailable: {', '.join(unhealthy_optional)}" + ) health_status["message"] = "; ".join(messages) @@ -476,15 +484,21 @@ async def readiness_check(): """Simple readiness check for container orchestration.""" # Use debug level for health check to reduce log spam logger.debug("Readiness check requested") - + # Only check critical services for readiness try: # Quick MongoDB ping to ensure we can serve requests await asyncio.wait_for(mongo_client.admin.command("ping"), timeout=2.0) - return JSONResponse(content={"status": "ready", "timestamp": int(time.time())}, status_code=200) + return JSONResponse( + content={"status": "ready", "timestamp": int(time.time())}, status_code=200 + ) except Exception as e: logger.error(f"Readiness check failed: {e}") return JSONResponse( - content={"status": "not_ready", "error": str(e), "timestamp": int(time.time())}, - status_code=503 + content={ + "status": "not_ready", + "error": str(e), + "timestamp": int(time.time()), + }, + status_code=503, ) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py index 1b8ae1cf..6fcebea1 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py @@ -37,6 +37,7 @@ class UpdateEntityRequest(BaseModel): """Request model for updating entity fields.""" + name: Optional[str] = None details: Optional[str] = None icon: Optional[str] = None @@ -44,6 +45,7 @@ class UpdateEntityRequest(BaseModel): class UpdatePromiseRequest(BaseModel): """Request model for updating promise status.""" + status: str # pending, in_progress, completed, cancelled diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py index 409f7b85..89ec6091 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py @@ -21,6 +21,7 @@ class AddMemoryRequest(BaseModel): """Request model for adding a memory.""" + content: str source_id: Optional[str] = None @@ -29,7 +30,9 @@ class AddMemoryRequest(BaseModel): async def get_memories( current_user: User = Depends(current_active_user), limit: int = Query(default=50, ge=1, le=1000), - user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"), + user_id: Optional[str] = Query( + default=None, description="User ID filter (admin only)" + ), ): """Get memories. Users see only their own memories, admins can see all or filter by user.""" return await memory_controller.get_memories(current_user, limit, user_id) @@ -39,10 +42,14 @@ async def get_memories( async def get_memories_with_transcripts( current_user: User = Depends(current_active_user), limit: int = Query(default=50, ge=1, le=1000), - user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"), + user_id: Optional[str] = Query( + default=None, description="User ID filter (admin only)" + ), ): """Get memories with their source transcripts. Users see only their own memories, admins can see all or filter by user.""" - return await memory_controller.get_memories_with_transcripts(current_user, limit, user_id) + return await memory_controller.get_memories_with_transcripts( + current_user, limit, user_id + ) @router.get("/search") @@ -50,30 +57,44 @@ async def search_memories( query: str = Query(..., description="Search query"), current_user: User = Depends(current_active_user), limit: int = Query(default=20, ge=1, le=100), - score_threshold: float = Query(default=0.0, ge=0.0, le=1.0, description="Minimum similarity score (0.0 = no threshold)"), - user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"), + score_threshold: float = Query( + default=0.0, + ge=0.0, + le=1.0, + description="Minimum similarity score (0.0 = no threshold)", + ), + user_id: Optional[str] = Query( + default=None, description="User ID filter (admin only)" + ), ): """Search memories by text query with configurable similarity threshold. Users can only search their own memories, admins can search all or filter by user.""" - return await memory_controller.search_memories(query, current_user, limit, score_threshold, user_id) + return await memory_controller.search_memories( + query, current_user, limit, score_threshold, user_id + ) @router.post("") async def add_memory( - request: AddMemoryRequest, - current_user: User = Depends(current_active_user) + request: AddMemoryRequest, current_user: User = Depends(current_active_user) ): """Add a memory directly from content text. The service will extract structured memories from the provided content.""" - return await memory_controller.add_memory(request.content, current_user, request.source_id) + return await memory_controller.add_memory( + request.content, current_user, request.source_id + ) @router.delete("/{memory_id}") -async def delete_memory(memory_id: str, current_user: User = Depends(current_active_user)): +async def delete_memory( + memory_id: str, current_user: User = Depends(current_active_user) +): """Delete a memory by ID. Users can only delete their own memories, admins can delete any.""" return await memory_controller.delete_memory(memory_id, current_user) @router.get("/admin") -async def get_all_memories_admin(current_user: User = Depends(current_superuser), limit: int = 200): +async def get_all_memories_admin( + current_user: User = Depends(current_superuser), limit: int = 200 +): """Get all memories across all users for admin review. Admin only.""" return await memory_controller.get_all_memories_admin(current_user, limit) @@ -82,7 +103,9 @@ async def get_all_memories_admin(current_user: User = Depends(current_superuser) async def get_memory_by_id( memory_id: str, current_user: User = Depends(current_active_user), - user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"), + user_id: Optional[str] = Query( + default=None, description="User ID filter (admin only)" + ), ): """Get a single memory by ID. Users can only access their own memories, admins can access any.""" return await memory_controller.get_memory_by_id(memory_id, current_user, user_id) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py index b02ed426..8d76b947 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py @@ -1,4 +1,3 @@ - import json import logging import os @@ -25,20 +24,23 @@ router = APIRouter(prefix="/obsidian", tags=["obsidian"]) + class IngestRequest(BaseModel): vault_path: str + @router.post("/ingest") async def ingest_obsidian_vault( - request: IngestRequest, - current_user: User = Depends(current_active_user) + request: IngestRequest, current_user: User = Depends(current_active_user) ): """ Immediate/synchronous ingestion endpoint (legacy). Not recommended for UI. Prefer the upload_zip + start endpoints to enable progress reporting. """ if not os.path.exists(request.vault_path): - raise HTTPException(status_code=400, detail=f"Path not found: {request.vault_path}") + raise HTTPException( + status_code=400, detail=f"Path not found: {request.vault_path}" + ) try: result = await obsidian_service.ingest_vault(request.vault_path) @@ -50,15 +52,16 @@ async def ingest_obsidian_vault( @router.post("/upload_zip") async def upload_obsidian_zip( - file: UploadFile = File(...), - current_user: User = Depends(current_superuser) + file: UploadFile = File(...), current_user: User = Depends(current_superuser) ): """ Upload a zipped Obsidian vault. Returns a job_id that can be started later. Uses upload_files_async pattern from upload_files.py for proper file handling. """ - if not file.filename.lower().endswith('.zip'): - raise HTTPException(status_code=400, detail="Please upload a .zip file of your Obsidian vault") + if not file.filename.lower().endswith(".zip"): + raise HTTPException( + status_code=400, detail="Please upload a .zip file of your Obsidian vault" + ) job_id = str(uuid.uuid4()) base_dir = Path("/app/data/obsidian_jobs") @@ -67,21 +70,23 @@ async def upload_obsidian_zip( job_dir.mkdir(parents=True, exist_ok=True) zip_path = job_dir / "vault.zip" extract_dir = job_dir / "vault" - + # Use upload_files_async pattern for proper file handling with cleanup zip_file_handle = None try: # Read file content file_content = await file.read() - + # Save zip file using proper file handling pattern from upload_files_async try: - zip_file_handle = open(zip_path, 'wb') + zip_file_handle = open(zip_path, "wb") zip_file_handle.write(file_content) except IOError as e: logger.error(f"Error writing zip file {zip_path}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to save uploaded zip: {e}") - + raise HTTPException( + status_code=500, detail=f"Failed to save uploaded zip: {e}" + ) + # Extract zip file using utility function try: extract_zip(zip_path, extract_dir) @@ -90,10 +95,12 @@ async def upload_obsidian_zip( raise HTTPException(status_code=400, detail=f"Invalid zip file: {e}") except ZipExtractionError as e: logger.error(f"Error extracting zip file: {e}") - raise HTTPException(status_code=500, detail=f"Failed to extract zip file: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to extract zip file: {e}" + ) total = count_markdown_files(str(extract_dir)) - + # Store pending job state in Redis pending_state = { "status": "ready", @@ -101,16 +108,20 @@ async def upload_obsidian_zip( "processed": 0, "errors": [], "vault_path": str(extract_dir), - "job_id": job_id + "job_id": job_id, } - redis_conn.set(f"obsidian_pending:{job_id}", json.dumps(pending_state), ex=3600*24) # 24h expiry + redis_conn.set( + f"obsidian_pending:{job_id}", json.dumps(pending_state), ex=3600 * 24 + ) # 24h expiry return {"job_id": job_id, "vault_path": str(extract_dir), "total_files": total} except HTTPException: raise except Exception as e: logger.exception(f"Failed to process uploaded zip: {e}") - raise HTTPException(status_code=500, detail=f"Failed to process uploaded zip: {e}") + raise HTTPException( + status_code=500, detail=f"Failed to process uploaded zip: {e}" + ) finally: # Ensure file handle is closed (following upload_files_async pattern) if zip_file_handle: @@ -123,17 +134,17 @@ async def upload_obsidian_zip( @router.post("/start") async def start_ingestion( job_id: str = Body(..., embed=True), - current_user: User = Depends(current_active_user) + current_user: User = Depends(current_active_user), ): # Check if job is pending pending_key = f"obsidian_pending:{job_id}" pending_data = redis_conn.get(pending_key) - + if pending_data: try: job_data = json.loads(pending_data) vault_path = job_data.get("vault_path") - + # Enqueue to RQ rq_job = default_queue.enqueue( ingest_obsidian_vault_job, @@ -141,27 +152,31 @@ async def start_ingestion( vault_path, # arg2 job_id=job_id, # Set RQ job ID to match our ID description=f"Obsidian ingestion for job {job_id}", - job_timeout=3600 # 1 hour timeout + job_timeout=3600, # 1 hour timeout ) - + # Remove pending key redis_conn.delete(pending_key) - - return {"message": "Ingestion started", "job_id": job_id, "rq_job_id": rq_job.id} + + return { + "message": "Ingestion started", + "job_id": job_id, + "rq_job_id": rq_job.id, + } except Exception as e: logger.exception(f"Failed to start job {job_id}: {e}") raise HTTPException(status_code=500, detail=f"Failed to start job: {e}") - + # Check if already in RQ try: job = Job.fetch(job_id, connection=redis_conn) status = job.get_status() if status in ("queued", "started", "deferred", "scheduled"): - raise HTTPException(status_code=400, detail=f"Job already {status}") - + raise HTTPException(status_code=400, detail=f"Job already {status}") + # If finished/failed, we could potentially restart? But for now let's say it's done. raise HTTPException(status_code=400, detail=f"Job is in state: {status}") - + except NoSuchJobError: raise HTTPException(status_code=404, detail="Job not found") @@ -171,7 +186,7 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us # 1. Try RQ first try: job = Job.fetch(job_id, connection=redis_conn) - + # Get status status = job.get_status() if status == "started": @@ -181,13 +196,18 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us meta = job.meta or {} # If meta has status, prefer it (for granular updates) - if "status" in meta and meta["status"] in ("running", "finished", "failed", "canceled"): - status = meta["status"] + if "status" in meta and meta["status"] in ( + "running", + "finished", + "failed", + "canceled", + ): + status = meta["status"] total = meta.get("total_files", 0) processed = meta.get("processed", 0) percent = int((processed / total) * 100) if total else 0 - + return { "job_id": job_id, "status": status, @@ -196,14 +216,14 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us "percent": percent, "errors": meta.get("errors", []), "vault_path": meta.get("vault_path"), - "rq_job_id": job.id + "rq_job_id": job.id, } - + except NoSuchJobError: # 2. Check pending pending_key = f"obsidian_pending:{job_id}" pending_data = redis_conn.get(pending_key) - + if pending_data: try: job_data = json.loads(pending_data) @@ -214,10 +234,8 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us "processed": 0, "percent": 0, "errors": [], - "vault_path": job_data.get("vault_path") + "vault_path": job_data.get("vault_path"), } except: raise HTTPException(status_code=500, detail="Failed to get job status") raise HTTPException(status_code=404, detail="Job not found") - - diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py index 29719566..745321f5 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py @@ -73,7 +73,9 @@ async def list_jobs( @router.get("/jobs/{job_id}/status") -async def get_job_status(job_id: str, current_user: User = Depends(current_active_user)): +async def get_job_status( + job_id: str, current_user: User = Depends(current_active_user) +): """Get just the status of a specific job (lightweight endpoint).""" try: job = Job.fetch(job_id, connection=redis_conn) @@ -193,7 +195,9 @@ async def cancel_job(job_id: str, current_user: User = Depends(current_active_us @router.get("/jobs/by-client/{client_id}") -async def get_jobs_by_client(client_id: str, current_user: User = Depends(current_active_user)): +async def get_jobs_by_client( + client_id: str, current_user: User = Depends(current_active_user) +): """Get all jobs associated with a specific client device.""" try: from rq.registry import ( @@ -242,11 +246,17 @@ def process_job_and_dependents(job, queue_name, base_status): all_jobs.append( { "job_id": job.id, - "job_type": (job.func_name.split(".")[-1] if job.func_name else "unknown"), + "job_type": ( + job.func_name.split(".")[-1] if job.func_name else "unknown" + ), "queue": queue_name, "status": status, - "created_at": (job.created_at.isoformat() if job.created_at else None), - "started_at": (job.started_at.isoformat() if job.started_at else None), + "created_at": ( + job.created_at.isoformat() if job.created_at else None + ), + "started_at": ( + job.started_at.isoformat() if job.started_at else None + ), "ended_at": job.ended_at.isoformat() if job.ended_at else None, "description": job.description or "", "result": job.result, @@ -323,13 +333,17 @@ def process_job_and_dependents(job, queue_name, base_status): # Sort by created_at all_jobs.sort(key=lambda x: x["created_at"] or "", reverse=False) - logger.info(f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)") + logger.info( + f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)" + ) return {"client_id": client_id, "jobs": all_jobs, "total": len(all_jobs)} except Exception as e: logger.error(f"Failed to get jobs for client {client_id}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get jobs for client: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get jobs for client: {str(e)}" + ) @router.get("/events") @@ -349,7 +363,9 @@ async def get_events( if not router_instance: return {"events": [], "total": 0} - events = router_instance.get_recent_events(limit=limit, event_type=event_type or None) + events = router_instance.get_recent_events( + limit=limit, event_type=event_type or None + ) return {"events": events, "total": len(events)} except Exception as e: logger.error(f"Failed to get events: {e}") @@ -467,7 +483,9 @@ async def get_queue_worker_details(current_user: User = Depends(current_active_u except Exception as e: logger.error(f"Failed to get queue worker details: {e}") - raise HTTPException(status_code=500, detail=f"Failed to get worker details: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get worker details: {str(e)}" + ) @router.get("/streams") @@ -498,7 +516,9 @@ async def get_stream_stats( async def get_stream_info(stream_key): try: - stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key + stream_name = ( + stream_key.decode() if isinstance(stream_key, bytes) else stream_key + ) # Get basic stream info info = await audio_service.redis.xinfo_stream(stream_name) @@ -558,7 +578,9 @@ async def get_stream_info(stream_key): "name": group_dict.get("name", "unknown"), "consumers": group_dict.get("consumers", 0), "pending": group_dict.get("pending", 0), - "last_delivered_id": group_dict.get("last-delivered-id", "N/A"), + "last_delivered_id": group_dict.get( + "last-delivered-id", "N/A" + ), "consumer_details": consumers, } ) @@ -569,7 +591,9 @@ async def get_stream_info(stream_key): "stream_name": stream_name, "length": info[b"length"], "first_entry_id": ( - info[b"first-entry"][0].decode() if info[b"first-entry"] else None + info[b"first-entry"][0].decode() + if info[b"first-entry"] + else None ), "last_entry_id": ( info[b"last-entry"][0].decode() if info[b"last-entry"] else None @@ -581,7 +605,9 @@ async def get_stream_info(stream_key): return None # Fetch all stream info in parallel - streams_info_results = await asyncio.gather(*[get_stream_info(key) for key in stream_keys]) + streams_info_results = await asyncio.gather( + *[get_stream_info(key) for key in stream_keys] + ) streams_info = [info for info in streams_info_results if info is not None] return { @@ -607,7 +633,9 @@ class FlushAllJobsRequest(BaseModel): @router.post("/flush") -async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(current_active_user)): +async def flush_jobs( + request: FlushJobsRequest, current_user: User = Depends(current_active_user) +): """Flush old inactive jobs based on age and status.""" if not current_user.is_superuser: raise HTTPException(status_code=403, detail="Admin access required") @@ -623,7 +651,9 @@ async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(cur from advanced_omi_backend.controllers.queue_controller import get_queue - cutoff_time = datetime.now(timezone.utc) - timedelta(hours=request.older_than_hours) + cutoff_time = datetime.now(timezone.utc) - timedelta( + hours=request.older_than_hours + ) total_removed = 0 # Get all queues @@ -655,7 +685,9 @@ async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(cur except Exception as e: logger.error(f"Error deleting job {job_id}: {e}") - if "canceled" in request.statuses: # RQ standard (US spelling), not "cancelled" + if ( + "canceled" in request.statuses + ): # RQ standard (US spelling), not "cancelled" registry = CanceledJobRegistry(queue=queue) for job_id in registry.get_job_ids(): try: @@ -737,8 +769,12 @@ async def flush_all_jobs( registries.append(("finished", FinishedJobRegistry(queue=queue))) for registry_name, registry in registries: - job_ids = list(registry.get_job_ids()) # Convert to list to avoid iterator issues - logger.info(f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}") + job_ids = list( + registry.get_job_ids() + ) # Convert to list to avoid iterator issues + logger.info( + f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}" + ) for job_id in job_ids: try: @@ -748,7 +784,9 @@ async def flush_all_jobs( # Skip session-level jobs (e.g., speech_detection, audio_persistence) # These run for the entire session and should not be killed by test cleanup if job.meta and job.meta.get("session_level"): - logger.info(f"Skipping session-level job {job_id} ({job.description})") + logger.info( + f"Skipping session-level job {job_id} ({job.description})" + ) continue # Handle running jobs differently to avoid worker deadlock @@ -759,7 +797,9 @@ async def flush_all_jobs( from rq.command import send_stop_job_command send_stop_job_command(redis_conn, job_id) - logger.info(f"Sent stop command to worker for job {job_id}") + logger.info( + f"Sent stop command to worker for job {job_id}" + ) # Don't delete yet - let worker move it to canceled/failed registry # It will be cleaned up on next flush or by worker cleanup continue @@ -770,9 +810,13 @@ async def flush_all_jobs( # If stop fails, try to cancel it (may already be finishing) try: job.cancel() - logger.info(f"Cancelled job {job_id} after stop failed") + logger.info( + f"Cancelled job {job_id} after stop failed" + ) except Exception as cancel_error: - logger.warning(f"Could not cancel job {job_id}: {cancel_error}") + logger.warning( + f"Could not cancel job {job_id}: {cancel_error}" + ) # For non-running jobs, safe to delete immediately job.delete() @@ -787,7 +831,9 @@ async def flush_all_jobs( f"Removed stale job reference {job_id} from {registry_name} registry" ) except Exception as reg_error: - logger.error(f"Could not remove {job_id} from registry: {reg_error}") + logger.error( + f"Could not remove {job_id} from registry: {reg_error}" + ) # Also clean up audio streams and consumer locks deleted_keys = 0 @@ -801,7 +847,9 @@ async def flush_all_jobs( # Delete audio streams cursor = 0 while True: - cursor, keys = await async_redis.scan(cursor, match="audio:*", count=1000) + cursor, keys = await async_redis.scan( + cursor, match="audio:*", count=1000 + ) if keys: await async_redis.delete(*keys) deleted_keys += len(keys) @@ -811,7 +859,9 @@ async def flush_all_jobs( # Delete consumer locks cursor = 0 while True: - cursor, keys = await async_redis.scan(cursor, match="consumer:*", count=1000) + cursor, keys = await async_redis.scan( + cursor, match="consumer:*", count=1000 + ) if keys: await async_redis.delete(*keys) deleted_keys += len(keys) @@ -840,7 +890,9 @@ async def flush_all_jobs( except Exception as e: logger.error(f"Failed to flush all jobs: {e}") - raise HTTPException(status_code=500, detail=f"Failed to flush all jobs: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to flush all jobs: {str(e)}" + ) @router.get("/sessions") @@ -860,7 +912,9 @@ async def get_redis_sessions( session_keys = [] cursor = b"0" while cursor and len(session_keys) < limit: - cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=limit) + cursor, keys = await redis_client.scan( + cursor, match="audio:session:*", count=limit + ) session_keys.extend(keys[: limit - len(session_keys)]) # Get session info @@ -872,8 +926,12 @@ async def get_redis_sessions( session_id = key.decode().replace("audio:session:", "") # Get conversation count for this session - conversation_count_key = f"session:conversation_count:{session_id}" - conversation_count_bytes = await redis_client.get(conversation_count_key) + conversation_count_key = ( + f"session:conversation_count:{session_id}" + ) + conversation_count_bytes = await redis_client.get( + conversation_count_key + ) conversation_count = ( int(conversation_count_bytes.decode()) if conversation_count_bytes @@ -884,16 +942,25 @@ async def get_redis_sessions( { "session_id": session_id, "user_id": session_data.get(b"user_id", b"").decode(), - "client_id": session_data.get(b"client_id", b"").decode(), - "stream_name": session_data.get(b"stream_name", b"").decode(), + "client_id": session_data.get( + b"client_id", b"" + ).decode(), + "stream_name": session_data.get( + b"stream_name", b"" + ).decode(), "provider": session_data.get(b"provider", b"").decode(), "mode": session_data.get(b"mode", b"").decode(), "status": session_data.get(b"status", b"").decode(), - "started_at": session_data.get(b"started_at", b"").decode(), + "started_at": session_data.get( + b"started_at", b"" + ).decode(), "chunks_published": int( - session_data.get(b"chunks_published", b"0").decode() or 0 + session_data.get(b"chunks_published", b"0").decode() + or 0 ), - "last_chunk_at": session_data.get(b"last_chunk_at", b"").decode(), + "last_chunk_at": session_data.get( + b"last_chunk_at", b"" + ).decode(), "conversation_count": conversation_count, } ) @@ -936,7 +1003,9 @@ async def clear_old_sessions( session_keys = [] cursor = b"0" while cursor: - cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=100) + cursor, keys = await redis_client.scan( + cursor, match="audio:session:*", count=100 + ) session_keys.extend(keys) # Check each session and delete if old @@ -955,13 +1024,18 @@ async def clear_old_sessions( except Exception as e: logger.error(f"Error processing session {key}: {e}") - return {"deleted_count": deleted_count, "cutoff_seconds": older_than_seconds} + return { + "deleted_count": deleted_count, + "cutoff_seconds": older_than_seconds, + } finally: await redis_client.aclose() except Exception as e: logger.error(f"Failed to clear sessions: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to clear sessions: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to clear sessions: {str(e)}" + ) @router.get("/dashboard") @@ -1013,11 +1087,17 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100): if status_name == "queued": job_ids = queue.job_ids[:limit] elif status_name == "started": # RQ standard, not "processing" - job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[:limit] + job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[ + :limit + ] elif status_name == "finished": # RQ standard, not "completed" - job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[:limit] + job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[ + :limit + ] elif status_name == "failed": - job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[:limit] + job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[ + :limit + ] else: continue @@ -1028,7 +1108,9 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100): # Check user permission if not current_user.is_superuser: - job_user_id = job.kwargs.get("user_id") if job.kwargs else None + job_user_id = ( + job.kwargs.get("user_id") if job.kwargs else None + ) if job_user_id != str(current_user.user_id): continue @@ -1037,24 +1119,38 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100): { "job_id": job.id, "job_type": ( - job.func_name.split(".")[-1] if job.func_name else "unknown" + job.func_name.split(".")[-1] + if job.func_name + else "unknown" + ), + "user_id": ( + job.kwargs.get("user_id") + if job.kwargs + else None ), - "user_id": (job.kwargs.get("user_id") if job.kwargs else None), "status": status_name, "priority": "normal", # RQ doesn't have priority concept "data": {"description": job.description or ""}, "result": job.result, "meta": job.meta if job.meta else {}, "kwargs": job.kwargs if job.kwargs else {}, - "error_message": (str(job.exc_info) if job.exc_info else None), + "error_message": ( + str(job.exc_info) if job.exc_info else None + ), "created_at": ( - job.created_at.isoformat() if job.created_at else None + job.created_at.isoformat() + if job.created_at + else None ), "started_at": ( - job.started_at.isoformat() if job.started_at else None + job.started_at.isoformat() + if job.started_at + else None ), "ended_at": ( - job.ended_at.isoformat() if job.ended_at else None + job.ended_at.isoformat() + if job.ended_at + else None ), "retry_count": 0, # RQ doesn't track this by default "max_retries": 0, @@ -1170,7 +1266,11 @@ def get_job_status(job): # Check user permission if not current_user.is_superuser: - job_user_id = job.kwargs.get("user_id") if job.kwargs else None + job_user_id = ( + job.kwargs.get("user_id") + if job.kwargs + else None + ) if job_user_id != str(current_user.user_id): continue @@ -1186,13 +1286,19 @@ def get_job_status(job): "queue": queue_name, "status": get_job_status(job), "created_at": ( - job.created_at.isoformat() if job.created_at else None + job.created_at.isoformat() + if job.created_at + else None ), "started_at": ( - job.started_at.isoformat() if job.started_at else None + job.started_at.isoformat() + if job.started_at + else None ), "ended_at": ( - job.ended_at.isoformat() if job.ended_at else None + job.ended_at.isoformat() + if job.ended_at + else None ), "description": job.description or "", "result": job.result, @@ -1255,12 +1361,20 @@ async def fetch_events(): ) queued_jobs = results[0] if not isinstance(results[0], Exception) else [] - started_jobs = results[1] if not isinstance(results[1], Exception) else [] # RQ standard - finished_jobs = results[2] if not isinstance(results[2], Exception) else [] # RQ standard + started_jobs = ( + results[1] if not isinstance(results[1], Exception) else [] + ) # RQ standard + finished_jobs = ( + results[2] if not isinstance(results[2], Exception) else [] + ) # RQ standard failed_jobs = results[3] if not isinstance(results[3], Exception) else [] - stats = results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0} + stats = ( + results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0} + ) streaming_status = ( - results[5] if not isinstance(results[5], Exception) else {"active_sessions": []} + results[5] + if not isinstance(results[5], Exception) + else {"active_sessions": []} ) events = results[6] if not isinstance(results[6], Exception) else [] recent_conversations = [] @@ -1279,7 +1393,9 @@ async def fetch_events(): { "conversation_id": conv.conversation_id, "user_id": str(conv.user_id) if conv.user_id else None, - "created_at": (conv.created_at.isoformat() if conv.created_at else None), + "created_at": ( + conv.created_at.isoformat() if conv.created_at else None + ), "title": conv.title, "summary": conv.summary, "transcript_text": ( @@ -1307,4 +1423,6 @@ async def fetch_events(): except Exception as e: logger.error(f"Failed to get dashboard data: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to get dashboard data: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to get dashboard data: {str(e)}" + ) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py index 277d7dc1..0f580be7 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py @@ -29,6 +29,7 @@ # Request models for memory config endpoints class MemoryConfigRequest(BaseModel): """Request model for memory configuration validation and updates.""" + config_yaml: str @@ -68,8 +69,7 @@ async def get_diarization_settings(current_user: User = Depends(current_superuse @router.post("/diarization-settings") async def save_diarization_settings( - settings: dict, - current_user: User = Depends(current_superuser) + settings: dict, current_user: User = Depends(current_superuser) ): """Save diarization settings. Admin only.""" return await system_controller.save_diarization_settings_controller(settings) @@ -83,32 +83,36 @@ async def get_misc_settings(current_user: User = Depends(current_superuser)): @router.post("/misc-settings") async def save_misc_settings( - settings: dict, - current_user: User = Depends(current_superuser) + settings: dict, current_user: User = Depends(current_superuser) ): """Save miscellaneous configuration settings. Admin only.""" return await system_controller.save_misc_settings_controller(settings) @router.get("/cleanup-settings") -async def get_cleanup_settings( - current_user: User = Depends(current_superuser) -): +async def get_cleanup_settings(current_user: User = Depends(current_superuser)): """Get cleanup configuration settings. Admin only.""" return await system_controller.get_cleanup_settings_controller(current_user) @router.post("/cleanup-settings") async def save_cleanup_settings( - auto_cleanup_enabled: bool = Body(..., description="Enable automatic cleanup of soft-deleted conversations"), - retention_days: int = Body(..., ge=1, le=365, description="Number of days to keep soft-deleted conversations"), - current_user: User = Depends(current_superuser) + auto_cleanup_enabled: bool = Body( + ..., description="Enable automatic cleanup of soft-deleted conversations" + ), + retention_days: int = Body( + ..., + ge=1, + le=365, + description="Number of days to keep soft-deleted conversations", + ), + current_user: User = Depends(current_superuser), ): """Save cleanup configuration settings. Admin only.""" return await system_controller.save_cleanup_settings_controller( auto_cleanup_enabled=auto_cleanup_enabled, retention_days=retention_days, - user=current_user + user=current_user, ) @@ -120,11 +124,12 @@ async def get_speaker_configuration(current_user: User = Depends(current_active_ @router.post("/speaker-configuration") async def update_speaker_configuration( - primary_speakers: list[dict], - current_user: User = Depends(current_active_user) + primary_speakers: list[dict], current_user: User = Depends(current_active_user) ): """Update current user's primary speakers configuration.""" - return await system_controller.update_speaker_configuration(current_user, primary_speakers) + return await system_controller.update_speaker_configuration( + current_user, primary_speakers + ) @router.get("/enrolled-speakers") @@ -141,6 +146,7 @@ async def get_speaker_service_status(current_user: User = Depends(current_superu # LLM Operations Configuration Endpoints + @router.get("/admin/llm-operations") async def get_llm_operations(current_user: User = Depends(current_superuser)): """Get LLM operation configurations. Admin only.""" @@ -149,8 +155,7 @@ async def get_llm_operations(current_user: User = Depends(current_superuser)): @router.post("/admin/llm-operations") async def save_llm_operations( - operations: dict, - current_user: User = Depends(current_superuser) + operations: dict, current_user: User = Depends(current_superuser) ): """Save LLM operation configurations. Admin only.""" return await system_controller.save_llm_operations(operations) @@ -159,7 +164,7 @@ async def save_llm_operations( @router.post("/admin/llm-operations/test") async def test_llm_model( model_name: Optional[str] = Body(None, embed=True), - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Test an LLM model connection with a trivial prompt. Admin only.""" return await system_controller.test_llm_model(model_name) @@ -171,10 +176,11 @@ async def get_memory_config_raw(current_user: User = Depends(current_superuser)) """Get memory configuration YAML from config.yml. Admin only.""" return await system_controller.get_memory_config_raw() + @router.post("/admin/memory/config/raw") async def update_memory_config_raw( config_yaml: str = Body(..., media_type="text/plain"), - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Save memory YAML to config.yml and hot-reload. Admin only.""" return await system_controller.update_memory_config_raw(config_yaml) @@ -191,8 +197,7 @@ async def validate_memory_config_raw( @router.post("/admin/memory/config/validate") async def validate_memory_config( - request: MemoryConfigRequest, - current_user: User = Depends(current_superuser) + request: MemoryConfigRequest, current_user: User = Depends(current_superuser) ): """Validate memory configuration YAML sent as JSON (used by tests). Admin only.""" return await system_controller.validate_memory_config(request.config_yaml) @@ -212,6 +217,7 @@ async def delete_all_user_memories(current_user: User = Depends(current_active_u # Chat Configuration Management Endpoints + @router.get("/admin/chat/config", response_class=Response) async def get_chat_config(current_user: User = Depends(current_superuser)): """Get chat configuration as YAML. Admin only.""" @@ -225,13 +231,12 @@ async def get_chat_config(current_user: User = Depends(current_superuser)): @router.post("/admin/chat/config") async def save_chat_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Save chat configuration from YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.save_chat_config_yaml(yaml_str) return JSONResponse(content=result) except ValueError as e: @@ -243,13 +248,12 @@ async def save_chat_config( @router.post("/admin/chat/config/validate") async def validate_chat_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Validate chat configuration YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.validate_chat_config_yaml(yaml_str) return JSONResponse(content=result) except Exception as e: @@ -259,6 +263,7 @@ async def validate_chat_config( # Plugin Configuration Management Endpoints + @router.get("/admin/plugins/config", response_class=Response) async def get_plugins_config(current_user: User = Depends(current_superuser)): """Get plugins configuration as YAML. Admin only.""" @@ -272,13 +277,12 @@ async def get_plugins_config(current_user: User = Depends(current_superuser)): @router.post("/admin/plugins/config") async def save_plugins_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Save plugins configuration from YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.save_plugins_config_yaml(yaml_str) return JSONResponse(content=result) except ValueError as e: @@ -290,13 +294,12 @@ async def save_plugins_config( @router.post("/admin/plugins/config/validate") async def validate_plugins_config( - request: Request, - current_user: User = Depends(current_superuser) + request: Request, current_user: User = Depends(current_superuser) ): """Validate plugins configuration YAML. Admin only.""" try: yaml_content = await request.body() - yaml_str = yaml_content.decode('utf-8') + yaml_str = yaml_content.decode("utf-8") result = await system_controller.validate_plugins_config_yaml(yaml_str) return JSONResponse(content=result) except Exception as e: @@ -306,6 +309,7 @@ async def validate_plugins_config( # Structured Plugin Configuration Endpoints (Form-based UI) + @router.post("/admin/plugins/reload") async def reload_plugins( request: Request, @@ -360,9 +364,16 @@ async def get_plugins_health(current_user: User = Depends(current_superuser)): """Get plugin health status for all registered plugins. Admin only.""" try: from advanced_omi_backend.services.plugin_service import get_plugin_router + plugin_router = get_plugin_router() if not plugin_router: - return {"total": 0, "initialized": 0, "failed": 0, "registered": 0, "plugins": []} + return { + "total": 0, + "initialized": 0, + "failed": 0, + "registered": 0, + "plugins": [], + } return plugin_router.get_health_summary() except Exception as e: logger.error(f"Failed to get plugins health: {e}") @@ -377,6 +388,7 @@ async def get_plugins_connectivity(current_user: User = Depends(current_superuse """ try: from advanced_omi_backend.services.plugin_service import get_plugin_router + plugin_router = get_plugin_router() if not plugin_router: return {"plugins": {}} @@ -406,6 +418,7 @@ async def get_plugins_metadata(current_user: User = Depends(current_superuser)): class PluginConfigRequest(BaseModel): """Request model for structured plugin configuration updates.""" + orchestration: Optional[dict] = None settings: Optional[dict] = None env_vars: Optional[dict] = None @@ -413,6 +426,7 @@ class PluginConfigRequest(BaseModel): class CreatePluginRequest(BaseModel): """Request model for creating a new plugin.""" + plugin_name: str description: str events: list[str] = [] @@ -421,12 +435,14 @@ class CreatePluginRequest(BaseModel): class WritePluginCodeRequest(BaseModel): """Request model for writing plugin code.""" + code: str config_yml: Optional[str] = None class PluginAssistantRequest(BaseModel): """Request model for plugin assistant chat.""" + messages: list[dict] @@ -434,7 +450,7 @@ class PluginAssistantRequest(BaseModel): async def update_plugin_config_structured( plugin_id: str, config: PluginConfigRequest, - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Update plugin configuration from structured JSON (form data). Admin only. @@ -445,7 +461,9 @@ async def update_plugin_config_structured( """ try: config_dict = config.dict(exclude_none=True) - result = await system_controller.update_plugin_config_structured(plugin_id, config_dict) + result = await system_controller.update_plugin_config_structured( + plugin_id, config_dict + ) return JSONResponse(content=result) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -458,7 +476,7 @@ async def update_plugin_config_structured( async def test_plugin_connection( plugin_id: str, config: PluginConfigRequest, - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Test plugin connection/configuration without saving. Admin only. @@ -490,7 +508,9 @@ async def create_plugin( plugin_code=request.plugin_code, ) if not result.get("success"): - raise HTTPException(status_code=400, detail=result.get("error", "Unknown error")) + raise HTTPException( + status_code=400, detail=result.get("error", "Unknown error") + ) return JSONResponse(content=result) except HTTPException: raise @@ -513,7 +533,9 @@ async def write_plugin_code( config_yml=request.config_yml, ) if not result.get("success"): - raise HTTPException(status_code=400, detail=result.get("error", "Unknown error")) + raise HTTPException( + status_code=400, detail=result.get("error", "Unknown error") + ) return JSONResponse(content=result) except HTTPException: raise @@ -535,7 +557,9 @@ async def delete_plugin( remove_files=remove_files, ) if not result.get("success"): - raise HTTPException(status_code=400, detail=result.get("error", "Unknown error")) + raise HTTPException( + status_code=400, detail=result.get("error", "Unknown error") + ) return JSONResponse(content=result) except HTTPException: raise @@ -572,25 +596,34 @@ async def event_stream(): @router.get("/streaming/status") -async def get_streaming_status(request: Request, current_user: User = Depends(current_superuser)): +async def get_streaming_status( + request: Request, current_user: User = Depends(current_superuser) +): """Get status of active streaming sessions and Redis Streams health. Admin only.""" return await session_controller.get_streaming_status(request) @router.post("/streaming/cleanup") -async def cleanup_stuck_stream_workers(request: Request, current_user: User = Depends(current_superuser)): +async def cleanup_stuck_stream_workers( + request: Request, current_user: User = Depends(current_superuser) +): """Clean up stuck Redis Stream workers and pending messages. Admin only.""" return await queue_controller.cleanup_stuck_stream_workers(request) @router.post("/streaming/cleanup-sessions") -async def cleanup_old_sessions(request: Request, max_age_seconds: int = 3600, current_user: User = Depends(current_superuser)): +async def cleanup_old_sessions( + request: Request, + max_age_seconds: int = 3600, + current_user: User = Depends(current_superuser), +): """Clean up old session tracking metadata. Admin only.""" return await session_controller.cleanup_old_sessions(request, max_age_seconds) # Memory Provider Configuration Endpoints + @router.get("/admin/memory/provider") async def get_memory_provider(current_user: User = Depends(current_superuser)): """Get current memory provider configuration. Admin only.""" @@ -600,7 +633,7 @@ async def get_memory_provider(current_user: User = Depends(current_superuser)): @router.post("/admin/memory/provider") async def set_memory_provider( provider: str = Body(..., embed=True), - current_user: User = Depends(current_superuser) + current_user: User = Depends(current_superuser), ): """Set memory provider and restart backend services. Admin only.""" return await system_controller.set_memory_provider(provider) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py index 349fe33d..a5488450 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py @@ -37,7 +37,7 @@ async def clear_test_plugin_events(): # Clear events from all plugins that have storage for plugin_id, plugin in plugin_router.plugins.items(): - if hasattr(plugin, 'storage') and plugin.storage: + if hasattr(plugin, "storage") and plugin.storage: try: cleared = await plugin.storage.clear_events() total_cleared += cleared @@ -45,10 +45,7 @@ async def clear_test_plugin_events(): except Exception as e: logger.error(f"Error clearing events from plugin '{plugin_id}': {e}") - return { - "message": "Test plugin events cleared", - "events_cleared": total_cleared - } + return {"message": "Test plugin events cleared", "events_cleared": total_cleared} @router.get("/plugins/events/count") @@ -65,23 +62,33 @@ async def get_test_plugin_event_count(event_type: Optional[str] = None): plugin_router = get_plugin_router() if not plugin_router: - return {"count": 0, "event_type": event_type, "message": "No plugin router initialized"} + return { + "count": 0, + "event_type": event_type, + "message": "No plugin router initialized", + } # Get count from first plugin with storage (usually test_event plugin) for plugin_id, plugin in plugin_router.plugins.items(): - if hasattr(plugin, 'storage') and plugin.storage: + if hasattr(plugin, "storage") and plugin.storage: try: count = await plugin.storage.get_event_count(event_type) return { "count": count, "event_type": event_type, - "plugin_id": plugin_id + "plugin_id": plugin_id, } except Exception as e: - logger.error(f"Error getting event count from plugin '{plugin_id}': {e}") + logger.error( + f"Error getting event count from plugin '{plugin_id}': {e}" + ) raise HTTPException(status_code=500, detail=str(e)) - return {"count": 0, "event_type": event_type, "message": "No plugin with storage found"} + return { + "count": 0, + "event_type": event_type, + "message": "No plugin with storage found", + } @router.get("/plugins/events") @@ -102,7 +109,7 @@ async def get_test_plugin_events(event_type: Optional[str] = None): # Get events from first plugin with storage for plugin_id, plugin in plugin_router.plugins.items(): - if hasattr(plugin, 'storage') and plugin.storage: + if hasattr(plugin, "storage") and plugin.storage: try: if event_type: events = await plugin.storage.get_events_by_type(event_type) @@ -113,7 +120,7 @@ async def get_test_plugin_events(event_type: Optional[str] = None): "events": events, "count": len(events), "event_type": event_type, - "plugin_id": plugin_id + "plugin_id": plugin_id, } except Exception as e: logger.error(f"Error getting events from plugin '{plugin_id}': {e}") diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py index 12ed5c63..f66b5d39 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py @@ -24,13 +24,17 @@ async def get_users(current_user: User = Depends(current_superuser)): @router.post("") -async def create_user(user_data: UserCreate, current_user: User = Depends(current_superuser)): +async def create_user( + user_data: UserCreate, current_user: User = Depends(current_superuser) +): """Create a new user. Admin only.""" return await user_controller.create_user(user_data) @router.put("/{user_id}") -async def update_user(user_id: str, user_data: UserUpdate, current_user: User = Depends(current_superuser)): +async def update_user( + user_id: str, user_data: UserUpdate, current_user: User = Depends(current_superuser) +): """Update a user. Admin only.""" return await user_controller.update_user(user_id, user_data) @@ -43,4 +47,6 @@ async def delete_user( delete_memories: bool = False, ): """Delete a user and optionally their associated data. Admin only.""" - return await user_controller.delete_user(user_id, delete_conversations, delete_memories) + return await user_controller.delete_user( + user_id, delete_conversations, delete_memories + ) diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py index 4b244343..11a24f88 100644 --- a/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py +++ b/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py @@ -19,6 +19,7 @@ # Create router router = APIRouter(tags=["websocket"]) + @router.websocket("/ws") async def ws_endpoint( ws: WebSocket, @@ -42,11 +43,13 @@ async def ws_endpoint( codec = codec.lower() if codec not in ["pcm", "opus"]: logger.warning(f"Unsupported codec requested: {codec}") - await ws.close(code=1008, reason=f"Unsupported codec: {codec}. Supported: pcm, opus") + await ws.close( + code=1008, reason=f"Unsupported codec: {codec}. Supported: pcm, opus" + ) return # Route to appropriate handler if codec == "opus": await handle_omi_websocket(ws, token, device_name) else: - await handle_pcm_websocket(ws, token, device_name) \ No newline at end of file + await handle_pcm_websocket(ws, token, device_name) diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_service.py b/backends/advanced/src/advanced_omi_backend/services/audio_service.py index 992ede75..09914e24 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_service.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_service.py @@ -42,9 +42,9 @@ def __init__(self, redis_url: Optional[str] = None): self.memory_events_stream = "memory:events" # Consumer group names (action verbs - what they DO) - self.audio_writer = "audio-file-writer" # Writes audio chunks to WAV files - self.memory_enqueuer = "memory-job-enqueuer" # Enqueues memory extraction jobs - self.event_listener = "event-listener" # Listens for completion events + self.audio_writer = "audio-file-writer" # Writes audio chunks to WAV files + self.memory_enqueuer = "memory-job-enqueuer" # Enqueues memory extraction jobs + self.event_listener = "event-listener" # Listens for completion events async def connect(self): """Connect to Redis with connection pooling.""" @@ -55,7 +55,7 @@ async def connect(self): max_connections=20, # Allow multiple concurrent operations socket_keepalive=True, socket_connect_timeout=5, - retry_on_timeout=True + retry_on_timeout=True, ) logger.info(f"Audio stream service connected to Redis at {self.redis_url}") @@ -84,7 +84,7 @@ async def publish_audio_chunk( user_email: str, audio_chunk: AudioChunk, audio_uuid: Optional[str] = None, - timestamp: Optional[int] = None + timestamp: Optional[int] = None, ) -> str: """ Publish audio chunk to Redis Stream. @@ -135,12 +135,11 @@ async def publish_audio_chunk( # Ensure consumer group exists for this stream try: await self.redis.xgroup_create( - stream_name, - self.audio_writer, - id="0", - mkstream=True + stream_name, self.audio_writer, id="0", mkstream=True + ) + audio_logger.debug( + f"Created consumer group {self.audio_writer} for {stream_name}" ) - audio_logger.debug(f"Created consumer group {self.audio_writer} for {stream_name}") except aioredis.ResponseError as e: if "BUSYGROUP" not in str(e): raise @@ -152,7 +151,7 @@ async def publish_transcript_event( audio_uuid: str, conversation_id: str, status: str, - error: Optional[str] = None + error: Optional[str] = None, ): """ Publish transcript completion event. @@ -189,7 +188,7 @@ async def publish_transcript_event( self.transcript_events_stream, self.memory_enqueuer, id="0", - mkstream=True + mkstream=True, ) except aioredis.ResponseError as e: if "BUSYGROUP" not in str(e): @@ -200,7 +199,7 @@ async def publish_memory_event( conversation_id: str, status: str, memory_count: int = 0, - error: Optional[str] = None + error: Optional[str] = None, ): """ Publish memory processing event. @@ -234,21 +233,14 @@ async def publish_memory_event( # Ensure consumer group exists try: await self.redis.xgroup_create( - self.memory_events_stream, - self.event_listener, - id="0", - mkstream=True + self.memory_events_stream, self.event_listener, id="0", mkstream=True ) except aioredis.ResponseError as e: if "BUSYGROUP" not in str(e): raise async def consume_audio_stream( - self, - consumer_name: str, - callback, - block_ms: int = 5000, - count: int = 10 + self, consumer_name: str, callback, block_ms: int = 5000, count: int = 10 ): """ Consume audio chunks from all client streams. @@ -288,7 +280,7 @@ async def consume_audio_stream( consumer_name, streams_dict, count=count, - block=block_ms + block=block_ms, ) for stream_name, stream_messages in messages: @@ -299,15 +291,13 @@ async def consume_audio_stream( # Acknowledge message await self.redis.xack( - stream_name, - self.audio_writer, - message_id + stream_name, self.audio_writer, message_id ) except Exception as e: logger.error( f"Error processing audio message {message_id.decode()}: {e}", - exc_info=True + exc_info=True, ) except Exception as e: diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py index 19b76874..4cfd5a90 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py @@ -49,8 +49,12 @@ async def get_session_results(self, session_id: str) -> list[dict]: "text": fields[b"text"].decode(), "confidence": float(fields[b"confidence"].decode()), "provider": fields[b"provider"].decode(), - "chunk_id": fields.get(b"chunk_id", b"unknown").decode(), # Handle missing chunk_id gracefully - "processing_time": float(fields.get(b"processing_time", b"0.0").decode()), + "chunk_id": fields.get( + b"chunk_id", b"unknown" + ).decode(), # Handle missing chunk_id gracefully + "processing_time": float( + fields.get(b"processing_time", b"0.0").decode() + ), "timestamp": float(fields[b"timestamp"].decode()), } @@ -104,7 +108,7 @@ async def get_combined_results(self, session_id: str) -> dict: "segments": [], "chunk_count": 0, "total_confidence": 0.0, - "provider": None + "provider": None, } # Combine ALL final results for cumulative speech detection @@ -150,7 +154,7 @@ async def get_combined_results(self, session_id: str) -> dict: "segments": all_segments, "chunk_count": len(results), "total_confidence": avg_confidence, - "provider": provider + "provider": provider, } logger.info( @@ -162,10 +166,7 @@ async def get_combined_results(self, session_id: str) -> dict: return combined async def get_realtime_results( - self, - session_id: str, - last_id: str = "0", - timeout_ms: int = 1000 + self, session_id: str, last_id: str = "0", timeout_ms: int = 1000 ) -> tuple[list[dict], str]: """ Get new results since last_id (for real-time streaming). @@ -183,9 +184,7 @@ async def get_realtime_results( try: # Read new messages since last_id messages = await self.redis_client.xread( - {stream_name: last_id}, - count=10, - block=timeout_ms + {stream_name: last_id}, count=10, block=timeout_ms ) results = [] @@ -199,14 +198,18 @@ async def get_realtime_results( "text": fields[b"text"].decode(), "confidence": float(fields[b"confidence"].decode()), "provider": fields[b"provider"].decode(), - "chunk_id": fields.get(b"chunk_id", b"unknown").decode(), # Handle missing chunk_id gracefully + "chunk_id": fields.get( + b"chunk_id", b"unknown" + ).decode(), # Handle missing chunk_id gracefully } # Optional fields if b"words" in fields: result["words"] = json.loads(fields[b"words"].decode()) if b"segments" in fields: - result["segments"] = json.loads(fields[b"segments"].decode()) + result["segments"] = json.loads( + fields[b"segments"].decode() + ) results.append(result) new_last_id = message_id.decode() @@ -214,5 +217,7 @@ async def get_realtime_results( return results, new_last_id except Exception as e: - logger.error(f"πŸ”„ Error getting realtime results for session {session_id}: {e}") + logger.error( + f"πŸ”„ Error getting realtime results for session {session_id}: {e}" + ) return [], last_id diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py index 455ebebe..8902ae5d 100644 --- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py +++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py @@ -23,7 +23,9 @@ class BaseAudioStreamConsumer(ABC): Writes results to transcription:results:{session_id}. """ - def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks: int = 30): + def __init__( + self, provider_name: str, redis_client: redis.Redis, buffer_chunks: int = 30 + ): """ Initialize consumer. @@ -50,7 +52,9 @@ def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks: self.active_streams = {} # {stream_name: True} # Buffering: accumulate chunks per session - self.session_buffers = {} # {session_id: {"chunks": [], "chunk_ids": [], "sample_rate": int}} + self.session_buffers = ( + {} + ) # {session_id: {"chunks": [], "chunk_ids": [], "sample_rate": int}} async def discover_streams(self) -> list[str]: """ @@ -67,7 +71,9 @@ async def discover_streams(self) -> list[str]: cursor, match=self.stream_pattern, count=100 ) if keys: - streams.extend([k.decode() if isinstance(k, bytes) else k for k in keys]) + streams.extend( + [k.decode() if isinstance(k, bytes) else k for k in keys] + ) return streams @@ -76,16 +82,17 @@ async def setup_consumer_group(self, stream_name: str): # Create consumer group (ignore error if already exists) try: await self.redis_client.xgroup_create( - stream_name, - self.group_name, - "0", - mkstream=True + stream_name, self.group_name, "0", mkstream=True + ) + logger.debug( + f"➑️ Created consumer group {self.group_name} for {stream_name}" ) - logger.debug(f"➑️ Created consumer group {self.group_name} for {stream_name}") except redis_exceptions.ResponseError as e: if "BUSYGROUP" not in str(e): raise - logger.debug(f"➑️ Consumer group {self.group_name} already exists for {stream_name}") + logger.debug( + f"➑️ Consumer group {self.group_name} already exists for {stream_name}" + ) async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000): """ @@ -100,7 +107,7 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000): try: # Get all consumers in the group consumers = await self.redis_client.execute_command( - 'XINFO', 'CONSUMERS', self.input_stream, self.group_name + "XINFO", "CONSUMERS", self.input_stream, self.group_name ) if not consumers: @@ -115,9 +122,13 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000): # Parse consumer fields (flat key-value pairs within each consumer) for j in range(0, len(consumer_info), 2): - if j+1 < len(consumer_info): - key = consumer_info[j].decode() if isinstance(consumer_info[j], bytes) else str(consumer_info[j]) - value = consumer_info[j+1] + if j + 1 < len(consumer_info): + key = ( + consumer_info[j].decode() + if isinstance(consumer_info[j], bytes) + else str(consumer_info[j]) + ) + value = consumer_info[j + 1] if isinstance(value, bytes): try: value = value.decode() @@ -142,11 +153,19 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000): if is_dead: # If consumer has pending messages, claim and ACK them first if consumer_pending > 0: - logger.info(f"πŸ”„ Claiming {consumer_pending} pending messages from dead consumer {consumer_name}") + logger.info( + f"πŸ”„ Claiming {consumer_pending} pending messages from dead consumer {consumer_name}" + ) try: pending_messages = await self.redis_client.execute_command( - 'XPENDING', self.input_stream, self.group_name, '-', '+', str(consumer_pending), consumer_name + "XPENDING", + self.input_stream, + self.group_name, + "-", + "+", + str(consumer_pending), + consumer_name, ) # Parse pending messages (groups of 4: msg_id, consumer, idle_ms, delivery_count) @@ -159,28 +178,49 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000): # Claim to ourselves and ACK immediately try: await self.redis_client.execute_command( - 'XCLAIM', self.input_stream, self.group_name, self.consumer_name, '0', msg_id + "XCLAIM", + self.input_stream, + self.group_name, + self.consumer_name, + "0", + msg_id, + ) + await self.redis_client.xack( + self.input_stream, self.group_name, msg_id ) - await self.redis_client.xack(self.input_stream, self.group_name, msg_id) claimed_count += 1 except Exception as claim_error: - logger.warning(f"Failed to claim/ack message {msg_id}: {claim_error}") + logger.warning( + f"Failed to claim/ack message {msg_id}: {claim_error}" + ) except Exception as pending_error: - logger.warning(f"Failed to process pending messages for {consumer_name}: {pending_error}") + logger.warning( + f"Failed to process pending messages for {consumer_name}: {pending_error}" + ) # Delete the dead consumer try: await self.redis_client.execute_command( - 'XGROUP', 'DELCONSUMER', self.input_stream, self.group_name, consumer_name + "XGROUP", + "DELCONSUMER", + self.input_stream, + self.group_name, + consumer_name, ) deleted_count += 1 - logger.info(f"🧹 Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)") + logger.info( + f"🧹 Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)" + ) except Exception as delete_error: - logger.warning(f"Failed to delete consumer {consumer_name}: {delete_error}") + logger.warning( + f"Failed to delete consumer {consumer_name}: {delete_error}" + ) if deleted_count > 0 or claimed_count > 0: - logger.info(f"βœ… Cleanup complete: deleted {deleted_count} dead consumers, claimed {claimed_count} pending messages") + logger.info( + f"βœ… Cleanup complete: deleted {deleted_count} dead consumers, claimed {claimed_count} pending messages" + ) except Exception as e: logger.error(f"❌ Failed to cleanup dead consumers: {e}", exc_info=True) @@ -204,7 +244,9 @@ async def transcribe_audio(self, audio_data: bytes, sample_rate: int) -> dict: async def start_consuming(self): """Discover and consume from multiple streams using Redis consumer groups.""" self.running = True - logger.info(f"➑️ Starting dynamic stream consumer: {self.consumer_name} (group: {self.group_name})") + logger.info( + f"➑️ Starting dynamic stream consumer: {self.consumer_name} (group: {self.group_name})" + ) last_discovery = 0 discovery_interval = 10 # Discover new streams every 10 seconds @@ -223,7 +265,9 @@ async def start_consuming(self): # Setup consumer group for this stream (no manual lock needed) await self.setup_consumer_group(stream_name) self.active_streams[stream_name] = True - logger.info(f"βœ… Now consuming from {stream_name} (group: {self.group_name})") + logger.info( + f"βœ… Now consuming from {stream_name} (group: {self.group_name})" + ) last_discovery = current_time @@ -241,14 +285,18 @@ async def start_consuming(self): self.consumer_name, streams_dict, count=1, - block=1000 # Block for 1 second + block=1000, # Block for 1 second ) if not messages: continue for stream_name, msgs in messages: - stream_name_str = stream_name.decode() if isinstance(stream_name, bytes) else stream_name + stream_name_str = ( + stream_name.decode() + if isinstance(stream_name, bytes) + else stream_name + ) for message_id, fields in msgs: await self.process_message(message_id, fields, stream_name_str) @@ -260,11 +308,15 @@ async def start_consuming(self): # Extract stream name from error message for stream_name in list(self.active_streams.keys()): if stream_name in error_msg: - logger.warning(f"➑️ [{self.consumer_name}] Stream {stream_name} was deleted, removing from active streams") + logger.warning( + f"➑️ [{self.consumer_name}] Stream {stream_name} was deleted, removing from active streams" + ) # Remove from active streams del self.active_streams[stream_name] - logger.info(f"➑️ [{self.consumer_name}] Removed {stream_name}, {len(self.active_streams)} streams remaining") + logger.info( + f"➑️ [{self.consumer_name}] Removed {stream_name}, {len(self.active_streams)} streams remaining" + ) break else: # Other ResponseError - log and continue @@ -273,7 +325,10 @@ async def start_consuming(self): await asyncio.sleep(1) except Exception as e: - logger.error(f"➑️ [{self.consumer_name}] Error in dynamic consume loop: {e}", exc_info=True) + logger.error( + f"➑️ [{self.consumer_name}] Error in dynamic consume loop: {e}", + exc_info=True, + ) await asyncio.sleep(1) async def process_message(self, message_id: bytes, fields: dict, stream_name: str): @@ -295,7 +350,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st # Check for end-of-session signal if chunk_id == "END": - logger.info(f"➑️ [{self.consumer_name}] {self.provider_name}: Received END signal for session {session_id}") + logger.info( + f"➑️ [{self.consumer_name}] {self.provider_name}: Received END signal for session {session_id}" + ) # Flush buffer for this session if it has any chunks if session_id in self.session_buffers: @@ -306,7 +363,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st # Combine buffered chunks combined_audio = b"".join(buffer["chunks"]) - combined_chunk_id = f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}" + combined_chunk_id = ( + f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}" + ) logger.info( f"➑️ [{self.consumer_name}] {self.provider_name}: Flushing {len(buffer['chunks'])} remaining chunks " @@ -314,7 +373,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st ) # Transcribe remaining audio - result = await self.transcribe_audio(combined_audio, buffer["sample_rate"]) + result = await self.transcribe_audio( + combined_audio, buffer["sample_rate"] + ) # Store result processing_time = time.time() - start_time @@ -325,19 +386,27 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st confidence=result.get("confidence", 0.0), words=result.get("words", []), segments=result.get("segments", []), - processing_time=processing_time + processing_time=processing_time, ) # ACK all buffered messages for msg_id in buffer["message_ids"]: - await self.redis_client.xack(stream_name, self.group_name, msg_id) + await self.redis_client.xack( + stream_name, self.group_name, msg_id + ) # Trim stream to remove ACKed messages (keep only last 1000 for safety) try: - await self.redis_client.xtrim(stream_name, maxlen=1000, approximate=True) - logger.debug(f"🧹 Trimmed audio stream {stream_name} to max 1000 entries") + await self.redis_client.xtrim( + stream_name, maxlen=1000, approximate=True + ) + logger.debug( + f"🧹 Trimmed audio stream {stream_name} to max 1000 entries" + ) except Exception as trim_error: - logger.warning(f"Failed to trim stream {stream_name}: {trim_error}") + logger.warning( + f"Failed to trim stream {stream_name}: {trim_error}" + ) logger.info( f"➑️ [{self.consumer_name}] {self.provider_name}: Flushed buffer for session {session_id} " @@ -358,7 +427,7 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st "chunk_ids": [], "sample_rate": sample_rate, "message_ids": [], - "audio_offset_seconds": 0.0 # Track cumulative audio duration + "audio_offset_seconds": 0.0, # Track cumulative audio duration } # Add to buffer (skip empty audio data from END signals) @@ -382,14 +451,24 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st # Combine buffered chunks combined_audio = b"".join(buffer["chunks"]) - combined_chunk_id = f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}" + combined_chunk_id = ( + f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}" + ) # Calculate audio duration for this chunk (16-bit PCM, 1 channel) - audio_duration_seconds = len(combined_audio) / (sample_rate * 2) # 2 bytes per sample + audio_duration_seconds = len(combined_audio) / ( + sample_rate * 2 + ) # 2 bytes per sample audio_offset = buffer["audio_offset_seconds"] # Log individual chunk IDs to detect duplicates - chunk_list = ", ".join(buffer['chunk_ids'][:5] + ['...'] + buffer['chunk_ids'][-5:]) if len(buffer['chunk_ids']) > 10 else ", ".join(buffer['chunk_ids']) + chunk_list = ( + ", ".join( + buffer["chunk_ids"][:5] + ["..."] + buffer["chunk_ids"][-5:] + ) + if len(buffer["chunk_ids"]) > 10 + else ", ".join(buffer["chunk_ids"]) + ) logger.info( f"➑️ [{self.consumer_name}] {self.provider_name}: Transcribing {len(buffer['chunks'])} chunks " @@ -415,7 +494,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st adjusted_word["end"] = word.get("end", 0.0) + audio_offset adjusted_words.append(adjusted_word) - logger.debug(f"➑️ [{self.consumer_name}] Adjusted {len(adjusted_segments)} segments by +{audio_offset:.1f}s") + logger.debug( + f"➑️ [{self.consumer_name}] Adjusted {len(adjusted_segments)} segments by +{audio_offset:.1f}s" + ) # Store result with adjusted timestamps processing_time = time.time() - start_time @@ -426,7 +507,7 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st confidence=result.get("confidence", 0.0), words=adjusted_words, segments=adjusted_segments, - processing_time=processing_time + processing_time=processing_time, ) # Update audio offset for next chunk @@ -438,8 +519,12 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st # Trim stream to remove ACKed messages (keep only last 1000 for safety) try: - await self.redis_client.xtrim(stream_name, maxlen=1000, approximate=True) - logger.debug(f"🧹 Trimmed audio stream {stream_name} to max 1000 entries") + await self.redis_client.xtrim( + stream_name, maxlen=1000, approximate=True + ) + logger.debug( + f"🧹 Trimmed audio stream {stream_name} to max 1000 entries" + ) except Exception as trim_error: logger.warning(f"Failed to trim stream {stream_name}: {trim_error}") @@ -456,7 +541,7 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st except Exception as e: logger.error( f"➑️ [{self.consumer_name}] {self.provider_name}: Failed to process chunk {fields.get(b'chunk_id', b'unknown').decode()}: {e}", - exc_info=True + exc_info=True, ) async def store_result( @@ -467,7 +552,7 @@ async def store_result( confidence: float, words: list, segments: list, - processing_time: float + processing_time: float, ): """ Store transcription result in Redis Stream. @@ -502,7 +587,7 @@ async def store_result( session_results_stream, result_data, maxlen=1000, # Keep max 1k results per session - approximate=True + approximate=True, ) logger.debug( diff --git a/backends/advanced/src/advanced_omi_backend/services/capabilities.py b/backends/advanced/src/advanced_omi_backend/services/capabilities.py index 85837920..fc9d7de0 100644 --- a/backends/advanced/src/advanced_omi_backend/services/capabilities.py +++ b/backends/advanced/src/advanced_omi_backend/services/capabilities.py @@ -26,7 +26,9 @@ class TranscriptCapability(str, Enum): WORD_TIMESTAMPS = "word_timestamps" # Word-level timing data SEGMENTS = "segments" # Speaker segments in output - DIARIZATION = "diarization" # Speaker labels in segments (Speaker 0, Speaker 1, etc.) + DIARIZATION = ( + "diarization" # Speaker labels in segments (Speaker 0, Speaker 1, etc.) + ) class FeatureRequirement(str, Enum): @@ -99,7 +101,9 @@ def check_requirements( return True, "OK" -def get_provider_capabilities(transcript_version: "Conversation.TranscriptVersion") -> dict: +def get_provider_capabilities( + transcript_version: "Conversation.TranscriptVersion", +) -> dict: """ Get provider capabilities from transcript version metadata. diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py index a12cd540..2955b57f 100644 --- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py +++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py @@ -161,35 +161,45 @@ def _parse_extraction_response(content: str) -> ExtractionResult: for e in data.get("entities", []): if isinstance(e, dict) and e.get("name"): entity_type = _normalize_entity_type(e.get("type", "thing")) - entities.append(ExtractedEntity( - name=e["name"].strip(), - type=entity_type, - details=e.get("details"), - icon=e.get("icon") or _get_default_icon(entity_type), - when=e.get("when"), - )) + entities.append( + ExtractedEntity( + name=e["name"].strip(), + type=entity_type, + details=e.get("details"), + icon=e.get("icon") or _get_default_icon(entity_type), + when=e.get("when"), + ) + ) # Parse relationships relationships = [] for r in data.get("relationships", []): if isinstance(r, dict) and r.get("subject") and r.get("object"): - relationships.append(ExtractedRelationship( - subject=r["subject"].strip(), - relation=_normalize_relation_type(r.get("relation", "related_to")), - object=r["object"].strip(), - )) + relationships.append( + ExtractedRelationship( + subject=r["subject"].strip(), + relation=_normalize_relation_type( + r.get("relation", "related_to") + ), + object=r["object"].strip(), + ) + ) # Parse promises promises = [] for p in data.get("promises", []): if isinstance(p, dict) and p.get("action"): - promises.append(ExtractedPromise( - action=p["action"].strip(), - to=p.get("to"), - deadline=p.get("deadline"), - )) - - logger.info(f"Extracted {len(entities)} entities, {len(relationships)} relationships, {len(promises)} promises") + promises.append( + ExtractedPromise( + action=p["action"].strip(), + to=p.get("to"), + deadline=p.get("deadline"), + ) + ) + + logger.info( + f"Extracted {len(entities)} entities, {len(relationships)} relationships, {len(promises)} promises" + ) return ExtractionResult( entities=entities, relationships=relationships, @@ -274,7 +284,9 @@ def _get_default_icon(entity_type: str) -> str: return icons.get(entity_type, "πŸ“Œ") -def parse_natural_datetime(text: Optional[str], reference_date: Optional[datetime] = None) -> Optional[datetime]: +def parse_natural_datetime( + text: Optional[str], reference_date: Optional[datetime] = None +) -> Optional[datetime]: """Parse natural language date/time into datetime. Args: @@ -292,9 +304,20 @@ def parse_natural_datetime(text: Optional[str], reference_date: Optional[datetim # Simple patterns - can be extended with dateparser library later weekdays = { - "monday": 0, "tuesday": 1, "wednesday": 2, "thursday": 3, - "friday": 4, "saturday": 5, "sunday": 6, - "mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6, + "monday": 0, + "tuesday": 1, + "wednesday": 2, + "thursday": 3, + "friday": 4, + "saturday": 5, + "sunday": 6, + "mon": 0, + "tue": 1, + "wed": 2, + "thu": 3, + "fri": 4, + "sat": 5, + "sun": 6, } try: @@ -318,6 +341,7 @@ def parse_natural_datetime(text: Optional[str], reference_date: Optional[datetim # Handle "in X days/weeks" import re + match = re.search(r"in (\d+) (day|week|month)s?", text_lower) if match: num = int(match.group(1)) diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py index c4cf533c..0b78ba73 100644 --- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py +++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py @@ -14,6 +14,7 @@ class EntityType(str, Enum): """Supported entity types in the knowledge graph.""" + PERSON = "person" PLACE = "place" ORGANIZATION = "organization" @@ -26,6 +27,7 @@ class EntityType(str, Enum): class RelationshipType(str, Enum): """Supported relationship types between entities.""" + MENTIONED_IN = "MENTIONED_IN" WORKS_AT = "WORKS_AT" LIVES_IN = "LIVES_IN" @@ -41,6 +43,7 @@ class RelationshipType(str, Enum): class PromiseStatus(str, Enum): """Status of a promise/task.""" + PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" @@ -55,6 +58,7 @@ class Entity(BaseModel): Each entity belongs to a specific user and can have relationships with other entities. """ + id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: str type: EntityType @@ -113,6 +117,7 @@ class Relationship(BaseModel): Relationships are edges in the graph connecting entities with typed connections (e.g., WORKS_AT, KNOWS, MENTIONED_IN). """ + id: str = Field(default_factory=lambda: str(uuid.uuid4())) type: RelationshipType source_id: str # Entity ID @@ -164,6 +169,7 @@ class Promise(BaseModel): Promises are commitments made during conversations that can be tracked and followed up on. """ + id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: str action: str # What was promised @@ -191,7 +197,9 @@ def to_dict(self) -> Dict[str, Any]: "to_entity_name": self.to_entity_name, "status": self.status, "due_date": self.due_date.isoformat() if self.due_date else None, - "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "completed_at": ( + self.completed_at.isoformat() if self.completed_at else None + ), "source_conversation_id": self.source_conversation_id, "context": self.context, "metadata": self.metadata, @@ -202,6 +210,7 @@ def to_dict(self) -> Dict[str, Any]: class ExtractedEntity(BaseModel): """Entity as extracted by LLM before Neo4j storage.""" + name: str type: str # Will be validated against EntityType details: Optional[str] = None @@ -212,6 +221,7 @@ class ExtractedEntity(BaseModel): class ExtractedRelationship(BaseModel): """Relationship as extracted by LLM.""" + subject: str # Entity name or "speaker" relation: str # Relationship type object: str # Entity name @@ -219,6 +229,7 @@ class ExtractedRelationship(BaseModel): class ExtractedPromise(BaseModel): """Promise as extracted by LLM.""" + action: str to: Optional[str] = None # Entity name deadline: Optional[str] = None # Natural language deadline @@ -226,6 +237,7 @@ class ExtractedPromise(BaseModel): class ExtractionResult(BaseModel): """Result of entity extraction from a conversation.""" + entities: List[ExtractedEntity] = Field(default_factory=list) relationships: List[ExtractedRelationship] = Field(default_factory=list) promises: List[ExtractedPromise] = Field(default_factory=list) diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py index 5f13508d..0c7bd4a8 100644 --- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py +++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py @@ -126,7 +126,9 @@ async def process_conversation( ) if not extraction.entities and not extraction.promises: - logger.debug(f"No entities extracted from conversation {conversation_id}") + logger.debug( + f"No entities extracted from conversation {conversation_id}" + ) return {"entities": 0, "relationships": 0, "promises": 0} # Create conversation entity node @@ -827,6 +829,7 @@ def _parse_metadata(self, value: Any) -> Dict[str, Any]: if isinstance(value, str): try: import json + return json.loads(value) except (json.JSONDecodeError, ValueError): return {} diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/README.md b/backends/advanced/src/advanced_omi_backend/services/memory/README.md index ba6de6a4..49b74bd0 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/README.md +++ b/backends/advanced/src/advanced_omi_backend/services/memory/README.md @@ -28,50 +28,50 @@ memory/ graph TB %% User Interface Layer User[User/Application] --> CompatService[compat_service.py
MemoryService] - + %% Compatibility Layer CompatService --> CoreService[memory_service.py
CoreMemoryService] - + %% Configuration Config[config.py
MemoryConfig] --> CoreService Config --> LLMProviders Config --> VectorStores - + %% Core Service Layer CoreService --> Base[base.py
Abstract Interfaces] - + %% Base Abstractions Base --> MemoryServiceBase[MemoryServiceBase] - Base --> LLMProviderBase[LLMProviderBase] + Base --> LLMProviderBase[LLMProviderBase] Base --> VectorStoreBase[VectorStoreBase] Base --> MemoryEntry[MemoryEntry
Data Structure] - + %% Provider Implementations subgraph LLMProviders[LLM Providers] OpenAI[OpenAIProvider] Ollama[OllamaProvider] Qwen[Qwen3EmbeddingProvider] end - + subgraph VectorStores[Vector Stores] Qdrant[QdrantVectorStore] end - + %% Inheritance relationships LLMProviderBase -.-> OpenAI LLMProviderBase -.-> Ollama VectorStoreBase -.-> Qdrant - + %% Core Service uses providers CoreService --> LLMProviders CoreService --> VectorStores - + %% External Services OpenAI --> OpenAIAPI[OpenAI API] Ollama --> OllamaAPI[Ollama Server] Qwen --> LocalModel[Local Qwen Model] Qdrant --> QdrantDB[(Qdrant Database)] - + %% Memory Flow subgraph MemoryFlow[Memory Processing Flow] Transcript[Transcript] --> Extract[Extract Memories
via LLM] @@ -80,15 +80,15 @@ graph TB Store --> Search[Semantic Search] Search --> Results[Memory Results] end - + CoreService --> MemoryFlow - + %% Styling classDef interface fill:#e1f5fe,stroke:#01579b,stroke-width:2px classDef implementation fill:#f3e5f5,stroke:#4a148c,stroke-width:2px classDef external fill:#fff3e0,stroke:#e65100,stroke-width:2px classDef data fill:#e8f5e8,stroke:#2e7d32,stroke-width:2px - + class Base,MemoryServiceBase,LLMProviderBase,VectorStoreBase interface class CompatService,CoreService,OpenAI,Ollama,Qdrant implementation class OpenAIAPI,OllamaAPI,LocalModel,QdrantDB external @@ -109,7 +109,7 @@ classDiagram +string created_at +__post_init__() } - + class MemoryServiceBase { <> +initialize() Promise~void~ @@ -121,7 +121,7 @@ classDiagram +test_connection() Promise~bool~ +shutdown() void } - + class LLMProviderBase { <> +extract_memories(text, prompt) Promise~List~string~~ @@ -129,7 +129,7 @@ classDiagram +propose_memory_actions(old_memory, new_facts, custom_prompt) Promise~Dict~ +test_connection() Promise~bool~ } - + class VectorStoreBase { <> +initialize() Promise~void~ @@ -141,7 +141,7 @@ classDiagram +delete_user_memories(user_id) Promise~int~ +test_connection() Promise~bool~ } - + %% Configuration Classes class MemoryConfig { +LLMProvider llm_provider @@ -153,7 +153,7 @@ classDiagram +bool extraction_enabled +int timeout_seconds } - + %% Core Implementation class CoreMemoryService { -MemoryConfig config @@ -174,7 +174,7 @@ classDiagram -_normalize_actions() List~dict~ -_apply_memory_actions() Promise~List~string~~ } - + %% Compatibility Layer class CompatMemoryService { -CoreMemoryService _service @@ -190,7 +190,7 @@ classDiagram +test_connection() Promise~bool~ +shutdown() void } - + %% LLM Provider Implementations class OpenAIProvider { -string api_key @@ -204,7 +204,7 @@ classDiagram +propose_memory_actions() Promise~Dict~ +test_connection() Promise~bool~ } - + class OllamaProvider { -string base_url -string model @@ -217,7 +217,7 @@ classDiagram +propose_memory_actions() Promise~Dict~ +test_connection() Promise~bool~ } - + %% Vector Store Implementation class QdrantVectorStore { -string host @@ -234,20 +234,20 @@ classDiagram +delete_user_memories() Promise~int~ +test_connection() Promise~bool~ } - + %% Relationships MemoryServiceBase <|-- CoreMemoryService : implements LLMProviderBase <|-- OpenAIProvider : implements LLMProviderBase <|-- OllamaProvider : implements VectorStoreBase <|-- QdrantVectorStore : implements - + CoreMemoryService --> MemoryConfig : uses CoreMemoryService --> LLMProviderBase : uses CoreMemoryService --> VectorStoreBase : uses CoreMemoryService --> MemoryEntry : creates - + CompatMemoryService --> CoreMemoryService : wraps - + OpenAIProvider --> MemoryEntry : creates OllamaProvider --> MemoryEntry : creates QdrantVectorStore --> MemoryEntry : stores @@ -263,7 +263,7 @@ sequenceDiagram participant LLM as LLM Provider
(OpenAI/Ollama) participant Vector as Vector Store
(Qdrant) participant Config as Configuration - + Note over User, Config: Memory Service Initialization User->>Compat: get_memory_service() Compat->>Core: __init__(config) @@ -275,23 +275,23 @@ sequenceDiagram Vector-->>Core: ready LLM-->>Core: ready Core-->>Compat: initialized - + Note over User, Config: Adding Memory from Transcript User->>Compat: add_memory(transcript, ...) Compat->>Core: add_memory(transcript, ...) - + Core->>Core: _deduplicate_memories() Core->>LLM: generate_embeddings(memory_texts) LLM->>LLM: create vector embeddings LLM-->>Core: List[embeddings] - + alt Memory Updates Enabled Core->>Vector: search_memories(embeddings, user_id) Vector-->>Core: existing_memories Core->>LLM: propose_memory_actions(old, new) LLM->>LLM: decide ADD/UPDATE/DELETE LLM-->>Core: actions_list - + loop For each action alt Action: ADD Core->>Core: create MemoryEntry @@ -306,11 +306,11 @@ sequenceDiagram Core->>Core: create MemoryEntry objects Core->>Vector: add_memories(entries) end - + Vector-->>Core: created_ids Core-->>Compat: success, memory_ids Compat-->>User: success, memory_ids - + Note over User, Config: Searching Memories User->>Compat: search_memories(query, user_id) Compat->>Core: search_memories(query, user_id) @@ -380,7 +380,7 @@ await memory_service.initialize() success, memory_ids = await memory_service.add_memory( transcript="User discussed their goals for the next quarter.", client_id="client123", - audio_uuid="audio456", + audio_uuid="audio456", user_id="user789", user_email="user@example.com" ) @@ -412,7 +412,7 @@ success, memory_ids = await service.add_memory( transcript="My favorite destination is now Tokyo instead of Paris.", client_id="client123", audio_uuid="audio456", - user_id="user789", + user_id="user789", user_email="user@example.com", allow_update=True # Enable intelligent memory updates ) @@ -427,11 +427,11 @@ class CustomLLMProvider(LLMProviderBase): async def extract_memories(self, text: str, prompt: str) -> List[str]: # Custom implementation pass - + async def generate_embeddings(self, texts: List[str]) -> List[List[float]]: - # Custom implementation + # Custom implementation pass - + # ... implement other abstract methods ``` @@ -464,7 +464,7 @@ MEMORY_TIMEOUT_SECONDS=30 ```python from advanced_omi_backend.memory.config import ( MemoryConfig, - create_openai_config, + create_openai_config, create_qdrant_config ) @@ -493,7 +493,7 @@ The service supports intelligent memory updates through LLM-driven action propos - **ADD**: Create new memories for novel information - **UPDATE**: Modify existing memories with new details -- **DELETE**: Remove outdated or incorrect memories +- **DELETE**: Remove outdated or incorrect memories - **NONE**: No action needed for redundant information ### Example Flow @@ -594,7 +594,7 @@ The service includes comprehensive error handling: The modular architecture makes it easy to: 1. **Add new LLM providers**: Inherit from `LLMProviderBase` -2. **Add new vector stores**: Inherit from `VectorStoreBase` +2. **Add new vector stores**: Inherit from `VectorStoreBase` 3. **Customize memory logic**: Override `MemoryServiceBase` methods 4. **Add new data formats**: Extend `MemoryEntry` or conversion logic @@ -618,7 +618,7 @@ logging.getLogger("memory_service").setLevel(logging.INFO) Log levels: - **INFO**: Service lifecycle, major operations -- **DEBUG**: Detailed processing information +- **DEBUG**: Detailed processing information - **WARNING**: Recoverable errors, fallbacks - **ERROR**: Serious errors requiring attention @@ -694,15 +694,15 @@ class AnthropicProvider(LLMProviderBase): def __init__(self, config: Dict[str, Any]): self.api_key = config["api_key"] self.model = config.get("model", "claude-3-sonnet") - + async def extract_memories(self, text: str, prompt: str) -> List[str]: # Implement using Anthropic API pass - + async def generate_embeddings(self, texts: List[str]) -> List[List[float]]: # Implement using Anthropic embeddings pass - + # ... implement other abstract methods ``` @@ -781,15 +781,15 @@ class PineconeVectorStore(VectorStoreBase): self.api_key = config["api_key"] self.environment = config["environment"] self.index_name = config["index_name"] - + async def initialize(self) -> None: # Initialize Pinecone client pass - + async def add_memories(self, memories: List[MemoryEntry]) -> List[str]: # Add to Pinecone index pass - + # ... implement other abstract methods ``` @@ -831,10 +831,10 @@ class CustomMemoryService(CoreMemoryService): async def add_memory(self, transcript: str, **kwargs): # Pre-process transcript processed_transcript = await self._custom_preprocessing(transcript) - + # Call parent method return await super().add_memory(processed_transcript, **kwargs) - + async def _custom_preprocessing(self, transcript: str) -> str: # Your custom logic here return transcript @@ -909,4 +909,4 @@ Planned improvements: - Advanced memory summarization - Multi-modal memory support (images, audio) - Memory compression and archival -- Real-time memory streaming \ No newline at end of file +- Real-time memory streaming diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/base.py b/backends/advanced/src/advanced_omi_backend/services/memory/base.py index bae18e56..3eaf7259 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/base.py @@ -136,7 +136,9 @@ async def search_memories( pass @abstractmethod - async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]: + async def get_all_memories( + self, user_id: str, limit: int = 100 + ) -> List[MemoryEntry]: """Get all memories for a specific user. Args: @@ -265,7 +267,10 @@ async def reprocess_memory( @abstractmethod async def delete_memory( - self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + self, + memory_id: str, + user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> bool: """Delete a specific memory by ID. @@ -341,7 +346,10 @@ class LLMProviderBase(ABC): @abstractmethod async def extract_memories( - self, text: str, prompt: str, user_id: Optional[str] = None, + self, + text: str, + prompt: str, + user_id: Optional[str] = None, langfuse_session_id: Optional[str] = None, ) -> List[str]: """Extract meaningful fact memories from text using an LLM. @@ -469,7 +477,11 @@ async def add_memories(self, memories: List[MemoryEntry]) -> List[str]: @abstractmethod async def search_memories( - self, query_embedding: List[float], user_id: str, limit: int, score_threshold: float = 0.0 + self, + query_embedding: List[float], + user_id: str, + limit: int, + score_threshold: float = 0.0, ) -> List[MemoryEntry]: """Search memories using vector similarity. diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/config.py b/backends/advanced/src/advanced_omi_backend/services/memory/config.py index 83070706..3050487a 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/config.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/config.py @@ -137,7 +137,9 @@ def build_memory_config_from_env() -> MemoryConfig: # Map legacy provider names to current names if memory_provider in ("friend-lite", "friend_lite"): - memory_logger.info(f"πŸ”§ Mapping legacy provider '{memory_provider}' to 'chronicle'") + memory_logger.info( + f"πŸ”§ Mapping legacy provider '{memory_provider}' to 'chronicle'" + ) memory_provider = "chronicle" if memory_provider not in [p.value for p in MemoryProvider]: @@ -178,7 +180,9 @@ def build_memory_config_from_env() -> MemoryConfig: if not llm_def: raise ValueError("No default LLM defined in config.yml") model = llm_def.model_name - embedding_model = embed_def.model_name if embed_def else "text-embedding-3-small" + embedding_model = ( + embed_def.model_name if embed_def else "text-embedding-3-small" + ) base_url = llm_def.model_url memory_logger.info( f"πŸ”§ Memory config (registry): LLM={model}, Embedding={embedding_model}, Base URL={base_url}" @@ -201,7 +205,9 @@ def build_memory_config_from_env() -> MemoryConfig: host = str(vs_def.model_params.get("host", "qdrant")) port = int(vs_def.model_params.get("port", 6333)) - collection_name = str(vs_def.model_params.get("collection_name", "chronicle_memories")) + collection_name = str( + vs_def.model_params.get("collection_name", "chronicle_memories") + ) vector_store_config = create_qdrant_config( host=host, port=port, @@ -235,7 +241,9 @@ def build_memory_config_from_env() -> MemoryConfig: ) except ImportError: - memory_logger.warning("Config loader not available, using environment variables only") + memory_logger.warning( + "Config loader not available, using environment variables only" + ) raise diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py index 0e704be3..54c948b2 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py @@ -33,32 +33,32 @@ DEFAULT_UPDATE_MEMORY_PROMPT = f""" You are a memory manager for a system. -You must compare a list of **retrieved facts** with the **existing memory** (an array of `{{id, text}}` objects). -For each memory item, decide one of four operations: **ADD**, **UPDATE**, **DELETE**, or **NONE**. +You must compare a list of **retrieved facts** with the **existing memory** (an array of `{{id, text}}` objects). +For each memory item, decide one of four operations: **ADD**, **UPDATE**, **DELETE**, or **NONE**. Your output must follow the exact XML format described. --- ## Rules -1. **ADD**: +1. **ADD**: - If a retrieved fact is new (no existing memory on that topic), create a new `` with a new `id` (numeric, non-colliding). - Always include `` with the new fact. -2. **UPDATE**: - - If a retrieved fact replaces, contradicts, or refines an existing memory, update that memory instead of deleting and adding. - - Keep the same `id`. - - Always include `` with the new fact. - - Always include `` with the previous memory text. +2. **UPDATE**: + - If a retrieved fact replaces, contradicts, or refines an existing memory, update that memory instead of deleting and adding. + - Keep the same `id`. + - Always include `` with the new fact. + - Always include `` with the previous memory text. - If multiple memories are about the same topic, update **all of them** to the new fact (consolidation). -3. **DELETE**: - - Use only when a retrieved fact explicitly invalidates or negates a memory (e.g., β€œI no longer like pizza”). - - Keep the same `id`. +3. **DELETE**: + - Use only when a retrieved fact explicitly invalidates or negates a memory (e.g., β€œI no longer like pizza”). + - Keep the same `id`. - Always include `` with the old memory value so the XML remains well-formed. -4. **NONE**: - - If the memory is unchanged and still valid. - - Keep the same `id`. +4. **NONE**: + - If the memory is unchanged and still valid. + - Keep the same `id`. - Always include `` with the existing value. --- @@ -80,9 +80,9 @@ ## Examples ### Example 1 (Preference Update) -Old: `[{{"id": "0", "text": "My name is John"}}, {{"id": "1", "text": "My favorite fruit is oranges"}}]` +Old: `[{{"id": "0", "text": "My name is John"}}, {{"id": "1", "text": "My favorite fruit is oranges"}}]` Facts (each should be a separate XML item): - 1. My favorite fruit is apple + 1. My favorite fruit is apple Output: @@ -98,9 +98,9 @@ ### Example 2 (Contradiction / Deletion) -Old: `[{{"id": "0", "text": "I like pizza"}}]` +Old: `[{{"id": "0", "text": "I like pizza"}}]` Facts (each should be a separate XML item): - 1. I no longer like pizza + 1. I no longer like pizza Output: @@ -112,7 +112,7 @@ ### Example 3 (Multiple New Facts) -Old: `[{{"id": "0", "text": "I like hiking"}}]` +Old: `[{{"id": "0", "text": "I like hiking"}}]` Facts (each should be a separate XML item): 1. I enjoy rug tufting 2. I watch YouTube tutorials @@ -139,9 +139,9 @@ --- **Important constraints**: -- Never output both DELETE and ADD for the same topic; use UPDATE instead. -- Every `` must contain ``. -- Only include `` for UPDATE events. +- Never output both DELETE and ADD for the same topic; use UPDATE instead. +- Every `` must contain ``. +- Only include `` for UPDATE events. - Do not output any text outside `...`. """ @@ -314,81 +314,89 @@ def build_reprocess_speaker_messages( **Task Objective**: Scrape blog post titles and full content from the OpenAI blog. **Progress Status**: 10% complete β€” 5 out of 50 blog posts processed. -1. **Agent Action**: Opened URL "https://openai.com" - **Action Result**: - "HTML Content of the homepage including navigation bar with links: 'Blog', 'API', 'ChatGPT', etc." - **Key Findings**: Navigation bar loaded correctly. - **Navigation History**: Visited homepage: "https://openai.com" +1. **Agent Action**: Opened URL "https://openai.com" + **Action Result**: + "HTML Content of the homepage including navigation bar with links: 'Blog', 'API', 'ChatGPT', etc." + **Key Findings**: Navigation bar loaded correctly. + **Navigation History**: Visited homepage: "https://openai.com" **Current Context**: Homepage loaded; ready to click on the 'Blog' link. -2. **Agent Action**: Clicked on the "Blog" link in the navigation bar. - **Action Result**: - "Navigated to 'https://openai.com/blog/' with the blog listing fully rendered." - **Key Findings**: Blog listing shows 10 blog previews. - **Navigation History**: Transitioned from homepage to blog listing page. +2. **Agent Action**: Clicked on the "Blog" link in the navigation bar. + **Action Result**: + "Navigated to 'https://openai.com/blog/' with the blog listing fully rendered." + **Key Findings**: Blog listing shows 10 blog previews. + **Navigation History**: Transitioned from homepage to blog listing page. **Current Context**: Blog listing page displayed. -3. **Agent Action**: Extracted the first 5 blog post links from the blog listing page. - **Action Result**: - "[ '/blog/chatgpt-updates', '/blog/ai-and-education', '/blog/openai-api-announcement', '/blog/gpt-4-release', '/blog/safety-and-alignment' ]" - **Key Findings**: Identified 5 valid blog post URLs. +3. **Agent Action**: Extracted the first 5 blog post links from the blog listing page. + **Action Result**: + "[ '/blog/chatgpt-updates', '/blog/ai-and-education', '/blog/openai-api-announcement', '/blog/gpt-4-release', '/blog/safety-and-alignment' ]" + **Key Findings**: Identified 5 valid blog post URLs. **Current Context**: URLs stored in memory for further processing. -4. **Agent Action**: Visited URL "https://openai.com/blog/chatgpt-updates" - **Action Result**: - "HTML content loaded for the blog post including full article text." - **Key Findings**: Extracted blog title "ChatGPT Updates – March 2025" and article content excerpt. +4. **Agent Action**: Visited URL "https://openai.com/blog/chatgpt-updates" + **Action Result**: + "HTML content loaded for the blog post including full article text." + **Key Findings**: Extracted blog title "ChatGPT Updates – March 2025" and article content excerpt. **Current Context**: Blog post content extracted and stored. -5. **Agent Action**: Extracted blog title and full article content from "https://openai.com/blog/chatgpt-updates" - **Action Result**: - "{{ 'title': 'ChatGPT Updates – March 2025', 'content': 'We\'re introducing new updates to ChatGPT, including improved browsing capabilities and memory recall... (full content)' }}" - **Key Findings**: Full content captured for later summarization. +5. **Agent Action**: Extracted blog title and full article content from "https://openai.com/blog/chatgpt-updates" + **Action Result**: + "{{ 'title': 'ChatGPT Updates – March 2025', 'content': 'We\'re introducing new updates to ChatGPT, including improved browsing capabilities and memory recall... (full content)' }}" + **Key Findings**: Full content captured for later summarization. **Current Context**: Data stored; ready to proceed to next blog post. ... (Additional numbered steps for subsequent actions) ``` """ -def build_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None): - if custom_update_memory_prompt is None: + +def build_update_memory_messages( + retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None +): + if custom_update_memory_prompt is None: custom_update_memory_prompt = DEFAULT_UPDATE_MEMORY_PROMPT - - if not retrieved_old_memory_dict or len(retrieved_old_memory_dict) == 0: - retrieved_old_memory_dict = "None" - - # Format facts individually to encourage separate XML items - if isinstance(response_content, list) and len(response_content) > 1: - facts_str = "Facts (each should be a separate XML item):\n" - for i, fact in enumerate(response_content): - facts_str += f" {i+1}. {fact}\n" - facts_str = facts_str.strip() - else: - # Single fact or non-list, use original JSON format - facts_str = "Facts: " + json.dumps(response_content, ensure_ascii=False) - - prompt = ( - "Old: " + json.dumps(retrieved_old_memory_dict, ensure_ascii=False) + "\n" + - facts_str + "\n" + - "Output:" + + if not retrieved_old_memory_dict or len(retrieved_old_memory_dict) == 0: + retrieved_old_memory_dict = "None" + + # Format facts individually to encourage separate XML items + if isinstance(response_content, list) and len(response_content) > 1: + facts_str = "Facts (each should be a separate XML item):\n" + for i, fact in enumerate(response_content): + facts_str += f" {i+1}. {fact}\n" + facts_str = facts_str.strip() + else: + # Single fact or non-list, use original JSON format + facts_str = "Facts: " + json.dumps(response_content, ensure_ascii=False) + + prompt = ( + "Old: " + + json.dumps(retrieved_old_memory_dict, ensure_ascii=False) + + "\n" + + facts_str + + "\n" + + "Output:" ) - messages = [ + messages = [ {"role": "system", "content": custom_update_memory_prompt.strip()}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ] - return messages + return messages -def get_update_memory_messages(retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None): +def get_update_memory_messages( + retrieved_old_memory_dict, response_content, custom_update_memory_prompt=None +): """ Generate a formatted message for the LLM to update memory with new facts. - + Args: retrieved_old_memory_dict: List of existing memory entries with id and text response_content: List of new facts to integrate custom_update_memory_prompt: Optional custom prompt to override default - + Returns: str: Formatted prompt for the LLM """ @@ -409,7 +417,7 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content, cust "event" : "ADD" }}, {{ - "id" : "1", + "id" : "1", "text" : "", "event" : "ADD" }} @@ -419,7 +427,7 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content, cust New facts to add: {response_content} -IMPORTANT: +IMPORTANT: - When memory is empty, ALL actions must be "ADD" events - Use sequential IDs starting from 0: "0", "1", "2", etc. - Return ONLY valid JSON with NO extra text or thinking @@ -470,23 +478,49 @@ def get_update_memory_messages(retrieved_old_memory_dict, response_content, cust # ===== Temporal and Entity Extraction ===== + class TimeRange(BaseModel): """Represents a time range with start and end timestamps.""" - start: datetime = Field(description="ISO 8601 timestamp when the event/activity starts") + + start: datetime = Field( + description="ISO 8601 timestamp when the event/activity starts" + ) end: datetime = Field(description="ISO 8601 timestamp when the event/activity ends") - name: Optional[str] = Field(default=None, description="Optional name/label for this time range (e.g., 'wedding ceremony', 'party')") + name: Optional[str] = Field( + default=None, + description="Optional name/label for this time range (e.g., 'wedding ceremony', 'party')", + ) class TemporalEntity(BaseModel): """Structured temporal and entity information extracted from a memory fact.""" - isEvent: bool = Field(description="Whether this memory describes a scheduled event or time-bound activity") - isPerson: bool = Field(description="Whether this memory is primarily about a person or people") - isPlace: bool = Field(description="Whether this memory is primarily about a location or place") - isPromise: bool = Field(description="Whether this memory contains a commitment, promise, or agreement") - isRelationship: bool = Field(description="Whether this memory describes a relationship between people") - entities: List[str] = Field(default_factory=list, description="List of people, places, or things mentioned (e.g., ['John', 'Botanical Gardens', 'wedding'])") - timeRanges: List[TimeRange] = Field(default_factory=list, description="List of time ranges if this is a temporal memory") - emoji: Optional[str] = Field(default=None, description="Single emoji that best represents this memory") + + isEvent: bool = Field( + description="Whether this memory describes a scheduled event or time-bound activity" + ) + isPerson: bool = Field( + description="Whether this memory is primarily about a person or people" + ) + isPlace: bool = Field( + description="Whether this memory is primarily about a location or place" + ) + isPromise: bool = Field( + description="Whether this memory contains a commitment, promise, or agreement" + ) + isRelationship: bool = Field( + description="Whether this memory describes a relationship between people" + ) + entities: List[str] = Field( + default_factory=list, + description="List of people, places, or things mentioned (e.g., ['John', 'Botanical Gardens', 'wedding'])", + ) + timeRanges: List[TimeRange] = Field( + default_factory=list, + description="List of time ranges if this is a temporal memory", + ) + emoji: Optional[str] = Field( + default=None, description="Single emoji that best represents this memory" + ) def build_temporal_extraction_prompt(current_date: datetime) -> str: diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py index d1f51775..89045222 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/chronicle.py @@ -69,10 +69,15 @@ async def initialize(self) -> None: try: # Initialize LLM provider - if self.config.llm_provider in [LLMProviderEnum.OPENAI, LLMProviderEnum.OLLAMA]: + if self.config.llm_provider in [ + LLMProviderEnum.OPENAI, + LLMProviderEnum.OLLAMA, + ]: self.llm_provider = OpenAIProvider(self.config.llm_config) else: - raise ValueError(f"Unsupported LLM provider: {self.config.llm_provider}") + raise ValueError( + f"Unsupported LLM provider: {self.config.llm_provider}" + ) # Initialize vector store if self.config.vector_store_provider == VectorStoreProvider.QDRANT: @@ -155,7 +160,9 @@ async def add_memory( if self.config.extraction_enabled and self.config.extraction_prompt: fact_memories_text = await asyncio.wait_for( self.llm_provider.extract_memories( - transcript, self.config.extraction_prompt, user_id=user_id, + transcript, + self.config.extraction_prompt, + user_id=user_id, langfuse_session_id=source_id, ), timeout=self.config.timeout_seconds, @@ -174,7 +181,9 @@ async def add_memory( memory_logger.debug(f"🧠 fact_memories_text: {fact_memories_text}") # Simple deduplication of extracted memories within the same call fact_memories_text = self._deduplicate_memories(fact_memories_text) - memory_logger.debug(f"🧠 fact_memories_text after deduplication: {fact_memories_text}") + memory_logger.debug( + f"🧠 fact_memories_text after deduplication: {fact_memories_text}" + ) # Generate embeddings embeddings = await asyncio.wait_for( self.llm_provider.generate_embeddings(fact_memories_text), @@ -194,14 +203,24 @@ async def add_memory( if allow_update and fact_memories_text: memory_logger.info(f"πŸ” Allowing update for {source_id}") created_ids = await self._process_memory_updates( - fact_memories_text, embeddings, user_id, client_id, source_id, user_email, + fact_memories_text, + embeddings, + user_id, + client_id, + source_id, + user_email, langfuse_session_id=source_id, ) else: memory_logger.info(f"πŸ” Not allowing update for {source_id}") # Add all extracted memories normally memory_entries = self._create_memory_entries( - fact_memories_text, embeddings, client_id, source_id, user_id, user_email + fact_memories_text, + embeddings, + client_id, + source_id, + user_id, + user_email, ) # Store new entries in vector database @@ -211,10 +230,14 @@ async def add_memory( # Update database relationships if helper provided if created_ids and db_helper: - await self._update_database_relationships(db_helper, source_id, created_ids) + await self._update_database_relationships( + db_helper, source_id, created_ids + ) if created_ids: - memory_logger.info(f"βœ… Upserted {len(created_ids)} memories for {source_id}") + memory_logger.info( + f"βœ… Upserted {len(created_ids)} memories for {source_id}" + ) return True, created_ids # No memories created - this is a valid outcome (duplicates, no extractable facts, etc.) @@ -271,7 +294,9 @@ async def search_memories( memory_logger.error(f"Search memories failed: {e}") return [] - async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]: + async def get_all_memories( + self, user_id: str, limit: int = 100 + ) -> List[MemoryEntry]: """Get all memories for a specific user. Retrieves all stored memories for the given user without @@ -289,7 +314,9 @@ async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryE try: memories = await self.vector_store.get_memories(user_id, limit) - memory_logger.info(f"πŸ“š Retrieved {len(memories)} memories for user {user_id}") + memory_logger.info( + f"πŸ“š Retrieved {len(memories)} memories for user {user_id}" + ) return memories except Exception as e: memory_logger.error(f"Get all memories failed: {e}") @@ -325,7 +352,9 @@ async def get_memories_by_source( await self.initialize() try: - memories = await self.vector_store.get_memories_by_source(user_id, source_id, limit) + memories = await self.vector_store.get_memories_by_source( + user_id, source_id, limit + ) memory_logger.info( f"πŸ“š Retrieved {len(memories)} memories for source {source_id} (user {user_id})" ) @@ -411,7 +440,9 @@ async def update_memory( new_embedding = existing_memory.embedding else: # No existing embedding, generate one - embeddings = await self.llm_provider.generate_embeddings([new_content]) + embeddings = await self.llm_provider.generate_embeddings( + [new_content] + ) new_embedding = embeddings[0] # Update in vector store @@ -430,11 +461,16 @@ async def update_memory( return success except Exception as e: - memory_logger.error(f"Error updating memory {memory_id}: {e}", exc_info=True) + memory_logger.error( + f"Error updating memory {memory_id}: {e}", exc_info=True + ) return False async def delete_memory( - self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + self, + memory_id: str, + user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> bool: """Delete a specific memory by ID. @@ -544,7 +580,11 @@ async def reprocess_memory( f"falling back to normal extraction" ) return await self.add_memory( - transcript, client_id, source_id, user_id, user_email, + transcript, + client_id, + source_id, + user_id, + user_email, allow_update=True, ) @@ -555,7 +595,11 @@ async def reprocess_memory( f"falling back to normal extraction" ) return await self.add_memory( - transcript, client_id, source_id, user_id, user_email, + transcript, + client_id, + source_id, + user_id, + user_email, allow_update=True, ) @@ -582,22 +626,28 @@ async def reprocess_memory( new_transcript=transcript, langfuse_session_id=source_id, ) - memory_logger.info( - f"πŸ”„ Reprocess LLM returned actions: {actions_obj}" - ) + memory_logger.info(f"πŸ”„ Reprocess LLM returned actions: {actions_obj}") except NotImplementedError: memory_logger.warning( "LLM provider does not support propose_reprocess_actions, " "falling back to normal extraction" ) return await self.add_memory( - transcript, client_id, source_id, user_id, user_email, + transcript, + client_id, + source_id, + user_id, + user_email, allow_update=True, ) except Exception as e: memory_logger.error(f"Reprocess LLM call failed: {e}") return await self.add_memory( - transcript, client_id, source_id, user_id, user_email, + transcript, + client_id, + source_id, + user_id, + user_email, allow_update=True, ) @@ -645,15 +695,17 @@ async def reprocess_memory( return True, created_ids except Exception as e: - memory_logger.error( - f"❌ Reprocess memory failed for {source_id}: {e}" - ) + memory_logger.error(f"❌ Reprocess memory failed for {source_id}: {e}") # Fall back to normal extraction on any unexpected error memory_logger.info( f"πŸ”„ Falling back to normal extraction after reprocess error" ) return await self.add_memory( - transcript, client_id, source_id, user_id, user_email, + transcript, + client_id, + source_id, + user_id, + user_email, allow_update=True, ) @@ -825,7 +877,9 @@ async def _process_memory_updates( for mem in candidates: retrieved_old_memory.append({"id": mem.id, "text": mem.content}) except Exception as e_search: - memory_logger.warning(f"Search failed while preparing updates: {e_search}") + memory_logger.warning( + f"Search failed while preparing updates: {e_search}" + ) # Dedupe by id and prepare temp mapping uniq = {} @@ -845,7 +899,9 @@ async def _process_memory_updates( f"πŸ” Asking LLM for actions with {len(retrieved_old_memory)} old memories " f"and {len(memories_text)} new facts" ) - memory_logger.debug(f"🧠 Individual facts being sent to LLM: {memories_text}") + memory_logger.debug( + f"🧠 Individual facts being sent to LLM: {memories_text}" + ) # add update or delete etc actions using DEFAULT_UPDATE_MEMORY_PROMPT actions_obj = await self.llm_provider.propose_memory_actions( @@ -854,7 +910,9 @@ async def _process_memory_updates( custom_prompt=None, langfuse_session_id=langfuse_session_id, ) - memory_logger.info(f"πŸ“ UpdateMemory LLM returned: {type(actions_obj)} - {actions_obj}") + memory_logger.info( + f"πŸ“ UpdateMemory LLM returned: {type(actions_obj)} - {actions_obj}" + ) except Exception as e_actions: memory_logger.error(f"LLM propose_memory_actions failed: {e_actions}") actions_obj = {} @@ -891,7 +949,9 @@ def _normalize_actions(self, actions_obj: Any) -> List[dict]: if isinstance(memory_field, list): actions_list = memory_field elif isinstance(actions_obj.get("facts"), list): - actions_list = [{"event": "ADD", "text": str(t)} for t in actions_obj["facts"]] + actions_list = [ + {"event": "ADD", "text": str(t)} for t in actions_obj["facts"] + ] else: # Pick first list field found for v in actions_obj.values(): @@ -901,7 +961,9 @@ def _normalize_actions(self, actions_obj: Any) -> List[dict]: elif isinstance(actions_obj, list): actions_list = actions_obj - memory_logger.info(f"πŸ“‹ Normalized to {len(actions_list)} actions: {actions_list}") + memory_logger.info( + f"πŸ“‹ Normalized to {len(actions_list)} actions: {actions_list}" + ) except Exception as normalize_err: memory_logger.warning(f"Failed to normalize actions: {normalize_err}") actions_list = [] @@ -951,7 +1013,9 @@ async def _apply_memory_actions( memory_logger.warning(f"Skipping action with no text: {resp}") continue - memory_logger.debug(f"Processing action: {event_type} - {action_text[:50]}...") + memory_logger.debug( + f"Processing action: {event_type} - {action_text[:50]}..." + ) base_metadata = { "source": "offline_streaming", @@ -973,7 +1037,9 @@ async def _apply_memory_actions( ) emb = gen[0] if gen else None except Exception as gen_err: - memory_logger.warning(f"Embedding generation failed for action text: {gen_err}") + memory_logger.warning( + f"Embedding generation failed for action text: {gen_err}" + ) emb = None if event_type == "ADD": @@ -995,7 +1061,9 @@ async def _apply_memory_actions( updated_at=current_time, ) ) - memory_logger.info(f"βž• Added new memory: {memory_id} - {action_text[:50]}...") + memory_logger.info( + f"βž• Added new memory: {memory_id} - {action_text[:50]}..." + ) elif event_type == "UPDATE": provided_id = resp.get("id") @@ -1015,11 +1083,15 @@ async def _apply_memory_actions( f"πŸ”„ Updated memory: {actual_id} - {action_text[:50]}..." ) else: - memory_logger.warning(f"Failed to update memory {actual_id}") + memory_logger.warning( + f"Failed to update memory {actual_id}" + ) except Exception as update_err: memory_logger.error(f"Update memory failed: {update_err}") else: - memory_logger.warning(f"Skipping UPDATE due to missing ID or embedding") + memory_logger.warning( + f"Skipping UPDATE due to missing ID or embedding" + ) elif event_type == "DELETE": provided_id = resp.get("id") @@ -1030,14 +1102,20 @@ async def _apply_memory_actions( if deleted: memory_logger.info(f"πŸ—‘οΈ Deleted memory {actual_id}") else: - memory_logger.warning(f"Failed to delete memory {actual_id}") + memory_logger.warning( + f"Failed to delete memory {actual_id}" + ) except Exception as delete_err: memory_logger.error(f"Delete memory failed: {delete_err}") else: - memory_logger.warning(f"Skipping DELETE due to missing ID: {provided_id}") + memory_logger.warning( + f"Skipping DELETE due to missing ID: {provided_id}" + ) elif event_type == "NONE": - memory_logger.debug(f"NONE action - no changes for: {action_text[:50]}...") + memory_logger.debug( + f"NONE action - no changes for: {action_text[:50]}..." + ) continue else: memory_logger.warning(f"Unknown event type: {event_type}") @@ -1100,7 +1178,9 @@ async def example_usage(): print(f"πŸ” Found {len(results)} search results") # Get all memories - all_memories = await memory_service.get_all_memories(user_id="user789", limit=100) + all_memories = await memory_service.get_all_memories( + user_id="user789", limit=100 + ) print(f"πŸ“š Total memories: {len(all_memories)}") # Clean up test data diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py index 2d83d24c..58b187e8 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/llm_providers.py @@ -15,7 +15,10 @@ from typing import Any, Dict, List, Optional from advanced_omi_backend.model_registry import ModelDef, get_models_registry -from advanced_omi_backend.openai_factory import create_openai_client, is_langfuse_enabled +from advanced_omi_backend.openai_factory import ( + create_openai_client, + is_langfuse_enabled, +) from advanced_omi_backend.prompt_registry import get_prompt_registry from ..base import LLMProviderBase @@ -77,6 +80,7 @@ async def generate_openai_embeddings( ) return [data.embedding for data in response.data] + # TODO: Re-enable spacy when Docker build is fixed # try: # nlp = spacy.load("en_core_web_sm") @@ -86,6 +90,7 @@ async def generate_openai_embeddings( # nlp = None nlp = None # Temporarily disabled + def chunk_text_with_spacy(text: str, max_tokens: int = 100) -> List[str]: """Split text into chunks using spaCy sentence segmentation. max_tokens is the maximum number of words in a chunk. @@ -93,14 +98,14 @@ def chunk_text_with_spacy(text: str, max_tokens: int = 100) -> List[str]: # Fallback chunking when spacy is not available if nlp is None: # Simple sentence-based chunking - sentences = text.replace('\n', ' ').split('. ') + sentences = text.replace("\n", " ").split(". ") chunks = [] current_chunk = "" current_tokens = 0 - + for sentence in sentences: sentence_tokens = len(sentence.split()) - + if current_tokens + sentence_tokens > max_tokens and current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence @@ -111,23 +116,23 @@ def chunk_text_with_spacy(text: str, max_tokens: int = 100) -> List[str]: else: current_chunk = sentence current_tokens += sentence_tokens - + if current_chunk.strip(): chunks.append(current_chunk.strip()) - + return chunks if chunks else [text] - + # Original spacy implementation when available doc = nlp(text) - + chunks = [] current_chunk = "" current_tokens = 0 - + for sent in doc.sents: sent_text = sent.text.strip() sent_tokens = len(sent_text.split()) # Simple word count - + if current_tokens + sent_tokens > max_tokens and current_chunk: chunks.append(current_chunk.strip()) current_chunk = sent_text @@ -135,12 +140,13 @@ def chunk_text_with_spacy(text: str, max_tokens: int = 100) -> List[str]: else: current_chunk += " " + sent_text if current_chunk else sent_text current_tokens += sent_tokens - + if current_chunk.strip(): chunks.append(current_chunk.strip()) - + return chunks + class OpenAIProvider(LLMProviderBase): """Config-driven LLM provider using OpenAI SDK (OpenAI-compatible). @@ -153,7 +159,9 @@ def __init__(self, config: Dict[str, Any]): # Ignore provider-specific envs; use registry as single source of truth registry = get_models_registry() if not registry: - raise RuntimeError("config.yml not found or invalid; cannot initialize model registry") + raise RuntimeError( + "config.yml not found or invalid; cannot initialize model registry" + ) self._registry = registry @@ -170,9 +178,15 @@ def __init__(self, config: Dict[str, Any]): self.model = self.llm_def.model_name # Store parameters for embeddings (use separate config if available) - self.embedding_model = (self.embed_def.model_name if self.embed_def else self.llm_def.model_name) - self.embedding_api_key = (self.embed_def.api_key if self.embed_def else self.api_key) - self.embedding_base_url = (self.embed_def.model_url if self.embed_def else self.base_url) + self.embedding_model = ( + self.embed_def.model_name if self.embed_def else self.llm_def.model_name + ) + self.embedding_api_key = ( + self.embed_def.api_key if self.embed_def else self.api_key + ) + self.embedding_base_url = ( + self.embed_def.model_url if self.embed_def else self.base_url + ) # CRITICAL: Validate API keys are present - fail fast instead of hanging if not self.api_key or self.api_key.strip() == "": @@ -182,7 +196,9 @@ def __init__(self, config: Dict[str, Any]): f"Cannot proceed without valid API credentials." ) - if self.embed_def and (not self.embedding_api_key or self.embedding_api_key.strip() == ""): + if self.embed_def and ( + not self.embedding_api_key or self.embedding_api_key.strip() == "" + ): raise RuntimeError( f"API key is missing or empty for embedding provider '{self.embed_def.model_provider}' (model: {self.embedding_model}). " f"Please set the API key in config.yml or environment variables." @@ -192,7 +208,10 @@ def __init__(self, config: Dict[str, Any]): self._client = None async def extract_memories( - self, text: str, prompt: str, user_id: Optional[str] = None, + self, + text: str, + prompt: str, + user_id: Optional[str] = None, langfuse_session_id: Optional[str] = None, ) -> List[str]: """Extract memories using OpenAI API with the enhanced fact retrieval prompt. @@ -223,22 +242,30 @@ async def extract_memories( text_chunks = chunk_text_with_spacy(text) # Process all chunks in sequence, not concurrently - results = [await self._process_chunk(system_prompt, chunk, i, langfuse_session_id=langfuse_session_id) for i, chunk in enumerate(text_chunks)] - + results = [ + await self._process_chunk( + system_prompt, chunk, i, langfuse_session_id=langfuse_session_id + ) + for i, chunk in enumerate(text_chunks) + ] + # Spread list of list of facts into a single list of facts cleaned_facts = [] for result in results: memory_logger.info(f"Cleaned facts: {result}") cleaned_facts.extend(result) - + return cleaned_facts - + except Exception as e: memory_logger.error(f"OpenAI memory extraction failed: {e}") return [] - + async def _process_chunk( - self, system_prompt: str, chunk: str, index: int, + self, + system_prompt: str, + chunk: str, + index: int, langfuse_session_id: Optional[str] = None, ) -> List[str]: """Process a single text chunk to extract memories using OpenAI API. @@ -312,11 +339,15 @@ async def test_connection(self) -> bool: try: # Add 10-second timeout to prevent hanging on API calls async with asyncio.timeout(10): - client = _get_openai_client(api_key=self.api_key, base_url=self.base_url, is_async=True) + client = _get_openai_client( + api_key=self.api_key, base_url=self.base_url, is_async=True + ) await client.models.list() return True except asyncio.TimeoutError: - memory_logger.error(f"OpenAI connection test timed out after 10s - check network connectivity and API endpoint") + memory_logger.error( + f"OpenAI connection test timed out after 10s - check network connectivity and API endpoint" + ) return False except Exception as e: memory_logger.error(f"OpenAI connection test failed: {e}") @@ -344,11 +375,11 @@ async def propose_memory_actions( # Generate the complete prompt using the helper function memory_logger.debug(f"🧠 Facts passed to prompt builder: {new_facts}") update_memory_messages = build_update_memory_messages( - retrieved_old_memory, - new_facts, - custom_prompt + retrieved_old_memory, new_facts, custom_prompt + ) + memory_logger.debug( + f"🧠 Generated prompt user content: {update_memory_messages[1]['content'][:200]}..." ) - memory_logger.debug(f"🧠 Generated prompt user content: {update_memory_messages[1]['content'][:200]}...") op = self._registry.get_llm_operation("memory_update") client = op.get_client(is_async=True) @@ -374,7 +405,6 @@ async def propose_memory_actions( memory_logger.error(f"OpenAI propose_memory_actions failed: {e}") return {} - async def propose_reprocess_actions( self, existing_memories: List[Dict[str, str]], @@ -466,21 +496,23 @@ async def propose_reprocess_actions( class OllamaProvider(LLMProviderBase): """Ollama LLM provider implementation. - + Provides memory extraction, embedding generation, and memory action proposals using Ollama's GPT and embedding models. - - + + Use the openai provider for ollama with different environment variables - - os.environ["OPENAI_API_KEY"] = "ollama" + + os.environ["OPENAI_API_KEY"] = "ollama" os.environ["OPENAI_BASE_URL"] = "http://localhost:11434/v1" os.environ["QDRANT_BASE_URL"] = "localhost" os.environ["OPENAI_EMBEDDER_MODEL"] = "erwan2/DeepSeek-R1-Distill-Qwen-1.5B:latest" - + """ + pass + def _parse_memories_content(content: str) -> List[str]: """ Parse LLM content to extract memory strings. diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py index 1a4e545f..6a9972af 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mcp_client.py @@ -70,7 +70,9 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() - async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List[str]: + async def add_memories( + self, text: str, metadata: Dict[str, Any] = None + ) -> List[str]: """Add memories to the OpenMemory server. Uses the REST API to create memories. OpenMemory will handle: @@ -104,18 +106,28 @@ async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List app_id = None if apps_data.get("apps"): # Find matching app by name, prefer one with most memories - matching = [a for a in apps_data["apps"] if a["name"] == self.client_name] - memory_logger.debug(f"Matching apps for '{self.client_name}': {matching}") + matching = [ + a for a in apps_data["apps"] if a["name"] == self.client_name + ] + memory_logger.debug( + f"Matching apps for '{self.client_name}': {matching}" + ) if matching: - matching.sort(key=lambda x: x.get("total_memories_created", 0), reverse=True) + matching.sort( + key=lambda x: x.get("total_memories_created", 0), reverse=True + ) app_id = matching[0]["id"] memory_logger.info(f"Found matching app with ID: {app_id}") else: app_id = apps_data["apps"][0]["id"] - memory_logger.info(f"No matching app name, using first app ID: {app_id}") + memory_logger.info( + f"No matching app name, using first app ID: {app_id}" + ) if not app_id: - memory_logger.error("No apps found in OpenMemory - cannot create memory") + memory_logger.error( + "No apps found in OpenMemory - cannot create memory" + ) raise MCPError("No apps found in OpenMemory") # Merge custom metadata with default metadata @@ -143,9 +155,13 @@ async def add_memories(self, text: str, metadata: Dict[str, Any] = None) -> List ) memory_logger.debug(f"Full payload: {payload}") - response = await self.client.post(f"{self.server_url}/api/v1/memories/", json=payload) + response = await self.client.post( + f"{self.server_url}/api/v1/memories/", json=payload + ) - response_body = response.text[:500] if response.status_code != 200 else "..." + response_body = ( + response.text[:500] if response.status_code != 200 else "..." + ) memory_logger.info( f"OpenMemory response: status={response.status_code}, body={response_body}, headers={dict(response.headers)}" ) @@ -218,7 +234,12 @@ async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any app_id = apps_data["apps"][0]["id"] # Use app-specific memories endpoint with search - params = {"user_id": self.user_id, "search_query": query, "page": 1, "size": limit} + params = { + "user_id": self.user_id, + "search_query": query, + "page": 1, + "size": limit, + } response = await self.client.get( f"{self.server_url}/api/v1/apps/{app_id}/memories", params=params @@ -242,7 +263,8 @@ async def search_memory(self, query: str, limit: int = 10) -> List[Dict[str, Any { "id": memory.get("id", str(uuid.uuid4())), "content": memory.get("content", "") or memory.get("text", ""), - "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}), + "metadata": memory.get("metadata_", {}) + or memory.get("metadata", {}), "created_at": memory.get("created_at"), "score": memory.get("score", 0.0), # No score from list API } @@ -311,7 +333,8 @@ async def list_memories(self, limit: int = 100) -> List[Dict[str, Any]]: { "id": memory.get("id", str(uuid.uuid4())), "content": memory.get("content", "") or memory.get("text", ""), - "metadata": memory.get("metadata_", {}) or memory.get("metadata", {}), + "metadata": memory.get("metadata_", {}) + or memory.get("metadata", {}), "created_at": memory.get("created_at"), } ) @@ -377,7 +400,8 @@ async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]: try: # Use the memories endpoint with specific ID response = await self.client.get( - f"{self.server_url}/api/v1/memories/{memory_id}", params={"user_id": self.user_id} + f"{self.server_url}/api/v1/memories/{memory_id}", + params={"user_id": self.user_id}, ) if response.status_code == 404: @@ -392,7 +416,8 @@ async def get_memory(self, memory_id: str) -> Optional[Dict[str, Any]]: return { "id": result.get("id", memory_id), "content": result.get("content", "") or result.get("text", ""), - "metadata": result.get("metadata_", {}) or result.get("metadata", {}), + "metadata": result.get("metadata_", {}) + or result.get("metadata", {}), "created_at": result.get("created_at"), } @@ -454,7 +479,10 @@ async def update_memory( return False async def delete_memory( - self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + self, + memory_id: str, + user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> bool: """Delete a specific memory by ID. diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mock_llm_provider.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mock_llm_provider.py index 1405be3c..65738dae 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/mock_llm_provider.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/mock_llm_provider.py @@ -29,7 +29,10 @@ def __init__(self, config: Dict[str, Any] = None): self.embedding_dimension = 384 # Standard dimension for mock embeddings async def extract_memories( - self, text: str, prompt: str, user_id: Optional[str] = None, + self, + text: str, + prompt: str, + user_id: Optional[str] = None, ) -> List[str]: """ Return predefined mock memories extracted from text. @@ -114,12 +117,9 @@ async def propose_memory_actions( actions = [] for idx, fact in enumerate(new_facts): - actions.append({ - "id": str(idx), - "event": "ADD", - "text": fact, - "old_memory": None - }) + actions.append( + {"id": str(idx), "event": "ADD", "text": fact, "old_memory": None} + ) return {"memory": actions} diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py index ba8484b2..227b9b37 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/openmemory_mcp.py @@ -66,8 +66,12 @@ def __init__( timeout: HTTP request timeout in seconds """ super().__init__() - self.server_url = server_url or os.getenv("OPENMEMORY_MCP_URL", "http://localhost:8765") - self.client_name = client_name or os.getenv("OPENMEMORY_CLIENT_NAME", "chronicle") + self.server_url = server_url or os.getenv( + "OPENMEMORY_MCP_URL", "http://localhost:8765" + ) + self.client_name = client_name or os.getenv( + "OPENMEMORY_CLIENT_NAME", "chronicle" + ) self.user_id = user_id or os.getenv("OPENMEMORY_USER_ID", "default") self.timeout = int(timeout or os.getenv("OPENMEMORY_TIMEOUT", "30")) self.mcp_client: Optional[MCPClient] = None @@ -95,7 +99,9 @@ async def initialize(self) -> None: # Test connection to OpenMemory MCP server is_connected = await self.mcp_client.test_connection() if not is_connected: - raise RuntimeError(f"Cannot connect to OpenMemory MCP server at {self.server_url}") + raise RuntimeError( + f"Cannot connect to OpenMemory MCP server at {self.server_url}" + ) self._initialized = True memory_logger.info( @@ -148,7 +154,9 @@ async def add_memory( # Use configured OpenMemory user (from config) for all Chronicle users # Chronicle user_id and email are stored in metadata for filtering - enriched_transcript = f"[Source: {source_id}, Client: {client_id}] {transcript}" + enriched_transcript = ( + f"[Source: {source_id}, Client: {client_id}] {transcript}" + ) memory_logger.info( f"Delegating memory processing to OpenMemory for user {user_id}, source {source_id}" @@ -168,7 +176,9 @@ async def add_memory( # Update database relationships if helper provided if memory_ids and db_helper: - await self._update_database_relationships(db_helper, source_id, memory_ids) + await self._update_database_relationships( + db_helper, source_id, memory_ids + ) if memory_ids: memory_logger.info( @@ -186,7 +196,9 @@ async def add_memory( memory_logger.error(f"❌ OpenMemory MCP error for {source_id}: {e}") raise e except Exception as e: - memory_logger.error(f"❌ OpenMemory MCP service failed for {source_id}: {e}") + memory_logger.error( + f"❌ OpenMemory MCP service failed for {source_id}: {e}" + ) raise e async def search_memories( @@ -241,7 +253,9 @@ async def search_memories( memory_logger.error(f"Search memories failed: {e}") return [] - async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]: + async def get_all_memories( + self, user_id: str, limit: int = 100 + ) -> List[MemoryEntry]: """Get all memories for a specific user. Retrieves all stored memories for the given user without @@ -275,7 +289,9 @@ async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryE if len(memory_entries) >= limit: break # Got enough results - memory_logger.info(f"πŸ“š Retrieved {len(memory_entries)} memories for user {user_id}") + memory_logger.info( + f"πŸ“š Retrieved {len(memory_entries)} memories for user {user_id}" + ) return memory_entries except MCPError as e: @@ -338,7 +354,9 @@ async def get_memory( # Update MCP client user context for this operation original_user_id = self.mcp_client.user_id - self.mcp_client.user_id = user_id or self.user_id # Use the actual Chronicle user's ID + self.mcp_client.user_id = ( + user_id or self.user_id + ) # Use the actual Chronicle user's ID try: result = await self.mcp_client.get_memory(memory_id) @@ -348,7 +366,9 @@ async def get_memory( return None # Convert MCP result to MemoryEntry - memory_entry = self._mcp_result_to_memory_entry(result, user_id or self.user_id) + memory_entry = self._mcp_result_to_memory_entry( + result, user_id or self.user_id + ) if memory_entry: memory_logger.info(f"πŸ“– Retrieved memory {memory_id}") return memory_entry @@ -385,7 +405,9 @@ async def update_memory( # Update MCP client user context for this operation original_user_id = self.mcp_client.user_id - self.mcp_client.user_id = user_id or self.user_id # Use the actual Chronicle user's ID + self.mcp_client.user_id = ( + user_id or self.user_id + ) # Use the actual Chronicle user's ID try: success = await self.mcp_client.update_memory( @@ -404,7 +426,10 @@ async def update_memory( self.mcp_client.user_id = original_user_id async def delete_memory( - self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None + self, + memory_id: str, + user_id: Optional[str] = None, + user_email: Optional[str] = None, ) -> bool: """Delete a specific memory by ID. @@ -444,7 +469,9 @@ async def delete_all_user_memories(self, user_id: str) -> int: try: count = await self.mcp_client.delete_all_memories() - memory_logger.info(f"πŸ—‘οΈ Deleted {count} memories for user {user_id} via OpenMemory MCP") + memory_logger.info( + f"πŸ—‘οΈ Deleted {count} memories for user {user_id} via OpenMemory MCP" + ) return count except Exception as e: diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py index 06fd44fa..54add1f7 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/providers/vector_stores.py @@ -30,11 +30,11 @@ class QdrantVectorStore(VectorStoreBase): """Qdrant vector store implementation. - + Provides high-performance vector storage and similarity search using Qdrant database. Handles memory persistence, user isolation, and semantic search operations. - + Attributes: host: Qdrant server hostname port: Qdrant server port @@ -52,34 +52,35 @@ def __init__(self, config: Dict[str, Any]): async def initialize(self) -> None: """Initialize Qdrant client and collection. - + Creates the collection if it doesn't exist with appropriate vector configuration for cosine similarity search. - + If the collection exists but has different dimensions, it will be recreated with the correct dimensions (data will be lost). - + Raises: RuntimeError: If initialization fails """ try: self.client = AsyncQdrantClient(host=self.host, port=self.port) - + # Check if collection exists and get its info collections = await self.client.get_collections() collection_exists = any( - col.name == self.collection_name - for col in collections.collections + col.name == self.collection_name for col in collections.collections ) - + need_create = False - + if collection_exists: # Check if dimensions match try: - collection_info = await self.client.get_collection(self.collection_name) + collection_info = await self.client.get_collection( + self.collection_name + ) existing_dims = collection_info.config.params.vectors.size - + if existing_dims != self.embedding_dims: memory_logger.warning( f"Collection {self.collection_name} exists with {existing_dims} dimensions, " @@ -93,7 +94,9 @@ async def initialize(self) -> None: f"Collection {self.collection_name} exists with correct dimensions ({self.embedding_dims})" ) except Exception as e: - memory_logger.warning(f"Error checking collection info: {e}. Recreating...") + memory_logger.warning( + f"Error checking collection info: {e}. Recreating..." + ) try: await self.client.delete_collection(self.collection_name) except: @@ -101,19 +104,18 @@ async def initialize(self) -> None: need_create = True else: need_create = True - + if need_create: await self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( - size=self.embedding_dims, - distance=Distance.COSINE - ) + size=self.embedding_dims, distance=Distance.COSINE + ), ) memory_logger.info( f"Created Qdrant collection: {self.collection_name} with {self.embedding_dims} dimensions" ) - + except Exception as e: memory_logger.error(f"Qdrant initialization failed: {e}") raise @@ -132,15 +134,14 @@ async def add_memories(self, memories: List[MemoryEntry]) -> List[str]: "content": memory.content, "metadata": memory.metadata, "created_at": memory.created_at or current_time, - "updated_at": memory.updated_at or current_time - } + "updated_at": memory.updated_at or current_time, + }, ) points.append(point) if points: await self.client.upsert( - collection_name=self.collection_name, - points=points + collection_name=self.collection_name, points=points ) return [str(point.id) for point in points] @@ -150,9 +151,15 @@ async def add_memories(self, memories: List[MemoryEntry]) -> List[str]: memory_logger.error(f"Qdrant add memories failed: {e}") return [] - async def search_memories(self, query_embedding: List[float], user_id: str, limit: int, score_threshold: float = 0.0) -> List[MemoryEntry]: + async def search_memories( + self, + query_embedding: List[float], + user_id: str, + limit: int, + score_threshold: float = 0.0, + ) -> List[MemoryEntry]: """Search memories in Qdrant with configurable similarity threshold filtering. - + Args: query_embedding: Query vector for similarity search user_id: User identifier to filter results @@ -164,19 +171,18 @@ async def search_memories(self, query_embedding: List[float], user_id: str, limi search_filter = Filter( must=[ FieldCondition( - key="metadata.user_id", - match=MatchValue(value=user_id) + key="metadata.user_id", match=MatchValue(value=user_id) ) ] ) - + # Apply similarity threshold if provided # For cosine similarity, scores range from -1 to 1, where 1 is most similar search_params = { "collection_name": self.collection_name, "query": query_embedding, "query_filter": search_filter, - "limit": limit + "limit": limit, } if score_threshold > 0.0: @@ -194,17 +200,27 @@ async def search_memories(self, query_embedding: List[float], user_id: str, limi # Qdrant returns similarity scores directly (higher = more similar) score=result.score if result.score is not None else None, created_at=result.payload.get("created_at"), - updated_at=result.payload.get("updated_at") + updated_at=result.payload.get("updated_at"), ) memories.append(memory) # Log similarity scores for debugging - score_str = f"{result.score:.3f}" if result.score is not None else "None" - memory_logger.debug(f"Retrieved memory with score {score_str}: {result.payload.get('content', '')[:50]}...") - - threshold_msg = f"threshold {score_threshold}" if score_threshold > 0.0 else "no threshold" - memory_logger.info(f"Found {len(memories)} memories with {threshold_msg} for user {user_id}") + score_str = ( + f"{result.score:.3f}" if result.score is not None else "None" + ) + memory_logger.debug( + f"Retrieved memory with score {score_str}: {result.payload.get('content', '')[:50]}..." + ) + + threshold_msg = ( + f"threshold {score_threshold}" + if score_threshold > 0.0 + else "no threshold" + ) + memory_logger.info( + f"Found {len(memories)} memories with {threshold_msg} for user {user_id}" + ) return memories - + except Exception as e: memory_logger.error(f"Qdrant search failed: {e}") return [] @@ -216,18 +232,17 @@ async def get_memories(self, user_id: str, limit: int) -> List[MemoryEntry]: search_filter = Filter( must=[ FieldCondition( - key="metadata.user_id", - match=MatchValue(value=user_id) + key="metadata.user_id", match=MatchValue(value=user_id) ) ] ) - + results = await self.client.scroll( collection_name=self.collection_name, scroll_filter=search_filter, - limit=limit + limit=limit, ) - + memories = [] for point in results[0]: # results is tuple (points, next_page_offset) memory = MemoryEntry( @@ -235,21 +250,27 @@ async def get_memories(self, user_id: str, limit: int) -> List[MemoryEntry]: content=point.payload.get("content", ""), metadata=point.payload.get("metadata", {}), created_at=point.payload.get("created_at"), - updated_at=point.payload.get("updated_at") + updated_at=point.payload.get("updated_at"), ) memories.append(memory) return memories - + except Exception as e: memory_logger.error(f"Qdrant get memories failed: {e}") return [] - async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None) -> bool: + async def delete_memory( + self, + memory_id: str, + user_id: Optional[str] = None, + user_email: Optional[str] = None, + ) -> bool: """Delete a specific memory from Qdrant.""" try: # Convert memory_id to proper format for Qdrant import uuid + try: # Try to parse as UUID first uuid.UUID(memory_id) @@ -263,11 +284,10 @@ async def delete_memory(self, memory_id: str, user_id: Optional[str] = None, use point_id = memory_id await self.client.delete( - collection_name=self.collection_name, - points_selector=[point_id] + collection_name=self.collection_name, points_selector=[point_id] ) return True - + except Exception as e: memory_logger.error(f"Qdrant delete memory failed: {e}") return False @@ -278,25 +298,24 @@ async def delete_user_memories(self, user_id: str) -> int: # First count memories to delete memories = await self.get_memories(user_id, limit=10000) count = len(memories) - + if count > 0: # Delete by filter delete_filter = Filter( must=[ FieldCondition( - key="metadata.user_id", - match=MatchValue(value=user_id) + key="metadata.user_id", match=MatchValue(value=user_id) ) ] ) - + await self.client.delete( collection_name=self.collection_name, - points_selector=FilterSelector(filter=delete_filter) + points_selector=FilterSelector(filter=delete_filter), ) - + return count - + except Exception as e: memory_logger.error(f"Qdrant delete user memories failed: {e}") return 0 @@ -308,7 +327,7 @@ async def test_connection(self) -> bool: await self.client.get_collections() return True return False - + except Exception as e: memory_logger.error(f"Qdrant connection test failed: {e}") return False @@ -336,6 +355,7 @@ async def update_memory( # Convert memory_id to proper format for Qdrant # Qdrant accepts either UUID strings or unsigned integers import uuid + try: # Try to parse as UUID first uuid.UUID(memory_id) @@ -370,8 +390,7 @@ async def count_memories(self, user_id: str) -> int: search_filter = Filter( must=[ FieldCondition( - key="metadata.user_id", - match=MatchValue(value=user_id) + key="metadata.user_id", match=MatchValue(value=user_id) ) ] ) @@ -379,8 +398,7 @@ async def count_memories(self, user_id: str) -> int: # Use Qdrant's native count API (documented in qdrant/qdrant/docs) # Count operation: CountPoints -> CountResponse with count result result = await self.client.count( - collection_name=self.collection_name, - count_filter=search_filter + collection_name=self.collection_name, count_filter=search_filter ) return result.count @@ -389,7 +407,9 @@ async def count_memories(self, user_id: str) -> int: memory_logger.error(f"Qdrant count memories failed: {e}") return 0 - async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Optional[MemoryEntry]: + async def get_memory( + self, memory_id: str, user_id: Optional[str] = None + ) -> Optional[MemoryEntry]: """Get a specific memory by ID from Qdrant. Args: @@ -402,6 +422,7 @@ async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Opt try: # Convert memory_id to proper format for Qdrant import uuid + try: # Try to parse as UUID first uuid.UUID(memory_id) @@ -419,7 +440,7 @@ async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Opt collection_name=self.collection_name, ids=[point_id], with_payload=True, - with_vectors=False + with_vectors=False, ) if not points: @@ -432,7 +453,9 @@ async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Opt if user_id: point_user_id = point.payload.get("metadata", {}).get("user_id") if point_user_id != user_id: - memory_logger.warning(f"Memory {memory_id} does not belong to user {user_id}") + memory_logger.warning( + f"Memory {memory_id} does not belong to user {user_id}" + ) return None # Convert to MemoryEntry @@ -441,7 +464,7 @@ async def get_memory(self, memory_id: str, user_id: Optional[str] = None) -> Opt content=point.payload.get("content", ""), metadata=point.payload.get("metadata", {}), created_at=point.payload.get("created_at"), - updated_at=point.payload.get("updated_at") + updated_at=point.payload.get("updated_at"), ) memory_logger.debug(f"Retrieved memory {memory_id}") @@ -559,4 +582,3 @@ async def get_recent_memories( except Exception as e: memory_logger.error(f"Qdrant get recent memories failed: {e}") return [] - diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/service_factory.py b/backends/advanced/src/advanced_omi_backend/services/memory/service_factory.py index 2c9507f5..3492a154 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/service_factory.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/service_factory.py @@ -34,7 +34,9 @@ def create_memory_service(config: MemoryConfig) -> MemoryServiceBase: ValueError: If unsupported memory provider is specified RuntimeError: If required dependencies are missing """ - memory_logger.info(f"🧠 Creating memory service with provider: {config.memory_provider.value}") + memory_logger.info( + f"🧠 Creating memory service with provider: {config.memory_provider.value}" + ) if config.memory_provider == MemoryProvider.CHRONICLE: # Use the sophisticated Chronicle implementation @@ -50,7 +52,9 @@ def create_memory_service(config: MemoryConfig) -> MemoryServiceBase: raise RuntimeError(f"OpenMemory MCP service not available: {e}") if not config.openmemory_config: - raise ValueError("OpenMemory configuration is required for OPENMEMORY_MCP provider") + raise ValueError( + "OpenMemory configuration is required for OPENMEMORY_MCP provider" + ) return OpenMemoryMCPService(**config.openmemory_config) diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/update_memory_utils.py b/backends/advanced/src/advanced_omi_backend/services/memory/update_memory_utils.py index b0c6c9db..3e7c7144 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/update_memory_utils.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/update_memory_utils.py @@ -1,4 +1,3 @@ - import re import xml.etree.ElementTree as ET from dataclasses import asdict, dataclass @@ -8,6 +7,7 @@ NUMERIC_ID = re.compile(r"^\d+$") ALLOWED_EVENTS = {"ADD", "UPDATE", "DELETE", "NONE"} + @dataclass(frozen=True) class MemoryItem: id: str @@ -15,9 +15,11 @@ class MemoryItem: text: str old_memory: Optional[str] = None + class MemoryXMLParseError(ValueError): pass + def extract_xml_from_content(content: str) -> str: """ Extract XML from content that might contain other text. @@ -27,32 +29,34 @@ def extract_xml_from_content(content: str) -> str: import re # Look for ... block - xml_match = re.search(r'.*?', content, re.DOTALL) + xml_match = re.search(r".*?", content, re.DOTALL) if xml_match: return xml_match.group(0) - + # If no tags found, return the original content return content + def clean_and_validate_xml(xml_str: str) -> str: """ Clean common XML issues and validate structure. """ xml_str = xml_str.strip() - + # Print raw XML for debugging print("Raw XML content:") print("=" * 50) print(repr(xml_str)) print("=" * 50) print("Formatted XML content:") - lines = xml_str.split('\n') + lines = xml_str.split("\n") for i, line in enumerate(lines, 1): print(f"{i:2d}: {line}") print("=" * 50) - + return xml_str + def extract_assistant_xml_from_openai_response(response) -> str: """ Extract XML content from OpenAI ChatCompletion response. @@ -62,7 +66,10 @@ def extract_assistant_xml_from_openai_response(response) -> str: # OpenAI ChatCompletion object structure return response.choices[0].message.content except (AttributeError, IndexError, KeyError) as e: - raise MemoryXMLParseError(f"Could not extract assistant XML from OpenAI response: {e}") from e + raise MemoryXMLParseError( + f"Could not extract assistant XML from OpenAI response: {e}" + ) from e + def parse_memory_xml(xml_str: str) -> List[MemoryItem]: """ @@ -118,9 +125,11 @@ def parse_memory_xml(xml_str: str) -> List[MemoryItem]: # Children text_el = item.find("text") if text_el is None or (text_el.text or "").strip() == "": - raise MemoryXMLParseError(f" is required and non-empty for id {item_id}.") + raise MemoryXMLParseError( + f" is required and non-empty for id {item_id}." + ) text_val = (text_el.text or "").strip() - + # No JSON expansion needed - individual facts are now properly handled by improved prompts old_el = item.find("old_memory") @@ -133,9 +142,13 @@ def parse_memory_xml(xml_str: str) -> List[MemoryItem]: else: # For non-UPDATE, must not appear if old_el is not None: - raise MemoryXMLParseError(f" must only appear for UPDATE (id {item_id}).") + raise MemoryXMLParseError( + f" must only appear for UPDATE (id {item_id})." + ) - items.append(MemoryItem(id=item_id, event=event, text=text_val, old_memory=old_val)) + items.append( + MemoryItem(id=item_id, event=event, text=text_val, old_memory=old_val) + ) if not items: raise MemoryXMLParseError("No elements found in .") @@ -151,4 +164,4 @@ def items_to_json(items: List[MemoryItem]) -> Dict[str, Any]: if it.event == "UPDATE" and it.old_memory: # include only if non-empty obj["old_memory"] = it.old_memory out.append(obj) - return {"memory": out} \ No newline at end of file + return {"memory": out} diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/utils.py b/backends/advanced/src/advanced_omi_backend/services/memory/utils.py index b3c231f7..bdb4d340 100644 --- a/backends/advanced/src/advanced_omi_backend/services/memory/utils.py +++ b/backends/advanced/src/advanced_omi_backend/services/memory/utils.py @@ -38,7 +38,7 @@ def extract_json_from_text(response_text: str) -> Optional[Dict[str, Any]]: think_end = response_text.find("") if think_end != -1: # Extract everything after - json_part = response_text[think_end + 8:].strip() + json_part = response_text[think_end + 8 :].strip() if json_part: try: @@ -56,12 +56,14 @@ def extract_json_from_text(response_text: str) -> Optional[Dict[str, Any]]: # Clean up common LLM response artifacts cleaned_text = response_text # Remove markdown code blocks - cleaned_text = re.sub(r'```(?:json)?\s*(.*?)\s*```', r'\1', cleaned_text, flags=re.DOTALL) + cleaned_text = re.sub( + r"```(?:json)?\s*(.*?)\s*```", r"\1", cleaned_text, flags=re.DOTALL + ) # Remove common prefixes - cleaned_text = re.sub(r'^.*?(?=\{)', '', cleaned_text, flags=re.DOTALL) + cleaned_text = re.sub(r"^.*?(?=\{)", "", cleaned_text, flags=re.DOTALL) # Remove trailing non-JSON content - cleaned_text = re.sub(r'\}.*$', '}', cleaned_text, flags=re.DOTALL) - + cleaned_text = re.sub(r"\}.*$", "}", cleaned_text, flags=re.DOTALL) + # Try parsing the cleaned text try: parsed = json.loads(cleaned_text.strip()) @@ -79,7 +81,7 @@ def extract_json_from_text(response_text: str) -> Optional[Dict[str, Any]]: # Look for any JSON object containing memory or facts r'\{[^{}]*"(?:memory|facts)"[^{}]*\}', # Look for any balanced JSON object - r'\{(?:[^{}]|{[^{}]*})*\}', + r"\{(?:[^{}]|{[^{}]*})*\}", ] for pattern in json_patterns: @@ -100,7 +102,7 @@ def extract_json_from_text(response_text: str) -> Optional[Dict[str, Any]]: except json.JSONDecodeError: continue # Use fallback if we found a valid dict but without preferred keys - if 'fallback' in locals(): + if "fallback" in locals(): return fallback except Exception as e: memory_logger.debug(f"Pattern {pattern} failed: {e}") @@ -115,7 +117,9 @@ def extract_json_from_text(response_text: str) -> Optional[Dict[str, Any]]: array_str = match.group(1) array_data = json.loads(array_str) if isinstance(array_data, list): - memory_logger.debug(f"Successfully extracted {key} array from response") + memory_logger.debug( + f"Successfully extracted {key} array from response" + ) return {key: array_data} except Exception as e: memory_logger.debug(f"{key} array extraction failed: {e}") @@ -131,7 +135,9 @@ def extract_json_from_text(response_text: str) -> Optional[Dict[str, Any]]: try: parsed = json.loads(potential_json) if isinstance(parsed, dict): - memory_logger.debug("Successfully extracted JSON using bracket matching") + memory_logger.debug( + "Successfully extracted JSON using bracket matching" + ) return parsed except json.JSONDecodeError: pass @@ -143,5 +149,3 @@ def extract_json_from_text(response_text: str) -> Optional[Dict[str, Any]]: f"Failed to extract JSON from LLM response. Response preview: {response_text[:200]}..." ) return None - - diff --git a/backends/advanced/src/advanced_omi_backend/services/obsidian_service.py b/backends/advanced/src/advanced_omi_backend/services/obsidian_service.py index b02a6fa0..7330a6f7 100644 --- a/backends/advanced/src/advanced_omi_backend/services/obsidian_service.py +++ b/backends/advanced/src/advanced_omi_backend/services/obsidian_service.py @@ -114,7 +114,9 @@ def __init__(self): embed_config = get_model_config(config_data, "embedding") if not embed_config: - raise ValueError("Configuration for 'defaults.embedding' not found in config.yml") + raise ValueError( + "Configuration for 'defaults.embedding' not found in config.yml" + ) # Neo4j Connection - Prefer environment variables passed by Docker Compose neo4j_host = os.getenv("NEO4J_HOST") @@ -142,18 +144,24 @@ def __init__(self): self.neo4j_uri = f"bolt://{neo4j_host}:7687" self.neo4j_user = os.getenv("NEO4J_USER") or env_data.get("NEO4J_USER", "neo4j") - self.neo4j_password = os.getenv("NEO4J_PASSWORD") or env_data.get("NEO4J_PASSWORD", "") + self.neo4j_password = os.getenv("NEO4J_PASSWORD") or env_data.get( + "NEO4J_PASSWORD", "" + ) # Models / API - Loaded strictly from config.yml self.embedding_model = str(resolve_value(embed_config["model_name"])) - self.embedding_dimensions = int(resolve_value(embed_config["embedding_dimensions"])) + self.embedding_dimensions = int( + resolve_value(embed_config["embedding_dimensions"]) + ) self.openai_base_url = str(resolve_value(llm_config["model_url"])) self.openai_api_key = str(resolve_value(llm_config["api_key"])) # Chunking - uses shared spaCy/text fallback utility self.chunk_word_limit = 120 - self.neo4j_client = Neo4jClient(self.neo4j_uri, self.neo4j_user, self.neo4j_password) + self.neo4j_client = Neo4jClient( + self.neo4j_uri, self.neo4j_user, self.neo4j_password + ) self.read_interface = Neo4jReadInterface(self.neo4j_client) self.write_interface = Neo4jWriteInterface(self.neo4j_client) @@ -191,7 +199,9 @@ def _clean_text(text: str) -> str: """Normalize whitespace for embedding inputs.""" return re.sub(r"\s+", " ", text).strip() - def parse_obsidian_note(self, root: str, filename: str, vault_path: str) -> NoteData: + def parse_obsidian_note( + self, root: str, filename: str, vault_path: str + ) -> NoteData: """Parse an Obsidian markdown file and extract metadata. Args: @@ -284,7 +294,9 @@ async def chunking_and_embedding(self, note_data: NoteData) -> List[ChunkPayload model=self.embedding_model, ) except Exception as e: - logger.exception(f"Embedding generation failed for {note_data['path']}: {e}") + logger.exception( + f"Embedding generation failed for {note_data['path']}: {e}" + ) return [] chunk_payloads: List[ChunkPayload] = [] @@ -296,7 +308,9 @@ async def chunking_and_embedding(self, note_data: NoteData) -> List[ChunkPayload return chunk_payloads - def ingest_note_and_chunks(self, note_data: NoteData, chunks: List[ChunkPayload]) -> None: + def ingest_note_and_chunks( + self, note_data: NoteData, chunks: List[ChunkPayload] + ) -> None: """Store note and chunks in Neo4j with relationships to folders, tags, and links. Args: @@ -416,15 +430,15 @@ async def search_obsidian(self, query: str, limit: int = 5) -> ObsidianSearchRes cypher_query = """ CALL db.index.vector.queryNodes('chunk_embeddings', $limit, $vector) YIELD node AS chunk, score - + // Find the parent Note MATCH (note:Note)-[:HAS_CHUNK]->(chunk) - + // Get graph context: What tags and linked files are around this note? OPTIONAL MATCH (note)-[:HAS_TAG]->(t:Tag) OPTIONAL MATCH (note)-[:LINKS_TO]->(linked:Note) - - RETURN + + RETURN note.name AS source, chunk.text AS content, collect(DISTINCT t.name) AS tags, diff --git a/backends/advanced/src/advanced_omi_backend/services/plugin_assistant.py b/backends/advanced/src/advanced_omi_backend/services/plugin_assistant.py index c5836b89..12565314 100644 --- a/backends/advanced/src/advanced_omi_backend/services/plugin_assistant.py +++ b/backends/advanced/src/advanced_omi_backend/services/plugin_assistant.py @@ -286,14 +286,20 @@ async def _exec_tool(name: str, arguments: dict) -> dict: metadata = await system_controller.get_plugins_metadata() plugin_id = arguments.get("plugin_id") if plugin_id: - plugins = [p for p in metadata.get("plugins", []) if p.get("plugin_id") == plugin_id] + plugins = [ + p + for p in metadata.get("plugins", []) + if p.get("plugin_id") == plugin_id + ] return {"plugins": plugins, "status": "success"} return metadata if name == "apply_plugin_config": plugin_id = arguments["plugin_id"] config = {k: v for k, v in arguments.items() if k != "plugin_id"} - return await system_controller.update_plugin_config_structured(plugin_id, config) + return await system_controller.update_plugin_config_structured( + plugin_id, config + ) if name == "test_plugin_connection": plugin_id = arguments["plugin_id"] @@ -418,7 +424,9 @@ async def generate_response_stream(messages: list[dict]) -> AsyncGenerator[dict, full_messages = [{"role": "system", "content": system_prompt}] + messages for _ in range(MAX_TOOL_ROUNDS): - response = await async_chat_with_tools(full_messages, tools=TOOLS, operation="plugin_assistant") + response = await async_chat_with_tools( + full_messages, tools=TOOLS, operation="plugin_assistant" + ) choice = response.choices[0] # If the model wants to call tools, execute them and loop @@ -441,11 +449,15 @@ async def generate_response_stream(messages: list[dict]) -> AsyncGenerator[dict, if confirmation == "rejected": # User rejected β€” add tool result and let model respond - full_messages.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": json.dumps({"rejected": True, "reason": "User declined"}), - }) + full_messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": json.dumps( + {"rejected": True, "reason": "User declined"} + ), + } + ) continue if not confirmation: @@ -473,11 +485,13 @@ async def generate_response_stream(messages: list[dict]) -> AsyncGenerator[dict, yield {"type": "tool_result", "name": fn_name, "result": result} - full_messages.append({ - "role": "tool", - "tool_call_id": tool_call.id, - "content": json.dumps(result, default=str), - }) + full_messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": json.dumps(result, default=str), + } + ) if needs_confirmation: # Stop the loop β€” frontend will re-send with confirmation diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py index 7bc4b7ac..ec08998b 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/base.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/base.py @@ -34,6 +34,7 @@ class TranscriptionProvider(Enum): """Available transcription providers for audio stream routing.""" + DEEPGRAM = "deepgram" PARAKEET = "parakeet" @@ -93,7 +94,9 @@ def capabilities(self) -> set: return set() @abc.abstractmethod - async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize: bool = False): + async def start_stream( + self, client_id: str, sample_rate: int = 16000, diarize: bool = False + ): """Start a transcription stream for a client. Args: @@ -104,7 +107,9 @@ async def start_stream(self, client_id: str, sample_rate: int = 16000, diarize: pass @abc.abstractmethod - async def process_audio_chunk(self, client_id: str, audio_chunk: bytes) -> Optional[dict]: + async def process_audio_chunk( + self, client_id: str, audio_chunk: bytes + ) -> Optional[dict]: """ Process audio chunk and return partial/final transcription. @@ -127,7 +132,14 @@ def mode(self) -> str: return "batch" @abc.abstractmethod - async def transcribe(self, audio_data: bytes, sample_rate: int, diarize: bool = False, context_info: Optional[str] = None, **kwargs) -> dict: + async def transcribe( + self, + audio_data: bytes, + sample_rate: int, + diarize: bool = False, + context_info: Optional[str] = None, + **kwargs + ) -> dict: """Transcribe audio data. Args: diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/context.py b/backends/advanced/src/advanced_omi_backend/services/transcription/context.py index c042da3d..8184febf 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/context.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/context.py @@ -57,7 +57,9 @@ def to_metadata(self) -> dict: } -async def gather_transcription_context(user_id: Optional[str] = None) -> TranscriptionContext: +async def gather_transcription_context( + user_id: Optional[str] = None, +) -> TranscriptionContext: """Build structured transcription context: static hot words + cached user jargon. Args: diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py b/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py index 02e5b37c..520e94ad 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/mock_provider.py @@ -76,32 +76,120 @@ async def transcribe( # Generate mock words with timestamps (spread across audio duration) words = [ - {"word": "This", "start": 0.0, "end": 0.3, "confidence": 0.99, "speaker": 0}, + { + "word": "This", + "start": 0.0, + "end": 0.3, + "confidence": 0.99, + "speaker": 0, + }, {"word": "is", "start": 0.3, "end": 0.5, "confidence": 0.99, "speaker": 0}, {"word": "a", "start": 0.5, "end": 0.6, "confidence": 0.99, "speaker": 0}, - {"word": "mock", "start": 0.6, "end": 0.9, "confidence": 0.99, "speaker": 0}, - {"word": "transcription", "start": 0.9, "end": 1.5, "confidence": 0.98, "speaker": 0}, + { + "word": "mock", + "start": 0.6, + "end": 0.9, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "transcription", + "start": 0.9, + "end": 1.5, + "confidence": 0.98, + "speaker": 0, + }, {"word": "for", "start": 1.5, "end": 1.7, "confidence": 0.99, "speaker": 0}, - {"word": "testing", "start": 1.7, "end": 2.1, "confidence": 0.99, "speaker": 0}, - {"word": "purposes", "start": 2.1, "end": 2.6, "confidence": 0.97, "speaker": 0}, + { + "word": "testing", + "start": 1.7, + "end": 2.1, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "purposes", + "start": 2.1, + "end": 2.6, + "confidence": 0.97, + "speaker": 0, + }, {"word": "It", "start": 2.6, "end": 2.8, "confidence": 0.99, "speaker": 0}, - {"word": "contains", "start": 2.8, "end": 3.2, "confidence": 0.99, "speaker": 0}, - {"word": "enough", "start": 3.2, "end": 3.5, "confidence": 0.99, "speaker": 0}, - {"word": "words", "start": 3.5, "end": 3.8, "confidence": 0.99, "speaker": 0}, + { + "word": "contains", + "start": 2.8, + "end": 3.2, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "enough", + "start": 3.2, + "end": 3.5, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "words", + "start": 3.5, + "end": 3.8, + "confidence": 0.99, + "speaker": 0, + }, {"word": "to", "start": 3.8, "end": 3.9, "confidence": 0.99, "speaker": 0}, - {"word": "meet", "start": 3.9, "end": 4.1, "confidence": 0.99, "speaker": 0}, - {"word": "minimum", "start": 4.1, "end": 4.5, "confidence": 0.98, "speaker": 0}, - {"word": "length", "start": 4.5, "end": 4.8, "confidence": 0.99, "speaker": 0}, - {"word": "requirements", "start": 4.8, "end": 5.4, "confidence": 0.98, "speaker": 0}, + { + "word": "meet", + "start": 3.9, + "end": 4.1, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "minimum", + "start": 4.1, + "end": 4.5, + "confidence": 0.98, + "speaker": 0, + }, + { + "word": "length", + "start": 4.5, + "end": 4.8, + "confidence": 0.99, + "speaker": 0, + }, + { + "word": "requirements", + "start": 4.8, + "end": 5.4, + "confidence": 0.98, + "speaker": 0, + }, {"word": "for", "start": 5.4, "end": 5.6, "confidence": 0.99, "speaker": 0}, - {"word": "automated", "start": 5.6, "end": 6.1, "confidence": 0.98, "speaker": 0}, - {"word": "testing", "start": 6.1, "end": 6.5, "confidence": 0.99, "speaker": 0}, + { + "word": "automated", + "start": 5.6, + "end": 6.1, + "confidence": 0.98, + "speaker": 0, + }, + { + "word": "testing", + "start": 6.1, + "end": 6.5, + "confidence": 0.99, + "speaker": 0, + }, ] # Mock segments (single speaker for simplicity) segments = [{"speaker": 0, "start": 0.0, "end": 6.5, "text": mock_transcript}] - return {"text": mock_transcript, "words": words, "segments": segments if diarize else []} + return { + "text": mock_transcript, + "words": words, + "segments": segments if diarize else [], + } async def connect(self, client_id: Optional[str] = None): """Initialize the mock provider (no-op).""" diff --git a/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py b/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py index bc50068f..a9b2b43d 100644 --- a/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py +++ b/backends/advanced/src/advanced_omi_backend/services/transcription/streaming_consumer.py @@ -20,13 +20,12 @@ import redis.asyncio as redis from redis import exceptions as redis_exceptions -from advanced_omi_backend.plugins.events import PluginEvent - from advanced_omi_backend.client_manager import get_client_owner_async from advanced_omi_backend.models.user import get_user_by_id +from advanced_omi_backend.plugins.events import PluginEvent from advanced_omi_backend.plugins.router import PluginRouter -from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient from advanced_omi_backend.services.transcription import get_transcription_provider +from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient from advanced_omi_backend.utils.audio_utils import pcm_to_wav_bytes logger = logging.getLogger(__name__) @@ -75,13 +74,19 @@ def _group_words_into_segments(words: list) -> list: if spk != current_speaker and current_words: # Flush previous segment - segments.append({ - "start": current_words[0].get("start", 0.0), - "end": current_words[-1].get("end", 0.0), - "text": " ".join(cw.get("word", "") for cw in current_words), - "speaker": f"Speaker {current_speaker}" if current_speaker != -1 else "Unknown", - "words": list(current_words), - }) + segments.append( + { + "start": current_words[0].get("start", 0.0), + "end": current_words[-1].get("end", 0.0), + "text": " ".join(cw.get("word", "") for cw in current_words), + "speaker": ( + f"Speaker {current_speaker}" + if current_speaker != -1 + else "Unknown" + ), + "words": list(current_words), + } + ) current_words = [] current_speaker = spk @@ -89,13 +94,17 @@ def _group_words_into_segments(words: list) -> list: # Flush last segment if current_words: - segments.append({ - "start": current_words[0].get("start", 0.0), - "end": current_words[-1].get("end", 0.0), - "text": " ".join(cw.get("word", "") for cw in current_words), - "speaker": f"Speaker {current_speaker}" if current_speaker != -1 else "Unknown", - "words": list(current_words), - }) + segments.append( + { + "start": current_words[0].get("start", 0.0), + "end": current_words[-1].get("end", 0.0), + "text": " ".join(cw.get("word", "") for cw in current_words), + "speaker": ( + f"Speaker {current_speaker}" if current_speaker != -1 else "Unknown" + ), + "words": list(current_words), + } + ) return segments @@ -145,8 +154,8 @@ def __init__( # Check if provider supports streaming diarization self._provider_has_diarization = ( - hasattr(self.provider, 'capabilities') - and 'diarization' in self.provider.capabilities + hasattr(self.provider, "capabilities") + and "diarization" in self.provider.capabilities ) # Stream configuration @@ -160,7 +169,9 @@ def __init__( self.active_streams: Dict[str, Dict] = {} # {stream_name: {"session_id": ...}} # Session tracking for WebSocket connections - self.active_sessions: Dict[str, Dict] = {} # {session_id: {"last_activity": timestamp}} + self.active_sessions: Dict[str, Dict] = ( + {} + ) # {session_id: {"last_activity": timestamp}} # Audio buffers for speaker identification (raw PCM bytes per session) self._audio_buffers: Dict[str, bytearray] = {} @@ -180,7 +191,9 @@ async def discover_streams(self) -> list[str]: cursor, match=self.stream_pattern, count=100 ) if keys: - streams.extend([k.decode() if isinstance(k, bytes) else k for k in keys]) + streams.extend( + [k.decode() if isinstance(k, bytes) else k for k in keys] + ) return streams @@ -188,16 +201,15 @@ async def setup_consumer_group(self, stream_name: str): """Create consumer group if it doesn't exist.""" try: await self.redis_client.xgroup_create( - stream_name, - self.group_name, - "0", - mkstream=True + stream_name, self.group_name, "0", mkstream=True ) logger.debug(f"Created consumer group {self.group_name} for {stream_name}") except redis_exceptions.ResponseError as e: if "BUSYGROUP" not in str(e): raise - logger.debug(f"Consumer group {self.group_name} already exists for {stream_name}") + logger.debug( + f"Consumer group {self.group_name} already exists for {stream_name}" + ) async def start_session_stream(self, session_id: str, sample_rate: int = 16000): """ @@ -258,7 +270,10 @@ async def end_session_stream(self, session_id: str): has_word_speakers = ( self._provider_has_diarization and words - and any(isinstance(w, dict) and w.get("speaker") is not None for w in words) + and any( + isinstance(w, dict) and w.get("speaker") is not None + for w in words + ) ) if has_word_speakers: @@ -266,21 +281,28 @@ async def end_session_stream(self, session_id: str): speaker_name = None speaker_confidence = 0.0 else: - speaker_name, speaker_confidence = await self._identify_speaker(session_id) + speaker_name, speaker_confidence = await self._identify_speaker( + session_id + ) if speaker_name: final_result["speaker_name"] = speaker_name final_result["speaker_confidence"] = speaker_confidence await self.publish_to_client( - session_id, final_result, is_final=True, - speaker_name=speaker_name, speaker_confidence=speaker_confidence, + session_id, + final_result, + is_final=True, + speaker_name=speaker_name, + speaker_confidence=speaker_confidence, ) await self.store_final_result(session_id, final_result) # Trigger plugins on final result if self.plugin_router: - await self.trigger_plugins(session_id, final_result, speaker_name=speaker_name) + await self.trigger_plugins( + session_id, final_result, speaker_name=speaker_name + ) self.active_sessions.pop(session_id, None) self._audio_buffers.pop(session_id, None) @@ -288,7 +310,9 @@ async def end_session_stream(self, session_id: str): # Signal that streaming transcription is complete for this session completion_key = f"transcription:complete:{session_id}" await self.redis_client.set(completion_key, "1", ex=300) # 5 min TTL - logger.info(f"Streaming transcription complete for {session_id} (signal set)") + logger.info( + f"Streaming transcription complete for {session_id} (signal set)" + ) except Exception as e: logger.error(f"Error ending stream for {session_id}: {e}", exc_info=True) @@ -300,7 +324,9 @@ async def end_session_stream(self, session_id: str): except Exception: pass # Best effort - async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_id: str): + async def process_audio_chunk( + self, session_id: str, audio_chunk: bytes, chunk_id: str + ): """ Process a single audio chunk through streaming transcription provider. @@ -316,8 +342,7 @@ async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_i # Send audio chunk to provider WebSocket and get result result = await self.provider.process_audio_chunk( - client_id=session_id, - audio_chunk=audio_chunk + client_id=session_id, audio_chunk=audio_chunk ) # Update last activity @@ -337,7 +362,7 @@ async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_i # Track transcript at each step logger.info( f"TRANSCRIPT session={session_id}, is_final={is_final}, " - f"words={word_count}, text=\"{text}\"" + f'words={word_count}, text="{text}"' ) if is_final: @@ -345,7 +370,10 @@ async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_i has_word_speakers = ( self._provider_has_diarization and words - and any(isinstance(w, dict) and w.get("speaker") is not None for w in words) + and any( + isinstance(w, dict) and w.get("speaker") is not None + for w in words + ) ) if has_word_speakers: @@ -355,7 +383,9 @@ async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_i speaker_confidence = 0.0 else: # Identify speaker from buffered audio (non-diarizing providers) - speaker_name, speaker_confidence = await self._identify_speaker(session_id) + speaker_name, speaker_confidence = await self._identify_speaker( + session_id + ) if speaker_name: result["speaker_name"] = speaker_name @@ -363,26 +393,33 @@ async def process_audio_chunk(self, session_id: str, audio_chunk: bytes, chunk_i # Publish to clients with speaker info await self.publish_to_client( - session_id, result, is_final=True, - speaker_name=speaker_name, speaker_confidence=speaker_confidence, + session_id, + result, + is_final=True, + speaker_name=speaker_name, + speaker_confidence=speaker_confidence, ) logger.info( f"TRANSCRIPT [STORE] session={session_id}, words={word_count}, " f"speaker={speaker_name}, segments={len(result.get('segments', []))}, " - f"text=\"{text}\"" + f'text="{text}"' ) await self.store_final_result(session_id, result, chunk_id=chunk_id) # Trigger plugins on final results only if self.plugin_router: - await self.trigger_plugins(session_id, result, speaker_name=speaker_name) + await self.trigger_plugins( + session_id, result, speaker_name=speaker_name + ) else: # Interim result β€” normalize words but no speaker identification await self.publish_to_client(session_id, result, is_final=False) except Exception as e: - logger.error(f"Error processing audio chunk for {session_id}: {e}", exc_info=True) + logger.error( + f"Error processing audio chunk for {session_id}: {e}", exc_info=True + ) async def _identify_speaker(self, session_id: str) -> tuple[Optional[str], float]: """Identify the speaker from buffered audio via speaker recognition service. @@ -405,7 +442,9 @@ async def _identify_speaker(self, session_id: str) -> tuple[Optional[str], float user_id = await self._get_user_id_from_client_id(session_id) # Convert buffered PCM to WAV - wav_bytes = pcm_to_wav_bytes(bytes(buffer), sample_rate=16000, channels=1, sample_width=2) + wav_bytes = pcm_to_wav_bytes( + bytes(buffer), sample_rate=16000, channels=1, sample_width=2 + ) # Call speaker recognition service result = await self.speaker_client.identify_segment( @@ -460,7 +499,7 @@ async def publish_to_client( "words": result.get("words", []), "segments": result.get("segments", []), "confidence": result.get("confidence", 0.0), - "timestamp": time.time() + "timestamp": time.time(), } # Include speaker info on final results @@ -472,12 +511,18 @@ async def publish_to_client( await self.redis_client.publish(channel, json.dumps(message)) result_type = "FINAL" if is_final else "interim" - logger.debug(f"Published {result_type} result to {channel}: {message['text'][:50]}...") + logger.debug( + f"Published {result_type} result to {channel}: {message['text'][:50]}..." + ) except Exception as e: - logger.error(f"Error publishing to client for {session_id}: {e}", exc_info=True) + logger.error( + f"Error publishing to client for {session_id}: {e}", exc_info=True + ) - async def store_final_result(self, session_id: str, result: Dict, chunk_id: str = None): + async def store_final_result( + self, session_id: str, result: Dict, chunk_id: str = None + ): """ Store final transcription result to Redis Stream. @@ -512,10 +557,14 @@ async def store_final_result(self, session_id: str, result: Dict, chunk_id: str # Write to Redis Stream await self.redis_client.xadd(stream_name, entry) - logger.info(f"Stored final result to {stream_name}: {result.get('text', '')[:50]}... ({len(words)} words)") + logger.info( + f"Stored final result to {stream_name}: {result.get('text', '')[:50]}... ({len(words)} words)" + ) except Exception as e: - logger.error(f"Error storing final result for {session_id}: {e}", exc_info=True) + logger.error( + f"Error storing final result for {session_id}: {e}", exc_info=True + ) async def _get_user_id_from_client_id(self, client_id: str) -> Optional[str]: """ @@ -582,17 +631,17 @@ async def trigger_plugins( # Don't block plugins on lookup failure plugin_data = { - 'transcript': result.get("text", ""), - 'session_id': session_id, - 'words': result.get("words", []), - 'segments': result.get("segments", []), - 'confidence': result.get("confidence", 0.0), - 'is_final': True, + "transcript": result.get("text", ""), + "session_id": session_id, + "words": result.get("words", []), + "segments": result.get("segments", []), + "confidence": result.get("confidence", 0.0), + "is_final": True, } # Include speaker info if available if speaker_name: - plugin_data['speaker_name'] = speaker_name + plugin_data["speaker_name"] = speaker_name # Dispatch transcript.streaming event logger.info( @@ -604,16 +653,20 @@ async def trigger_plugins( event=PluginEvent.TRANSCRIPT_STREAMING, user_id=user_id, data=plugin_data, - metadata={'client_id': session_id} + metadata={"client_id": session_id}, ) if plugin_results: - logger.info(f"Plugins triggered successfully: {len(plugin_results)} results") + logger.info( + f"Plugins triggered successfully: {len(plugin_results)} results" + ) else: logger.info(f"No plugins triggered (no matching conditions)") except Exception as e: - logger.error(f"Error triggering plugins for {session_id}: {e}", exc_info=True) + logger.error( + f"Error triggering plugins for {session_id}: {e}", exc_info=True + ) async def process_stream(self, stream_name: str): """ @@ -628,7 +681,7 @@ async def process_stream(self, stream_name: str): # Track this stream self.active_streams[stream_name] = { "session_id": session_id, - "started_at": time.time() + "started_at": time.time(), } # Read actual sample rate from the session's audio_format stored in Redis @@ -639,9 +692,13 @@ async def process_stream(self, stream_name: str): if audio_format_raw: audio_format = json.loads(audio_format_raw) sample_rate = int(audio_format.get("rate", 16000)) - logger.info(f"Read sample rate {sample_rate}Hz from session {session_id}") + logger.info( + f"Read sample rate {sample_rate}Hz from session {session_id}" + ) except Exception as e: - logger.warning(f"Failed to read audio_format from Redis for {session_id}: {e}") + logger.warning( + f"Failed to read audio_format from Redis for {session_id}: {e}" + ) # Start WebSocket connection to transcription provider await self.start_session_stream(session_id, sample_rate=sample_rate) @@ -658,44 +715,62 @@ async def process_stream(self, stream_name: str): self.consumer_name, # "streaming-worker-{pid}" {stream_name: ">"}, # Read only new messages count=10, - block=1000 # Block for 1 second + block=1000, # Block for 1 second ) if not messages: # No new messages - check if stream is still alive if session_id not in self.active_sessions: - logger.info(f"Session {session_id} no longer active, ending stream processing") + logger.info( + f"Session {session_id} no longer active, ending stream processing" + ) stream_ended = True continue for stream, stream_messages in messages: - logger.debug(f"Read {len(stream_messages)} messages from {stream_name}") + logger.debug( + f"Read {len(stream_messages)} messages from {stream_name}" + ) for message_id, fields in stream_messages: - msg_id = message_id.decode() if isinstance(message_id, bytes) else message_id + msg_id = ( + message_id.decode() + if isinstance(message_id, bytes) + else message_id + ) # Check for end marker - if fields.get(b'end_marker') or fields.get('end_marker'): + if fields.get(b"end_marker") or fields.get("end_marker"): logger.info(f"End marker received for {session_id}") stream_ended = True # ACK the end marker - await self.redis_client.xack(stream_name, self.group_name, msg_id) + await self.redis_client.xack( + stream_name, self.group_name, msg_id + ) break # Extract audio data (producer sends as 'audio_data', not 'audio_chunk') - audio_chunk = fields.get(b'audio_data') or fields.get('audio_data') + audio_chunk = fields.get(b"audio_data") or fields.get( + "audio_data" + ) if audio_chunk: - logger.debug(f"Processing audio chunk {msg_id} ({len(audio_chunk)} bytes)") + logger.debug( + f"Processing audio chunk {msg_id} ({len(audio_chunk)} bytes)" + ) # Process audio chunk through streaming provider await self.process_audio_chunk( session_id=session_id, audio_chunk=audio_chunk, - chunk_id=msg_id + chunk_id=msg_id, ) else: - logger.warning(f"Message {msg_id} has no audio_data field") + logger.warning( + f"Message {msg_id} has no audio_data field" + ) # ACK the message after processing - await self.redis_client.xack(stream_name, self.group_name, msg_id) + await self.redis_client.xack( + stream_name, self.group_name, msg_id + ) if stream_ended: break @@ -703,14 +778,21 @@ async def process_stream(self, stream_name: str): except redis_exceptions.ResponseError as e: if "NOGROUP" in str(e): # Stream has expired or been deleted - exit gracefully - logger.info(f"Stream {stream_name} expired or deleted, ending processing") + logger.info( + f"Stream {stream_name} expired or deleted, ending processing" + ) stream_ended = True break else: - logger.error(f"Redis error reading from stream {stream_name}: {e}", exc_info=True) + logger.error( + f"Redis error reading from stream {stream_name}: {e}", + exc_info=True, + ) await asyncio.sleep(1) except Exception as e: - logger.error(f"Error reading from stream {stream_name}: {e}", exc_info=True) + logger.error( + f"Error reading from stream {stream_name}: {e}", exc_info=True + ) await asyncio.sleep(1) finally: @@ -727,7 +809,9 @@ async def process_stream(self, stream_name: str): try: await self._try_delete_finished_stream(stream_name) except Exception as e: - logger.debug(f"Stream cleanup check failed for {stream_name} (non-fatal): {e}") + logger.debug( + f"Stream cleanup check failed for {stream_name} (non-fatal): {e}" + ) async def _try_delete_finished_stream(self, stream_name: str): """ @@ -743,7 +827,7 @@ async def _try_delete_finished_stream(self, stream_name: str): if not await self.redis_client.exists(stream_name): return - groups = await self.redis_client.execute_command('XINFO', 'GROUPS', stream_name) + groups = await self.redis_client.execute_command("XINFO", "GROUPS", stream_name) if not groups: return @@ -753,7 +837,9 @@ async def _try_delete_finished_stream(self, stream_name: str): for group in groups: group_dict = {} for i in range(0, len(group), 2): - key = group[i].decode() if isinstance(group[i], bytes) else str(group[i]) + key = ( + group[i].decode() if isinstance(group[i], bytes) else str(group[i]) + ) value = group[i + 1] if isinstance(value, bytes): try: @@ -818,7 +904,9 @@ async def start_consuming(self): session_id = stream_name.replace("audio:stream:", "") completion_key = f"transcription:complete:{session_id}" if await self.redis_client.exists(completion_key): - logger.debug(f"Stream {stream_name} already completed, skipping") + logger.debug( + f"Stream {stream_name} already completed, skipping" + ) continue # Setup consumer group (no manual lock needed) @@ -829,7 +917,9 @@ async def start_consuming(self): # Spawn task to process this stream asyncio.create_task(self.process_stream(stream_name)) - logger.info(f"Now consuming from {stream_name} (group: {self.group_name})") + logger.info( + f"Now consuming from {stream_name} (group: {self.group_name})" + ) # Sleep before next discovery cycle (1s for fast discovery) await asyncio.sleep(1) diff --git a/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py b/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py index 98ab5ce5..cb77e630 100644 --- a/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py +++ b/backends/advanced/src/advanced_omi_backend/speaker_recognition_client.py @@ -71,7 +71,9 @@ def __init__(self, service_url: Optional[str] = None): # Disabled in config self.enabled = False self.service_url = None - logger.info("Speaker recognition client disabled (config.yml enabled=false)") + logger.info( + "Speaker recognition client disabled (config.yml enabled=false)" + ) return # Enabled - determine URL (priority: param > config > env var) @@ -83,9 +85,13 @@ def __init__(self, service_url: Optional[str] = None): self.enabled = bool(self.service_url) if self.enabled: - logger.info(f"Speaker recognition client initialized with URL: {self.service_url}") + logger.info( + f"Speaker recognition client initialized with URL: {self.service_url}" + ) else: - logger.info("Speaker recognition client disabled (no service URL configured)") + logger.info( + "Speaker recognition client disabled (no service URL configured)" + ) def calculate_timeout(self, audio_duration: Optional[float]) -> float: """ @@ -100,7 +106,9 @@ def calculate_timeout(self, audio_duration: Optional[float]) -> float: Calculated timeout in seconds """ BASE_TIMEOUT = 30.0 # Minimum timeout for short files - TIMEOUT_MULTIPLIER = 8.0 # Processing speed ratio (e.g., 1 min audio = 8 min timeout) + TIMEOUT_MULTIPLIER = ( + 8.0 # Processing speed ratio (e.g., 1 min audio = 8 min timeout) + ) MAX_TIMEOUT = 600.0 # 10 minute cap for very long files if audio_duration is None or audio_duration <= 0: @@ -121,7 +129,7 @@ async def diarize_identify_match( conversation_id: str, backend_token: str, transcript_data: Dict, - user_id: Optional[str] = None + user_id: Optional[str] = None, ) -> Dict: """ Perform diarization, speaker identification, and word-to-speaker matching. @@ -139,7 +147,7 @@ async def diarize_identify_match( Dictionary containing segments with matched text and speaker identification """ # Use mock client if configured - if hasattr(self, '_mock_client'): + if hasattr(self, "_mock_client"): return await self._mock_client.diarize_identify_match( conversation_id, backend_token, transcript_data, user_id ) @@ -150,17 +158,23 @@ async def diarize_identify_match( # Fetch conversation to get audio duration for timeout calculation from advanced_omi_backend.models.conversation import Conversation - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) audio_duration = conversation.audio_total_duration if conversation else None # Calculate proportional timeout based on audio duration timeout = self.calculate_timeout(audio_duration) try: - logger.info(f"🎀 Calling speaker service with conversation_id: {conversation_id[:12]}...") + logger.info( + f"🎀 Calling speaker service with conversation_id: {conversation_id[:12]}..." + ) # Read diarization source from config system from advanced_omi_backend.config import get_diarization_settings + config = get_diarization_settings() diarization_source = config.get("diarization_source", "pyannote") @@ -173,51 +187,81 @@ async def diarize_identify_match( if diarization_source == "deepgram": # DEEPGRAM DIARIZATION PATH: We EXPECT transcript has speaker info from Deepgram # Only need speaker identification of existing segments - logger.info("Using Deepgram diarization path - transcript should have speaker segments, identifying speakers") + logger.info( + "Using Deepgram diarization path - transcript should have speaker segments, identifying speakers" + ) # TODO: Implement proper speaker identification for Deepgram segments # For now, use diarize-identify-match as fallback until we implement segment identification - logger.warning("Deepgram segment identification not yet implemented, using diarize-identify-match as fallback") + logger.warning( + "Deepgram segment identification not yet implemented, using diarize-identify-match as fallback" + ) form_data.add_field("transcript_data", json.dumps(transcript_data)) - form_data.add_field("user_id", "1") # TODO: Implement proper user mapping - form_data.add_field("similarity_threshold", str(config.get("similarity_threshold", 0.45))) - form_data.add_field("min_duration", str(config.get("min_duration", 0.5))) + form_data.add_field( + "user_id", "1" + ) # TODO: Implement proper user mapping + form_data.add_field( + "similarity_threshold", + str(config.get("similarity_threshold", 0.45)), + ) + form_data.add_field( + "min_duration", str(config.get("min_duration", 0.5)) + ) # Use /v1/diarize-identify-match endpoint as fallback endpoint = "/v1/diarize-identify-match" else: # pyannote (default) # PYANNOTE PATH: Backend has transcript, need diarization + speaker identification - logger.info("Using Pyannote path - diarizing backend transcript and identifying speakers") + logger.info( + "Using Pyannote path - diarizing backend transcript and identifying speakers" + ) # Send existing transcript for diarization and speaker matching form_data.add_field("transcript_data", json.dumps(transcript_data)) - form_data.add_field("user_id", "1") # TODO: Implement proper user mapping - form_data.add_field("similarity_threshold", str(config.get("similarity_threshold", 0.45))) + form_data.add_field( + "user_id", "1" + ) # TODO: Implement proper user mapping + form_data.add_field( + "similarity_threshold", + str(config.get("similarity_threshold", 0.45)), + ) # Add pyannote diarization parameters - form_data.add_field("min_duration", str(config.get("min_duration", 0.5))) + form_data.add_field( + "min_duration", str(config.get("min_duration", 0.5)) + ) form_data.add_field("collar", str(config.get("collar", 2.0))) - form_data.add_field("min_duration_off", str(config.get("min_duration_off", 1.5))) + form_data.add_field( + "min_duration_off", str(config.get("min_duration_off", 1.5)) + ) if config.get("min_speakers"): - form_data.add_field("min_speakers", str(config.get("min_speakers"))) + form_data.add_field( + "min_speakers", str(config.get("min_speakers")) + ) if config.get("max_speakers"): - form_data.add_field("max_speakers", str(config.get("max_speakers"))) + form_data.add_field( + "max_speakers", str(config.get("max_speakers")) + ) # Use /v1/diarize-identify-match endpoint for backend integration endpoint = "/v1/diarize-identify-match" # Make the request to the consolidated endpoint request_url = f"{self.service_url}{endpoint}" - logger.info(f"🎀 DEBUG: Making request to speaker service URL: {request_url}") + logger.info( + f"🎀 DEBUG: Making request to speaker service URL: {request_url}" + ) async with session.post( request_url, data=form_data, timeout=aiohttp.ClientTimeout(total=timeout), ) as response: - logger.info(f"🎀 Speaker service response status: {response.status}") + logger.info( + f"🎀 Speaker service response status: {response.status}" + ) if response.status != 200: response_text = await response.text() @@ -230,7 +274,9 @@ async def diarize_identify_match( # Log basic result info num_segments = len(result.get("segments", [])) - logger.info(f"🎀 Speaker recognition returned {num_segments} segments") + logger.info( + f"🎀 Speaker recognition returned {num_segments} segments" + ) return result @@ -270,20 +316,30 @@ async def identify_segment( ) if not self.enabled: - return {"found": False, "speaker_name": None, "confidence": 0.0, "status": "unknown"} + return { + "found": False, + "speaker_name": None, + "confidence": 0.0, + "status": "unknown", + } try: async with aiohttp.ClientSession() as session: form_data = aiohttp.FormData() form_data.add_field( - "file", audio_wav_bytes, filename="segment.wav", content_type="audio/wav" + "file", + audio_wav_bytes, + filename="segment.wav", + content_type="audio/wav", ) # TODO: Implement proper user mapping between MongoDB ObjectIds and speaker service integer IDs # Speaker service expects integer user_id, not MongoDB ObjectId strings if user_id is not None: form_data.add_field("user_id", "1") if similarity_threshold is not None: - form_data.add_field("similarity_threshold", str(similarity_threshold)) + form_data.add_field( + "similarity_threshold", str(similarity_threshold) + ) async with session.post( f"{self.service_url}/identify", @@ -292,23 +348,50 @@ async def identify_segment( ) as response: if response.status != 200: response_text = await response.text() - logger.warning(f"🎀 /identify returned status {response.status}: {response_text}") - return {"found": False, "speaker_name": None, "confidence": 0.0, "status": "error"} + logger.warning( + f"🎀 /identify returned status {response.status}: {response_text}" + ) + return { + "found": False, + "speaker_name": None, + "confidence": 0.0, + "status": "error", + } return await response.json() except ClientConnectorError as e: logger.error(f"🎀 Failed to connect to speaker service /identify: {e}") - return {"found": False, "speaker_name": None, "confidence": 0.0, "status": "error"} + return { + "found": False, + "speaker_name": None, + "confidence": 0.0, + "status": "error", + } except asyncio.TimeoutError: logger.error("🎀 Timeout calling speaker service /identify") - return {"found": False, "speaker_name": None, "confidence": 0.0, "status": "error"} + return { + "found": False, + "speaker_name": None, + "confidence": 0.0, + "status": "error", + } except aiohttp.ClientError as e: logger.warning(f"🎀 Client error during /identify: {e}") - return {"found": False, "speaker_name": None, "confidence": 0.0, "status": "error"} + return { + "found": False, + "speaker_name": None, + "confidence": 0.0, + "status": "error", + } except Exception as e: logger.error(f"🎀 Error during /identify: {e}") - return {"found": False, "speaker_name": None, "confidence": 0.0, "status": "error"} + return { + "found": False, + "speaker_name": None, + "confidence": 0.0, + "status": "error", + } async def identify_provider_segments( self, @@ -339,8 +422,11 @@ async def identify_provider_segments( """ if hasattr(self, "_mock_client"): return await self._mock_client.identify_provider_segments( - conversation_id, segments, user_id, - per_segment=per_segment, min_segment_duration=min_segment_duration, + conversation_id, + segments, + user_id, + per_segment=per_segment, + min_segment_duration=min_segment_duration, ) if not self.enabled: @@ -400,11 +486,15 @@ def _is_non_speech(seg: Dict) -> bool: # For each label, pick top N longest segments >= min_segment_duration label_samples: Dict[str, List[Dict]] = {} for label, segs in label_groups.items(): - eligible = [s for s in segs if (s["end"] - s["start"]) >= min_segment_duration] + eligible = [ + s for s in segs if (s["end"] - s["start"]) >= min_segment_duration + ] eligible.sort(key=lambda s: s["end"] - s["start"], reverse=True) label_samples[label] = eligible[:MAX_SAMPLES_PER_LABEL] if not label_samples[label]: - logger.info(f"🎀 Label '{label}': no segments >= {min_segment_duration}s, skipping identification") + logger.info( + f"🎀 Label '{label}': no segments >= {min_segment_duration}s, skipping identification" + ) # Extract audio and identify concurrently with semaphore semaphore = asyncio.Semaphore(3) @@ -416,11 +506,15 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: conversation_id, seg["start"], seg["end"] ) result = await self.identify_segment( - wav_bytes, user_id=user_id, similarity_threshold=similarity_threshold + wav_bytes, + user_id=user_id, + similarity_threshold=similarity_threshold, ) return result except Exception as e: - logger.warning(f"🎀 Failed to identify segment [{seg['start']:.1f}-{seg['end']:.1f}]: {e}") + logger.warning( + f"🎀 Failed to identify segment [{seg['start']:.1f}-{seg['end']:.1f}]: {e}" + ) return None # Collect identification tasks @@ -456,7 +550,10 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: # Pick name with most votes, break ties by average confidence best_name = max( name_votes.keys(), - key=lambda n: (len(name_votes[n]), sum(name_votes[n]) / len(name_votes[n])), + key=lambda n: ( + len(name_votes[n]), + sum(name_votes[n]) / len(name_votes[n]), + ), ) avg_confidence = sum(name_votes[best_name]) / len(name_votes[best_name]) label_mapping[label] = (best_name, avg_confidence) @@ -465,7 +562,9 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: f"({len(name_votes[best_name])}/{len(tasks)} votes, conf={avg_confidence:.3f})" ) else: - logger.info(f"🎀 Label '{label}' -> no identification (keeping original)") + logger.info( + f"🎀 Label '{label}' -> no identification (keeping original)" + ) # Build result segments in same format as diarize_identify_match() # Non-speech segments are kept but not speaker-identified @@ -473,26 +572,30 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: for i, seg in enumerate(segments): label = seg.get("speaker", "Unknown") if i in non_speech_indices: - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": label, - "confidence": 0.0, - "status": "non_speech", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": label, + "confidence": 0.0, + "status": "non_speech", + } + ) else: mapped = label_mapping.get(label) - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": mapped[0] if mapped else None, - "confidence": mapped[1] if mapped else 0.0, - "status": "identified" if mapped else "unknown", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": mapped[0] if mapped else None, + "confidence": mapped[1] if mapped else 0.0, + "status": "identified" if mapped else "unknown", + } + ) identified_count = sum(1 for m in label_mapping.values() if m) logger.info( @@ -548,7 +651,9 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: conversation_id, seg["start"], seg["end"] ) return await self.identify_segment( - wav_bytes, user_id=user_id, similarity_threshold=similarity_threshold + wav_bytes, + user_id=user_id, + similarity_threshold=similarity_threshold, ) except Exception as e: logger.warning( @@ -582,15 +687,17 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: label = seg.get("speaker", "Unknown") if i in non_speech_indices: - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": label, - "confidence": 0.0, - "status": "non_speech", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": label, + "confidence": 0.0, + "status": "non_speech", + } + ) continue # Find the matching task entry @@ -599,15 +706,17 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: if task is None: # Too short for identification - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": None, - "confidence": 0.0, - "status": "too_short", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": None, + "confidence": 0.0, + "status": "too_short", + } + ) continue try: @@ -618,52 +727,60 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: # None result means _identify_one raised an exception (audio reconstruction or service call) if result is None: error_count += 1 - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": None, - "confidence": 0.0, - "status": "error", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": None, + "confidence": 0.0, + "status": "error", + } + ) continue if result.get("found"): name = result.get("speaker_name", label) confidence = result.get("confidence", 0.0) - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": name, - "confidence": confidence, - "status": "identified", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": name, + "confidence": confidence, + "status": "identified", + } + ) identified_count += 1 elif result and result.get("status") == "error": # Speaker service returned an error (500, timeout, etc.) error_count += 1 - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": None, - "confidence": 0.0, - "status": "error", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": None, + "confidence": 0.0, + "status": "error", + } + ) else: - result_segments.append({ - "start": seg["start"], - "end": seg["end"], - "text": seg.get("text", ""), - "speaker": label, - "identified_as": None, - "confidence": 0.0, - "status": "unknown", - }) + result_segments.append( + { + "start": seg["start"], + "end": seg["end"], + "text": seg.get("text", ""), + "speaker": label, + "identified_as": None, + "confidence": 0.0, + "status": "unknown", + } + ) logger.info( f"🎀 Per-segment identification complete: " @@ -687,7 +804,10 @@ async def _identify_one(seg: Dict) -> Optional[Dict]: return result async def diarize_and_identify( - self, audio_data: bytes, words: None, user_id: Optional[str] = None # NOT IMPLEMENTED + self, + audio_data: bytes, + words: None, + user_id: Optional[str] = None, # NOT IMPLEMENTED ) -> Dict: """ Perform diarization and speaker identification using the speaker recognition service. @@ -715,7 +835,9 @@ async def diarize_and_identify( # Estimate audio duration from data size (assuming 16kHz, 16-bit PCM) # WAV header is typically 44 bytes - estimated_duration = (len(audio_data) - 44) / 32000 # 16000 Hz * 2 bytes per sample + estimated_duration = ( + len(audio_data) - 44 + ) / 32000 # 16000 Hz * 2 bytes per sample timeout = self.calculate_timeout(estimated_duration) # Call the speaker recognition service @@ -733,7 +855,9 @@ async def diarize_and_identify( # Add all diarization parameters for the diarize-and-identify endpoint min_duration = diarization_settings.get("min_duration", 0.5) - similarity_threshold = diarization_settings.get("similarity_threshold", 0.45) + similarity_threshold = diarization_settings.get( + "similarity_threshold", 0.45 + ) collar = diarization_settings.get("collar", 2.0) min_duration_off = diarization_settings.get("min_duration_off", 1.5) @@ -743,9 +867,13 @@ async def diarize_and_identify( form_data.add_field("min_duration_off", str(min_duration_off)) if diarization_settings.get("min_speakers"): - form_data.add_field("min_speakers", str(diarization_settings["min_speakers"])) + form_data.add_field( + "min_speakers", str(diarization_settings["min_speakers"]) + ) if diarization_settings.get("max_speakers"): - form_data.add_field("max_speakers", str(diarization_settings["max_speakers"])) + form_data.add_field( + "max_speakers", str(diarization_settings["max_speakers"]) + ) form_data.add_field("identify_only_enrolled", "false") # TODO: Implement proper user mapping between MongoDB ObjectIds and speaker service integer IDs @@ -776,40 +904,57 @@ async def diarize_and_identify( return {"segments": []} result = await response.json() - segments_count = len(result.get('segments', [])) - logger.info(f"🎀 [DIARIZE] βœ… Speaker service returned {segments_count} segments") + segments_count = len(result.get("segments", [])) + logger.info( + f"🎀 [DIARIZE] βœ… Speaker service returned {segments_count} segments" + ) # Log details about identified speakers if segments_count > 0: identified_names = set() - for seg in result.get('segments', []): - identified_as = seg.get('identified_as') - if identified_as and identified_as != 'Unknown': + for seg in result.get("segments", []): + identified_as = seg.get("identified_as") + if identified_as and identified_as != "Unknown": identified_names.add(identified_as) if identified_names: - logger.info(f"🎀 [DIARIZE] Identified speakers in segments: {identified_names}") + logger.info( + f"🎀 [DIARIZE] Identified speakers in segments: {identified_names}" + ) else: - logger.warning(f"🎀 [DIARIZE] No identified speakers found in {segments_count} segments") + logger.warning( + f"🎀 [DIARIZE] No identified speakers found in {segments_count} segments" + ) return result except ClientConnectorError as e: - logger.error(f"🎀 [DIARIZE] ❌ Failed to connect to speaker recognition service at {self.service_url}: {e}") + logger.error( + f"🎀 [DIARIZE] ❌ Failed to connect to speaker recognition service at {self.service_url}: {e}" + ) return {"error": "connection_failed", "message": str(e), "segments": []} except asyncio.TimeoutError as e: - logger.error(f"🎀 [DIARIZE] ❌ Timeout connecting to speaker recognition service: {e}") + logger.error( + f"🎀 [DIARIZE] ❌ Timeout connecting to speaker recognition service: {e}" + ) return {"error": "timeout", "message": str(e), "segments": []} except aiohttp.ClientError as e: - logger.warning(f"🎀 [DIARIZE] ❌ Client error during speaker recognition: {e}") + logger.warning( + f"🎀 [DIARIZE] ❌ Client error during speaker recognition: {e}" + ) return {"error": "client_error", "message": str(e), "segments": []} except Exception as e: - logger.error(f"🎀 [DIARIZE] ❌ Error during speaker diarization and identification: {e}") + logger.error( + f"🎀 [DIARIZE] ❌ Error during speaker diarization and identification: {e}" + ) import traceback + logger.debug(traceback.format_exc()) return {"error": "unknown_error", "message": str(e), "segments": []} - async def identify_speakers(self, audio_path: str, segments: List[Dict]) -> Dict[str, str]: + async def identify_speakers( + self, audio_path: str, segments: List[Dict] + ) -> Dict[str, str]: """ Identify speakers in audio segments using the speaker recognition service. @@ -835,11 +980,14 @@ async def identify_speakers(self, audio_path: str, segments: List[Dict]) -> Dict # Get audio duration for timeout calculation import wave + try: with wave.open(audio_path, "rb") as wav_file: frame_count = wav_file.getnframes() sample_rate = wav_file.getframerate() - audio_duration = frame_count / sample_rate if sample_rate > 0 else None + audio_duration = ( + frame_count / sample_rate if sample_rate > 0 else None + ) except Exception as e: logger.warning(f"Failed to get audio duration from {audio_path}: {e}") audio_duration = None @@ -853,7 +1001,10 @@ async def identify_speakers(self, audio_path: str, segments: List[Dict]) -> Dict with open(audio_path, "rb") as audio_file: form_data = aiohttp.FormData() form_data.add_field( - "file", audio_file, filename=Path(audio_path).name, content_type="audio/wav" + "file", + audio_file, + filename=Path(audio_path).name, + content_type="audio/wav", ) # Get current diarization settings from advanced_omi_backend.config import get_diarization_settings @@ -861,14 +1012,29 @@ async def identify_speakers(self, audio_path: str, segments: List[Dict]) -> Dict _diarization_settings = get_diarization_settings() # Add all diarization parameters for the diarize-and-identify endpoint - form_data.add_field("min_duration", str(_diarization_settings.get("min_duration", 0.5))) - form_data.add_field("similarity_threshold", str(_diarization_settings.get("similarity_threshold", 0.45))) - form_data.add_field("collar", str(_diarization_settings.get("collar", 2.0))) - form_data.add_field("min_duration_off", str(_diarization_settings.get("min_duration_off", 1.5))) + form_data.add_field( + "min_duration", + str(_diarization_settings.get("min_duration", 0.5)), + ) + form_data.add_field( + "similarity_threshold", + str(_diarization_settings.get("similarity_threshold", 0.45)), + ) + form_data.add_field( + "collar", str(_diarization_settings.get("collar", 2.0)) + ) + form_data.add_field( + "min_duration_off", + str(_diarization_settings.get("min_duration_off", 1.5)), + ) if _diarization_settings.get("min_speakers"): - form_data.add_field("min_speakers", str(_diarization_settings["min_speakers"])) + form_data.add_field( + "min_speakers", str(_diarization_settings["min_speakers"]) + ) if _diarization_settings.get("max_speakers"): - form_data.add_field("max_speakers", str(_diarization_settings["max_speakers"])) + form_data.add_field( + "max_speakers", str(_diarization_settings["max_speakers"]) + ) form_data.add_field("identify_only_enrolled", "false") # Make the request @@ -886,7 +1052,9 @@ async def identify_speakers(self, audio_path: str, segments: List[Dict]) -> Dict result = await response.json() # Process the response to create speaker mapping - speaker_mapping = self._process_diarization_result(result, segments) + speaker_mapping = self._process_diarization_result( + result, segments + ) if speaker_mapping: logger.info(f"Speaker mapping created: {speaker_mapping}") @@ -941,14 +1109,18 @@ def _process_diarization_result( for seg in segments_for_speaker: identified_name = seg.get("identified_as") if identified_name and identified_name != "Unknown": - name_counts[identified_name] = name_counts.get(identified_name, 0) + 1 + name_counts[identified_name] = ( + name_counts.get(identified_name, 0) + 1 + ) # Assign the most common identified name, or unknown if none found if name_counts: best_name = max(name_counts.items(), key=lambda x: x[1])[0] speaker_mapping[generic_speaker] = best_name else: - speaker_mapping[generic_speaker] = f"unknown_speaker_{unknown_counter}" + speaker_mapping[generic_speaker] = ( + f"unknown_speaker_{unknown_counter}" + ) unknown_counter += 1 logger.info(f"🎀 Speaker mapping: {speaker_mapping}") @@ -978,7 +1150,9 @@ async def get_enrolled_speakers(self, user_id: Optional[str] = None) -> Dict: timeout=aiohttp.ClientTimeout(total=10), ) as response: if response.status != 200: - logger.warning(f"🎀 Failed to get enrolled speakers: status {response.status}") + logger.warning( + f"🎀 Failed to get enrolled speakers: status {response.status}" + ) return {"speakers": []} result = await response.json() @@ -993,7 +1167,9 @@ async def get_enrolled_speakers(self, user_id: Optional[str] = None) -> Dict: logger.error(f"🎀 Error getting enrolled speakers: {e}") return {"speakers": []} - async def get_speaker_by_name(self, speaker_name: str, user_id: int = 1) -> Optional[Dict]: + async def get_speaker_by_name( + self, speaker_name: str, user_id: int = 1 + ) -> Optional[Dict]: """ Look up enrolled speaker by name. @@ -1016,19 +1192,25 @@ async def get_speaker_by_name(self, speaker_name: str, user_id: int = 1) -> Opti timeout=aiohttp.ClientTimeout(total=10), ) as response: if response.status != 200: - logger.warning(f"🎀 Failed to get speakers: status {response.status}") + logger.warning( + f"🎀 Failed to get speakers: status {response.status}" + ) return None result = await response.json() speakers = result.get("speakers", []) - + # Case-insensitive name match for speaker in speakers: if speaker["name"].lower() == speaker_name.lower(): - logger.info(f"🎀 Found speaker '{speaker_name}' with ID: {speaker['id']}") + logger.info( + f"🎀 Found speaker '{speaker_name}' with ID: {speaker['id']}" + ) return speaker - - logger.info(f"🎀 Speaker '{speaker_name}' not found in {len(speakers)} enrolled speakers") + + logger.info( + f"🎀 Speaker '{speaker_name}' not found in {len(speakers)} enrolled speakers" + ) return None except aiohttp.ClientError as e: @@ -1061,8 +1243,10 @@ async def enroll_new_speaker( # Generate speaker ID: user_{user_id}_speaker_{random_hex} speaker_id = f"user_{user_id}_speaker_{uuid.uuid4().hex[:12]}" - - logger.info(f"🎀 Enrolling new speaker '{speaker_name}' with ID: {speaker_id}") + + logger.info( + f"🎀 Enrolling new speaker '{speaker_name}' with ID: {speaker_id}" + ) async with aiohttp.ClientSession() as session: form_data = aiohttp.FormData() @@ -1116,7 +1300,10 @@ async def append_to_speaker(self, speaker_id: str, audio_data: bytes) -> Dict: async with aiohttp.ClientSession() as session: form_data = aiohttp.FormData() form_data.add_field( - "files", audio_data, filename="segment.wav", content_type="audio/wav" + "files", + audio_data, + filename="segment.wav", + content_type="audio/wav", ) form_data.add_field("speaker_id", speaker_id) @@ -1149,7 +1336,7 @@ async def check_if_enrolled_speaker_present( client_id: str, session_id: str, user_id: str, - transcription_results: List[dict] + transcription_results: List[dict], ) -> tuple[bool, dict]: """ Check if any enrolled speakers are present in the transcription results. @@ -1173,19 +1360,29 @@ async def check_if_enrolled_speaker_present( extract_audio_for_results, ) - logger.info(f"🎀 [SPEAKER CHECK] Starting speaker check for session {session_id}") + logger.info( + f"🎀 [SPEAKER CHECK] Starting speaker check for session {session_id}" + ) logger.info(f"🎀 [SPEAKER CHECK] Client: {client_id}, User: {user_id}") - logger.info(f"🎀 [SPEAKER CHECK] Transcription results count: {len(transcription_results)}") + logger.info( + f"🎀 [SPEAKER CHECK] Transcription results count: {len(transcription_results)}" + ) # Get enrolled speakers for this user - logger.info(f"🎀 [SPEAKER CHECK] Fetching enrolled speakers for user {user_id}...") + logger.info( + f"🎀 [SPEAKER CHECK] Fetching enrolled speakers for user {user_id}..." + ) enrolled_result = await self.get_enrolled_speakers(user_id) - enrolled_speakers = set(speaker["name"] for speaker in enrolled_result.get("speakers", [])) + enrolled_speakers = set( + speaker["name"] for speaker in enrolled_result.get("speakers", []) + ) logger.info(f"🎀 [SPEAKER CHECK] Enrolled speakers: {enrolled_speakers}") if not enrolled_speakers: - logger.warning("🎀 [SPEAKER CHECK] No enrolled speakers found, allowing conversation") + logger.warning( + "🎀 [SPEAKER CHECK] No enrolled speakers found, allowing conversation" + ) return (True, {}) # If no enrolled speakers, allow all conversations # Extract audio chunks (PCM format) @@ -1194,11 +1391,13 @@ async def check_if_enrolled_speaker_present( redis_client=redis_client, client_id=client_id, session_id=session_id, - transcription_results=transcription_results + transcription_results=transcription_results, ) if not pcm_data: - logger.warning("🎀 [SPEAKER CHECK] No audio data extracted, skipping speaker check") + logger.warning( + "🎀 [SPEAKER CHECK] No audio data extracted, skipping speaker check" + ) return (False, {}) audio_size_kb = len(pcm_data) / 1024 @@ -1211,17 +1410,23 @@ async def check_if_enrolled_speaker_present( from advanced_omi_backend.utils.audio_utils import pcm_to_wav_bytes logger.info(f"🎀 [SPEAKER CHECK] Converting PCM to WAV in memory...") - wav_data = pcm_to_wav_bytes(pcm_data, sample_rate=16000, channels=1, sample_width=2) + wav_data = pcm_to_wav_bytes( + pcm_data, sample_rate=16000, channels=1, sample_width=2 + ) - logger.info(f"🎀 [SPEAKER CHECK] WAV created in memory: {len(wav_data) / 1024 / 1024:.2f} MB") + logger.info( + f"🎀 [SPEAKER CHECK] WAV created in memory: {len(wav_data) / 1024 / 1024:.2f} MB" + ) try: # Run speaker recognition (diarize and identify) with in-memory audio - logger.info(f"🎀 [SPEAKER CHECK] Calling diarize_and_identify with in-memory audio...") + logger.info( + f"🎀 [SPEAKER CHECK] Calling diarize_and_identify with in-memory audio..." + ) result = await self.diarize_and_identify( audio_data=wav_data, # Pass bytes directly, no temp file! words=None, - user_id=user_id + user_id=user_id, ) logger.info(f"🎀 [SPEAKER CHECK] Speaker recognition result: {result}") @@ -1229,7 +1434,9 @@ async def check_if_enrolled_speaker_present( # Check if any identified speakers are enrolled identified_speakers = set() segments_count = len(result.get("segments", [])) - logger.info(f"🎀 [SPEAKER CHECK] Processing {segments_count} segments from speaker recognition") + logger.info( + f"🎀 [SPEAKER CHECK] Processing {segments_count} segments from speaker recognition" + ) for idx, segment in enumerate(result.get("segments", [])): identified_name = segment.get("identified_as") @@ -1245,25 +1452,40 @@ async def check_if_enrolled_speaker_present( if identified_name and identified_name != "Unknown": identified_speakers.add(identified_name) - logger.info(f"🎀 [SPEAKER CHECK] Found identified speaker: {identified_name}") + logger.info( + f"🎀 [SPEAKER CHECK] Found identified speaker: {identified_name}" + ) - logger.info(f"🎀 [SPEAKER CHECK] All identified speakers: {identified_speakers}") + logger.info( + f"🎀 [SPEAKER CHECK] All identified speakers: {identified_speakers}" + ) logger.info(f"🎀 [SPEAKER CHECK] Enrolled speakers: {enrolled_speakers}") matches = enrolled_speakers & identified_speakers if matches: - logger.info(f"🎀 [SPEAKER CHECK] βœ… MATCH! Enrolled speaker(s) detected: {matches}") - return (True, result) # Return both boolean and speaker recognition results + logger.info( + f"🎀 [SPEAKER CHECK] βœ… MATCH! Enrolled speaker(s) detected: {matches}" + ) + return ( + True, + result, + ) # Return both boolean and speaker recognition results else: logger.info( f"🎀 [SPEAKER CHECK] ❌ NO MATCH. " f"Identified: {identified_speakers}, Enrolled: {enrolled_speakers}" ) - return (False, result) # Return both boolean and speaker recognition results + return ( + False, + result, + ) # Return both boolean and speaker recognition results except Exception as e: - logger.error(f"🎀 [SPEAKER CHECK] ❌ Speaker recognition check failed: {e}", exc_info=True) + logger.error( + f"🎀 [SPEAKER CHECK] ❌ Speaker recognition check failed: {e}", + exc_info=True, + ) return (False, {}) # Fail closed - don't create conversation on error async def health_check(self) -> bool: @@ -1277,7 +1499,9 @@ async def health_check(self) -> bool: return False try: - logger.debug(f"Performing health check on speaker service: {self.service_url}") + logger.debug( + f"Performing health check on speaker service: {self.service_url}" + ) async with aiohttp.ClientSession() as session: # Use the /health endpoint if available, otherwise try a simple endpoint @@ -1290,12 +1514,18 @@ async def health_check(self) -> bool: timeout=aiohttp.ClientTimeout(total=5), ) as response: if response.status == 200: - logger.debug(f"Speaker service health check passed via {endpoint}") + logger.debug( + f"Speaker service health check passed via {endpoint}" + ) return True else: - logger.debug(f"Health check endpoint {endpoint} returned {response.status}") + logger.debug( + f"Health check endpoint {endpoint} returned {response.status}" + ) except Exception as endpoint_error: - logger.debug(f"Health check failed for {endpoint}: {endpoint_error}") + logger.debug( + f"Health check failed for {endpoint}: {endpoint_error}" + ) continue logger.warning("All health check endpoints failed") diff --git a/backends/advanced/src/advanced_omi_backend/task_manager.py b/backends/advanced/src/advanced_omi_backend/task_manager.py index b93a397d..2fe72c6c 100644 --- a/backends/advanced/src/advanced_omi_backend/task_manager.py +++ b/backends/advanced/src/advanced_omi_backend/task_manager.py @@ -68,7 +68,9 @@ async def shutdown(self): if active_tasks: tasks = [info.task for info in active_tasks] try: - await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=30.0) + await asyncio.wait_for( + asyncio.gather(*tasks, return_exceptions=True), timeout=30.0 + ) except asyncio.TimeoutError: logger.warning("Some tasks did not complete within shutdown timeout") @@ -172,7 +174,9 @@ async def _periodic_cleanup(self): long_running.append(f"{task_info.name} ({age:.0f}s)") if long_running: - logger.info(f"Long-running tasks: {', '.join(long_running[:5])}") + logger.info( + f"Long-running tasks: {', '.join(long_running[:5])}" + ) except Exception as e: logger.error(f"Error in periodic cleanup: {e}", exc_info=True) @@ -237,7 +241,9 @@ async def cancel_tasks_for_client(self, client_id: str, timeout: float = 30.0): for task_info in client_tasks: task_type = task_info.metadata.get("type", "") # Check if this is a processing task that should continue - is_processing_task = any(task_type.startswith(pt) for pt in PROCESSING_TASK_TYPES) + is_processing_task = any( + task_type.startswith(pt) for pt in PROCESSING_TASK_TYPES + ) if is_processing_task: tasks_to_preserve.append(task_info) @@ -278,7 +284,9 @@ async def cancel_tasks_for_client(self, client_id: str, timeout: float = 30.0): asyncio.gather(*tasks, return_exceptions=True), timeout=timeout ) except asyncio.TimeoutError: - logger.warning(f"Some tasks for client {client_id} did not complete within timeout") + logger.warning( + f"Some tasks for client {client_id} did not complete within timeout" + ) def get_health_status(self) -> Dict[str, Any]: """Get health status of the task manager.""" @@ -333,5 +341,7 @@ def init_task_manager() -> BackgroundTaskManager: def get_task_manager() -> BackgroundTaskManager: """Get the global task manager instance.""" if _task_manager is None: - raise RuntimeError("BackgroundTaskManager not initialized. Call init_task_manager first.") + raise RuntimeError( + "BackgroundTaskManager not initialized. Call init_task_manager first." + ) return _task_manager diff --git a/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py b/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py index 0e9f4cae..15e968d2 100644 --- a/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py +++ b/backends/advanced/src/advanced_omi_backend/testing/mock_speaker_client.py @@ -33,7 +33,7 @@ class MockSpeakerRecognitionClient: "speaker": 0, "identified_as": "Unknown", "text": "The pumpkin that'll last for forever. Finally. Does it count? Today, we're taking a glass blowing class.", - "confidence": 0.95 + "confidence": 0.95, }, { "start": 10.28, @@ -41,7 +41,7 @@ class MockSpeakerRecognitionClient: "speaker": 0, "identified_as": "Unknown", "text": "I'm sweating already. We've worked with a lot of materials before, but we've only scratched the surface", - "confidence": 0.93 + "confidence": 0.93, }, { "start": 20.455, @@ -49,7 +49,7 @@ class MockSpeakerRecognitionClient: "speaker": 1, "identified_as": "Unknown", "text": "when it comes to glass", - "confidence": 0.91 + "confidence": 0.91, }, { "start": 22.095, @@ -57,7 +57,7 @@ class MockSpeakerRecognitionClient: "speaker": 0, "identified_as": "Unknown", "text": "and that's because", - "confidence": 0.94 + "confidence": 0.94, }, { "start": 23.815, @@ -65,7 +65,7 @@ class MockSpeakerRecognitionClient: "speaker": 1, "identified_as": "Unknown", "text": "a little intimidating. We've got about 400 pounds", - "confidence": 0.92 + "confidence": 0.92, }, { "start": 28.335, @@ -73,7 +73,7 @@ class MockSpeakerRecognitionClient: "speaker": 0, "identified_as": "Unknown", "text": "of liquid glass in this furnace right here. Nick's gonna really help us out. Nick, I'm excited and nervous. Me too.", - "confidence": 0.96 + "confidence": 0.96, }, { "start": 43.28, @@ -81,7 +81,7 @@ class MockSpeakerRecognitionClient: "speaker": 1, "identified_as": "Unknown", "text": "So we're gonna", - "confidence": 0.90 + "confidence": 0.90, }, { "start": 44.68, @@ -89,7 +89,7 @@ class MockSpeakerRecognitionClient: "speaker": 0, "identified_as": "Unknown", "text": "make what's called a trumpet", - "confidence": 0.95 + "confidence": 0.95, }, { "start": 46.96, @@ -97,8 +97,8 @@ class MockSpeakerRecognitionClient: "speaker": 0, "identified_as": "Unknown", "text": "flower. We're using gravity as a tool.", - "confidence": 0.93 - } + "confidence": 0.93, + }, ] } @@ -111,7 +111,7 @@ async def diarize_identify_match( conversation_id: str, backend_token: str, transcript_data: Dict, - user_id: Optional[str] = None + user_id: Optional[str] = None, ) -> Dict: """ Return pre-computed mock segments for known test audio files. @@ -125,7 +125,9 @@ async def diarize_identify_match( Returns: Dictionary with 'segments' array matching speaker service format """ - logger.info(f"🎀 Mock speaker client processing conversation: {conversation_id[:12]}...") + logger.info( + f"🎀 Mock speaker client processing conversation: {conversation_id[:12]}..." + ) # Try to identify which test audio this is from the transcript transcript_text = transcript_data.get("text", "").lower() @@ -135,11 +137,15 @@ async def diarize_identify_match( filename = "DIY_Experts_Glass_Blowing_16khz_mono_1min.wav" if filename in self.MOCK_SEGMENTS: segments = self.MOCK_SEGMENTS[filename] - logger.info(f"🎀 Mock returning {len(segments)} segments for DIY Glass Blowing audio") + logger.info( + f"🎀 Mock returning {len(segments)} segments for DIY Glass Blowing audio" + ) return {"segments": segments} # Fallback: Create single generic segment - logger.warning(f"🎀 Mock: No pre-computed segments found, creating generic segment") + logger.warning( + f"🎀 Mock: No pre-computed segments found, creating generic segment" + ) # Get duration from words if available words = transcript_data.get("words", []) @@ -149,14 +155,16 @@ async def diarize_identify_match( duration = 60.0 return { - "segments": [{ - "start": 0.0, - "end": duration, - "speaker": 0, - "identified_as": "Unknown", - "text": transcript_data.get("text", ""), - "confidence": 0.85 - }] + "segments": [ + { + "start": 0.0, + "end": duration, + "speaker": 0, + "identified_as": "Unknown", + "text": transcript_data.get("text", ""), + "confidence": 0.85, + } + ] } async def identify_segment( diff --git a/backends/advanced/src/advanced_omi_backend/utils/audio_chunk_utils.py b/backends/advanced/src/advanced_omi_backend/utils/audio_chunk_utils.py index 52017932..17fb1099 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/audio_chunk_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/audio_chunk_utils.py @@ -51,8 +51,11 @@ async def encode_pcm_to_opus( >>> # opus_bytes is ~30KB vs 320KB PCM (94% reduction) """ # Create temporary files for FFmpeg I/O - with tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as pcm_file, \ - tempfile.NamedTemporaryFile(suffix=".opus", delete=False) as opus_file: + with tempfile.NamedTemporaryFile( + suffix=".pcm", delete=False + ) as pcm_file, tempfile.NamedTemporaryFile( + suffix=".opus", delete=False + ) as opus_file: pcm_path = Path(pcm_file.name) opus_path = Path(opus_file.name) @@ -72,14 +75,22 @@ async def encode_pcm_to_opus( # -application voip: optimize for speech cmd = [ "ffmpeg", - "-f", "s16le", - "-ar", str(sample_rate), - "-ac", str(channels), - "-i", str(pcm_path), - "-c:a", "libopus", - "-b:a", f"{bitrate}k", - "-vbr", "on", - "-application", "voip", + "-f", + "s16le", + "-ar", + str(sample_rate), + "-ac", + str(channels), + "-i", + str(pcm_path), + "-c:a", + "libopus", + "-b:a", + f"{bitrate}k", + "-vbr", + "on", + "-application", + "voip", "-y", # Overwrite output str(opus_path), ] @@ -140,8 +151,11 @@ async def decode_opus_to_pcm( >>> # pcm_bytes can be played or concatenated """ # Create temporary files for FFmpeg I/O - with tempfile.NamedTemporaryFile(suffix=".opus", delete=False) as opus_file, \ - tempfile.NamedTemporaryFile(suffix=".pcm", delete=False) as pcm_file: + with tempfile.NamedTemporaryFile( + suffix=".opus", delete=False + ) as opus_file, tempfile.NamedTemporaryFile( + suffix=".pcm", delete=False + ) as pcm_file: opus_path = Path(opus_file.name) pcm_path = Path(pcm_file.name) @@ -158,10 +172,14 @@ async def decode_opus_to_pcm( # -ac: convert to target channel count cmd = [ "ffmpeg", - "-i", str(opus_path), - "-f", "s16le", - "-ar", str(sample_rate), - "-ac", str(channels), + "-i", + str(opus_path), + "-f", + "s16le", + "-ar", + str(sample_rate), + "-ac", + str(channels), "-y", # Overwrite output str(pcm_path), ] @@ -322,9 +340,7 @@ async def concatenate_chunks_to_pcm( # Append to buffer pcm_buffer.extend(pcm_data) - logger.debug( - f"Concatenated {len(chunks)} chunks β†’ {len(pcm_buffer)} bytes PCM" - ) + logger.debug(f"Concatenated {len(chunks)} chunks β†’ {len(pcm_buffer)} bytes PCM") return bytes(pcm_buffer) @@ -369,9 +385,7 @@ async def reconstruct_wav_from_conversation( ) if not chunks: - raise ValueError( - f"No audio chunks found for conversation {conversation_id}" - ) + raise ValueError(f"No audio chunks found for conversation {conversation_id}") # Get audio format from first chunk sample_rate = chunks[0].sample_rate @@ -439,7 +453,9 @@ async def reconstruct_audio_segments( total_duration = conversation.audio_total_duration or 0.0 if total_duration == 0: - logger.warning(f"Conversation {conversation_id} has zero duration, no segments to yield") + logger.warning( + f"Conversation {conversation_id} has zero duration, no segments to yield" + ) return # Get audio format from first chunk @@ -462,11 +478,17 @@ async def reconstruct_audio_segments( # Get chunks that overlap with this time range # Note: Using start_time and end_time fields from chunks - chunks = await AudioChunkDocument.find( - AudioChunkDocument.conversation_id == conversation_id, - AudioChunkDocument.start_time < end_time, # Chunk starts before segment ends - AudioChunkDocument.end_time > start_time, # Chunk ends after segment starts - ).sort(+AudioChunkDocument.chunk_index).to_list() + chunks = ( + await AudioChunkDocument.find( + AudioChunkDocument.conversation_id == conversation_id, + AudioChunkDocument.start_time + < end_time, # Chunk starts before segment ends + AudioChunkDocument.end_time + > start_time, # Chunk ends after segment starts + ) + .sort(+AudioChunkDocument.chunk_index) + .to_list() + ) if not chunks: logger.warning( @@ -499,9 +521,7 @@ async def reconstruct_audio_segments( async def reconstruct_audio_segment( - conversation_id: str, - start_time: float, - end_time: float + conversation_id: str, start_time: float, end_time: float ) -> bytes: """ Reconstruct audio for a specific time range from MongoDB chunks. @@ -571,11 +591,16 @@ async def reconstruct_audio_segment( channels = first_chunk.channels # Get chunks that overlap with this time range - chunks = await AudioChunkDocument.find( - AudioChunkDocument.conversation_id == conversation_id, - AudioChunkDocument.start_time < end_time, # Chunk starts before segment ends - AudioChunkDocument.end_time > start_time, # Chunk ends after segment starts - ).sort(+AudioChunkDocument.chunk_index).to_list() + chunks = ( + await AudioChunkDocument.find( + AudioChunkDocument.conversation_id == conversation_id, + AudioChunkDocument.start_time + < end_time, # Chunk starts before segment ends + AudioChunkDocument.end_time > start_time, # Chunk ends after segment starts + ) + .sort(+AudioChunkDocument.chunk_index) + .to_list() + ) if not chunks: logger.warning( @@ -654,9 +679,7 @@ async def reconstruct_audio_segment( def filter_transcript_by_time( - transcript_data: dict, - start_time: float, - end_time: float + transcript_data: dict, start_time: float, end_time: float ) -> dict: """ Filter transcript data to only include words within a time range. @@ -695,10 +718,7 @@ def filter_transcript_by_time( # Rebuild text from filtered words filtered_text = " ".join(word.get("word", "") for word in filtered_words) - return { - "text": filtered_text, - "words": filtered_words - } + return {"text": filtered_text, "words": filtered_words} async def convert_audio_to_chunks( @@ -785,7 +805,7 @@ async def convert_audio_to_chunks( pcm_data=chunk_pcm, sample_rate=sample_rate, channels=channels, - bitrate=24 # 24kbps for speech + bitrate=24, # 24kbps for speech ) # Create MongoDB document @@ -830,19 +850,29 @@ async def convert_audio_to_chunks( ) if conversation: - compression_ratio = total_compressed_size / total_original_size if total_original_size > 0 else 0.0 + compression_ratio = ( + total_compressed_size / total_original_size + if total_original_size > 0 + else 0.0 + ) - logger.info(f"πŸ” DEBUG: Setting metadata - chunks={chunk_index}, duration={total_duration_seconds:.2f}s, ratio={compression_ratio:.3f}") + logger.info( + f"πŸ” DEBUG: Setting metadata - chunks={chunk_index}, duration={total_duration_seconds:.2f}s, ratio={compression_ratio:.3f}" + ) conversation.audio_chunks_count = chunk_index conversation.audio_total_duration = total_duration_seconds conversation.audio_compression_ratio = compression_ratio - logger.info(f"πŸ” DEBUG: Before save - chunks={conversation.audio_chunks_count}, duration={conversation.audio_total_duration}") + logger.info( + f"πŸ” DEBUG: Before save - chunks={conversation.audio_chunks_count}, duration={conversation.audio_total_duration}" + ) await conversation.save() logger.info(f"πŸ” DEBUG: After save - metadata should be persisted") else: - logger.error(f"❌ Conversation {conversation_id} not found for metadata update!") + logger.error( + f"❌ Conversation {conversation_id} not found for metadata update!" + ) logger.info( f"βœ… Converted audio to {chunk_index} MongoDB chunks: " @@ -899,6 +929,7 @@ async def convert_wav_to_chunks( # Read WAV file import wave + with wave.open(str(wav_file_path), "rb") as wav: sample_rate = wav.getframerate() channels = wav.getnchannels() @@ -952,7 +983,7 @@ async def convert_wav_to_chunks( pcm_data=chunk_pcm, sample_rate=sample_rate, channels=channels, - bitrate=24 # 24kbps for speech + bitrate=24, # 24kbps for speech ) # Create MongoDB document @@ -997,19 +1028,29 @@ async def convert_wav_to_chunks( ) if conversation: - compression_ratio = total_compressed_size / total_original_size if total_original_size > 0 else 0.0 + compression_ratio = ( + total_compressed_size / total_original_size + if total_original_size > 0 + else 0.0 + ) - logger.info(f"πŸ” DEBUG: Setting metadata - chunks={chunk_index}, duration={total_duration_seconds:.2f}s, ratio={compression_ratio:.3f}") + logger.info( + f"πŸ” DEBUG: Setting metadata - chunks={chunk_index}, duration={total_duration_seconds:.2f}s, ratio={compression_ratio:.3f}" + ) conversation.audio_chunks_count = chunk_index conversation.audio_total_duration = total_duration_seconds conversation.audio_compression_ratio = compression_ratio - logger.info(f"πŸ” DEBUG: Before save - chunks={conversation.audio_chunks_count}, duration={conversation.audio_total_duration}") + logger.info( + f"πŸ” DEBUG: Before save - chunks={conversation.audio_chunks_count}, duration={conversation.audio_total_duration}" + ) await conversation.save() logger.info(f"πŸ” DEBUG: After save - metadata should be persisted") else: - logger.error(f"❌ Conversation {conversation_id} not found for metadata update!") + logger.error( + f"❌ Conversation {conversation_id} not found for metadata update!" + ) logger.info( f"βœ… Converted WAV to {chunk_index} MongoDB chunks: " @@ -1058,7 +1099,7 @@ async def wait_for_audio_chunks( chunks = await retrieve_audio_chunks( conversation_id=conversation_id, start_index=0, - limit=1 # Just check if any exist + limit=1, # Just check if any exist ) if len(chunks) >= min_chunks: diff --git a/backends/advanced/src/advanced_omi_backend/utils/audio_extraction.py b/backends/advanced/src/advanced_omi_backend/utils/audio_extraction.py index df999a10..8a140874 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/audio_extraction.py +++ b/backends/advanced/src/advanced_omi_backend/utils/audio_extraction.py @@ -28,10 +28,7 @@ def parse_chunk_range(chunk_id: str) -> Tuple[int, int]: async def extract_audio_for_results( - redis_client, - client_id: str, - session_id: str, - transcription_results: List[dict] + redis_client, client_id: str, session_id: str, transcription_results: List[dict] ) -> bytes: """ Extract audio chunks for transcription results. @@ -47,8 +44,12 @@ async def extract_audio_for_results( Returns: Combined audio bytes for all chunks in results """ - logger.info(f"🎡 [AUDIO EXTRACT] Starting audio extraction for session {session_id}") - logger.info(f"🎡 [AUDIO EXTRACT] Client: {client_id}, Results count: {len(transcription_results)}") + logger.info( + f"🎡 [AUDIO EXTRACT] Starting audio extraction for session {session_id}" + ) + logger.info( + f"🎡 [AUDIO EXTRACT] Client: {client_id}, Results count: {len(transcription_results)}" + ) if not transcription_results: logger.warning(f"🎡 [AUDIO EXTRACT] No transcription results provided") @@ -64,7 +65,9 @@ async def extract_audio_for_results( chunk_ranges.append((start, end)) if not chunk_ranges: - logger.warning("🎡 [AUDIO EXTRACT] No chunk ranges found in transcription results") + logger.warning( + "🎡 [AUDIO EXTRACT] No chunk ranges found in transcription results" + ) return b"" # Find overall range @@ -108,7 +111,9 @@ async def extract_audio_for_results( if min_chunk <= chunk_num <= max_chunk: audio_data = fields.get(b"audio_data", b"") audio_chunks[chunk_num] = audio_data - logger.debug(f"🎡 [AUDIO EXTRACT] Collected chunk {chunk_num}: {len(audio_data)} bytes") + logger.debug( + f"🎡 [AUDIO EXTRACT] Collected chunk {chunk_num}: {len(audio_data)} bytes" + ) # Combine chunks in order sorted_chunks = sorted(audio_chunks.items()) @@ -123,7 +128,8 @@ async def extract_audio_for_results( logger.warning(f"🎡 [AUDIO EXTRACT] ⚠️ No audio data collected!") elif len(sorted_chunks) < (max_chunk - min_chunk + 1): missing_chunks = (max_chunk - min_chunk + 1) - len(sorted_chunks) - logger.warning(f"🎡 [AUDIO EXTRACT] ⚠️ Missing {missing_chunks} chunks from expected range") + logger.warning( + f"🎡 [AUDIO EXTRACT] ⚠️ Missing {missing_chunks} chunks from expected range" + ) return combined_audio - diff --git a/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py b/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py index 4abb1d5d..4d309a32 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/audio_utils.py @@ -24,8 +24,12 @@ audio_logger = logging.getLogger("audio_processing") # Import constants from main.py (these are defined there) -MIN_SPEECH_SEGMENT_DURATION = float(os.getenv("MIN_SPEECH_SEGMENT_DURATION", "1.0")) # seconds -CROPPING_CONTEXT_PADDING = float(os.getenv("CROPPING_CONTEXT_PADDING", "0.1")) # seconds +MIN_SPEECH_SEGMENT_DURATION = float( + os.getenv("MIN_SPEECH_SEGMENT_DURATION", "1.0") +) # seconds +CROPPING_CONTEXT_PADDING = float( + os.getenv("CROPPING_CONTEXT_PADDING", "0.1") +) # seconds SUPPORTED_AUDIO_EXTENSIONS = {".wav", ".mp3", ".mp4", ".m4a", ".flac", ".ogg", ".webm"} VIDEO_EXTENSIONS = {".mp4", ".webm"} @@ -33,6 +37,7 @@ class AudioValidationError(Exception): """Exception raised when audio validation fails.""" + pass @@ -42,7 +47,7 @@ async def resample_audio_with_ffmpeg( input_channels: int, input_sample_width: int, target_sample_rate: int, - target_channels: int = 1 + target_channels: int = 1, ) -> bytes: """ Resample audio using FFmpeg with stdin/stdout pipes (no disk I/O). @@ -75,13 +80,20 @@ async def resample_audio_with_ffmpeg( # pipe:0 = stdin, pipe:1 = stdout cmd = [ "ffmpeg", - "-f", input_format, - "-ar", str(input_sample_rate), - "-ac", str(input_channels), - "-i", "pipe:0", # Read from stdin - "-ar", str(target_sample_rate), - "-ac", str(target_channels), - "-f", "s16le", # Always output 16-bit + "-f", + input_format, + "-ar", + str(input_sample_rate), + "-ac", + str(input_channels), + "-i", + "pipe:0", # Read from stdin + "-ar", + str(target_sample_rate), + "-ac", + str(target_channels), + "-f", + "s16le", # Always output 16-bit "pipe:1", # Write to stdout ] @@ -133,12 +145,17 @@ async def convert_any_to_wav(file_data: bytes, file_extension: str) -> bytes: cmd = [ "ffmpeg", - "-i", "pipe:0", + "-i", + "pipe:0", "-vn", # Strip video track (no-op for audio-only files) - "-acodec", "pcm_s16le", - "-ar", "16000", - "-ac", "1", - "-f", "wav", + "-acodec", + "pcm_s16le", + "-ar", + "16000", + "-ac", + "1", + "-f", + "wav", "pipe:1", ] @@ -156,9 +173,7 @@ async def convert_any_to_wav(file_data: bytes, file_extension: str) -> bytes: audio_logger.error(f"FFmpeg conversion failed for {ext}: {error_msg}") raise AudioValidationError(f"Failed to convert {ext} file to WAV: {error_msg}") - audio_logger.info( - f"Converted {ext} to WAV: {len(file_data)} β†’ {len(stdout)} bytes" - ) + audio_logger.info(f"Converted {ext} to WAV: {len(file_data)} β†’ {len(stdout)} bytes") return stdout @@ -167,7 +182,7 @@ async def validate_and_prepare_audio( audio_data: bytes, expected_sample_rate: int = 16000, convert_to_mono: bool = True, - auto_resample: bool = False + auto_resample: bool = False, ) -> tuple[bytes, int, int, int, float]: """ Validate WAV audio data and prepare it for processing. @@ -212,7 +227,7 @@ async def validate_and_prepare_audio( input_channels=channels, input_sample_width=sample_width, target_sample_rate=expected_sample_rate, - target_channels=1 if convert_to_mono else channels + target_channels=1 if convert_to_mono else channels, ) # Update metadata after resampling sample_rate = expected_sample_rate @@ -243,7 +258,9 @@ async def validate_and_prepare_audio( # Reshape to separate channels and average audio_array = audio_array.reshape(-1, 2) - processed_audio = np.mean(audio_array, axis=1).astype(audio_array.dtype).tobytes() + processed_audio = ( + np.mean(audio_array, axis=1).astype(audio_array.dtype).tobytes() + ) channels = 1 audio_logger.debug( @@ -301,8 +318,9 @@ async def write_audio_file( # Validate and prepare audio if needed if validate: - audio_data, sample_rate, sample_width, channels, duration = \ + audio_data, sample_rate, sample_width, channels, duration = ( await validate_and_prepare_audio(raw_audio_data) + ) else: # For WebSocket/streaming path - audio is already processed PCM audio_data = raw_audio_data @@ -325,7 +343,7 @@ async def write_audio_file( # If output_dir is a subdirectory of CHUNK_DIR, include the folder prefix try: relative_path_parts = output_dir.relative_to(CHUNK_DIR) - if str(relative_path_parts) != '.': + if str(relative_path_parts) != ".": relative_audio_path = f"{relative_path_parts}/{wav_filename}" else: relative_audio_path = wav_filename @@ -338,15 +356,12 @@ async def write_audio_file( file_path=str(file_path), sample_rate=int(sample_rate), channels=int(channels), - sample_width=int(sample_width) + sample_width=int(sample_width), ) await sink.open() audio_chunk = AudioChunk( - rate=sample_rate, - width=sample_width, - channels=channels, - audio=audio_data + rate=sample_rate, width=sample_width, channels=channels, audio=audio_data ) await sink.write(audio_chunk) await sink.close() @@ -364,7 +379,7 @@ async def process_audio_chunk( user_id: str, user_email: str, audio_format: dict, - client_state: Optional["ClientState"] = None + client_state: Optional["ClientState"] = None, ) -> None: """Process a single audio chunk through Redis Streams pipeline. @@ -396,11 +411,7 @@ async def process_audio_chunk( # Create AudioChunk with format details chunk = AudioChunk( - audio=audio_data, - rate=rate, - width=width, - channels=channels, - timestamp=timestamp + audio=audio_data, rate=rate, width=width, channels=channels, timestamp=timestamp ) # Publish audio chunk to Redis Streams @@ -411,7 +422,7 @@ async def process_audio_chunk( user_email=user_email, audio_chunk=chunk, audio_uuid=None, # Will be generated by worker - timestamp=timestamp + timestamp=timestamp, ) # Update client state if provided @@ -420,10 +431,7 @@ async def process_audio_chunk( def pcm_to_wav_bytes( - pcm_data: bytes, - sample_rate: int = 16000, - channels: int = 1, - sample_width: int = 2 + pcm_data: bytes, sample_rate: int = 16000, channels: int = 1, sample_width: int = 2 ) -> bytes: """ Convert raw PCM audio data to WAV format in memory. @@ -448,7 +456,7 @@ def pcm_to_wav_bytes( # Use BytesIO to create WAV in memory wav_buffer = io.BytesIO() - with wave.open(wav_buffer, 'wb') as wav_file: + with wave.open(wav_buffer, "wb") as wav_file: wav_file.setnchannels(channels) wav_file.setsampwidth(sample_width) wav_file.setframerate(sample_rate) @@ -467,7 +475,7 @@ def write_pcm_to_wav( output_path: str, sample_rate: int = 16000, channels: int = 1, - sample_width: int = 2 + sample_width: int = 2, ) -> None: """ Write raw PCM audio data to a WAV file. @@ -487,7 +495,7 @@ def write_pcm_to_wav( ) try: - with wave.open(output_path, 'wb') as wav_file: + with wave.open(output_path, "wb") as wav_file: wav_file.setnchannels(channels) wav_file.setsampwidth(sample_width) wav_file.setframerate(sample_rate) diff --git a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py index 63036ce1..5dc32bae 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/conversation_utils.py @@ -43,7 +43,10 @@ def is_meaningful_speech(combined_results: dict) -> bool: if not combined_results.get("text"): return False - transcript_data = {"text": combined_results["text"], "words": combined_results.get("words", [])} + transcript_data = { + "text": combined_results["text"], + "words": combined_results.get("words", []), + } speech_analysis = analyze_speech(transcript_data) return speech_analysis["has_speech"] @@ -83,19 +86,25 @@ def analyze_speech(transcript_data: dict) -> dict: settings = get_speech_detection_settings() words = transcript_data.get("words", []) - logger.info(f"πŸ”¬ analyze_speech: words_list_length={len(words)}, settings={settings}") + logger.info( + f"πŸ”¬ analyze_speech: words_list_length={len(words)}, settings={settings}" + ) if words and len(words) > 0: logger.info(f"πŸ“ First 3 words: {words[:3]}") # Method 1: Word-level analysis (preferred - has confidence scores and timing) if words: # Filter by confidence threshold - valid_words = [w for w in words if (w.get("confidence") or 0) >= settings["min_confidence"]] + valid_words = [ + w for w in words if (w.get("confidence") or 0) >= settings["min_confidence"] + ] if len(valid_words) < settings["min_words"]: # Not enough valid words in word-level data - fall through to text-only analysis # This handles cases where word-level data is incomplete or low confidence - logger.debug(f"Only {len(valid_words)} valid words, falling back to text-only analysis") + logger.debug( + f"Only {len(valid_words)} valid words, falling back to text-only analysis" + ) # Continue to Method 2 (don't return early) else: # Calculate speech duration from word timing @@ -113,12 +122,16 @@ def analyze_speech(transcript_data: dict) -> dict: # If no timing data (duration = 0), fall back to text-only analysis # This happens with some streaming transcription services if speech_duration == 0: - logger.debug("Word timing data missing, falling back to text-only analysis") + logger.debug( + "Word timing data missing, falling back to text-only analysis" + ) # Continue to Method 2 (text-only fallback) else: # Check minimum duration threshold when we have timing data min_duration = settings.get("min_duration", 10.0) - logger.info(f"πŸ“ Comparing duration {speech_duration:.1f}s vs threshold {min_duration:.1f}s") + logger.info( + f"πŸ“ Comparing duration {speech_duration:.1f}s vs threshold {min_duration:.1f}s" + ) if speech_duration < min_duration: return { "has_speech": False, @@ -222,7 +235,9 @@ async def generate_title_and_summary( "{conversation_text}" """ - response = await async_generate(prompt, operation="title_summary", langfuse_session_id=langfuse_session_id) + response = await async_generate( + prompt, operation="title_summary", langfuse_session_id=langfuse_session_id + ) # Parse response for Title: and Summary: lines title = None @@ -244,12 +259,13 @@ async def generate_title_and_summary( # Fallback words = text.split()[:6] fallback_title = " ".join(words) - fallback_title = fallback_title[:40] + "..." if len(fallback_title) > 40 else fallback_title + fallback_title = ( + fallback_title[:40] + "..." if len(fallback_title) > 40 else fallback_title + ) fallback_summary = text[:120] + "..." if len(text) > 120 else text return fallback_title or "Conversation", fallback_summary or "No content" - async def generate_detailed_summary( text: str, segments: Optional[list] = None, @@ -330,8 +346,15 @@ async def generate_detailed_summary( "{conversation_text}" """ - summary = await async_generate(prompt, operation="detailed_summary", langfuse_session_id=langfuse_session_id) - return summary.strip().strip('"').strip("'") or "No meaningful content to summarize" + summary = await async_generate( + prompt, + operation="detailed_summary", + langfuse_session_id=langfuse_session_id, + ) + return ( + summary.strip().strip('"').strip("'") + or "No meaningful content to summarize" + ) except Exception as e: logger.warning(f"Failed to generate detailed summary: {e}") @@ -350,7 +373,6 @@ async def generate_detailed_summary( # ============================================================================ - def extract_speakers_from_segments(segments: list) -> List[str]: """ Extract unique speaker names from segments. @@ -364,14 +386,21 @@ def extract_speakers_from_segments(segments: list) -> List[str]: speakers = [] if segments: for seg in segments: - speaker = seg.get("speaker", "Unknown") if isinstance(seg, dict) else (seg.speaker or "Unknown") + speaker = ( + seg.get("speaker", "Unknown") + if isinstance(seg, dict) + else (seg.speaker or "Unknown") + ) if speaker and speaker != "Unknown" and speaker not in speakers: speakers.append(speaker) return speakers async def track_speech_activity( - speech_analysis: Dict[str, Any], last_word_count: int, conversation_id: str, redis_client + speech_analysis: Dict[str, Any], + last_word_count: int, + conversation_id: str, + redis_client, ) -> tuple[float, int]: """ Track new speech activity and update last speech timestamp using audio timestamps. @@ -475,7 +504,9 @@ async def update_job_progress_metadata( "conversation_id": conversation_id, "client_id": client_id, # Ensure client_id is always present "transcript": ( - combined["text"][:500] + "..." if len(combined["text"]) > 500 else combined["text"] + combined["text"][:500] + "..." + if len(combined["text"]) > 500 + else combined["text"] ), # First 500 chars "transcript_length": len(combined["text"]), "speakers": speakers, @@ -506,7 +537,9 @@ async def mark_conversation_deleted(conversation_id: str, deletion_reason: str) f"πŸ—‘οΈ Marking conversation {conversation_id} as deleted - reason: {deletion_reason}" ) - conversation = await Conversation.find_one(Conversation.conversation_id == conversation_id) + conversation = await Conversation.find_one( + Conversation.conversation_id == conversation_id + ) if conversation: conversation.deleted = True conversation.deletion_reason = deletion_reason diff --git a/backends/advanced/src/advanced_omi_backend/utils/file_utils.py b/backends/advanced/src/advanced_omi_backend/utils/file_utils.py index 0b499420..c65d1a39 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/file_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/file_utils.py @@ -15,13 +15,14 @@ class ZipExtractionError(Exception): """Exception raised when zip extraction fails.""" + pass def extract_zip( zip_path: Union[str, Path], extract_dir: Union[str, Path], - create_extract_dir: bool = True + create_extract_dir: bool = True, ) -> Path: """ Extract a zip file to a specified directory. @@ -59,7 +60,7 @@ def extract_zip( # Extract zip file try: - with zipfile.ZipFile(zip_path, 'r') as zf: + with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_dir) logger.info(f"Successfully extracted {zip_path} to {extract_dir}") return extract_dir @@ -79,4 +80,3 @@ def extract_zip( error_msg = f"Error extracting zip file {zip_path}: {e}" logger.error(error_msg) raise ZipExtractionError(error_msg) from e - diff --git a/backends/advanced/src/advanced_omi_backend/utils/gdrive_audio_utils.py b/backends/advanced/src/advanced_omi_backend/utils/gdrive_audio_utils.py index d9e39163..7e86331e 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/gdrive_audio_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/gdrive_audio_utils.py @@ -17,7 +17,6 @@ FOLDER_MIMETYPE = "application/vnd.google-apps.folder" - async def download_and_wrap_drive_file(service, file_item): file_id = file_item["id"] name = file_item["name"] @@ -36,7 +35,7 @@ async def download_and_wrap_drive_file(service, file_item): if not content: raise AudioValidationError(f"Downloaded Google Drive file '{name}' was empty") - tmp_file = tempfile.SpooledTemporaryFile(max_size=10*1024*1024) # 10 MB + tmp_file = tempfile.SpooledTemporaryFile(max_size=10 * 1024 * 1024) # 10 MB tmp_file.write(content) tmp_file.seek(0) upload_file = StarletteUploadFile(filename=name, file=tmp_file) @@ -54,10 +53,13 @@ def wrapped_close(): return upload_file + # ------------------------------------------------------------- # LIST + DOWNLOAD FILES IN FOLDER (OAUTH) # ------------------------------------------------------------- -async def download_audio_files_from_drive(folder_id: str, user_id: str) -> List[StarletteUploadFile]: +async def download_audio_files_from_drive( + folder_id: str, user_id: str +) -> List[StarletteUploadFile]: if not folder_id: raise AudioValidationError("Google Drive folder ID is required.") @@ -67,18 +69,21 @@ async def download_audio_files_from_drive(folder_id: str, user_id: str) -> List[ escaped_folder_id = folder_id.replace("\\", "\\\\").replace("'", "\\'") query = f"'{escaped_folder_id}' in parents and trashed = false" - response = service.files().list( - q=query, - fields="files(id, name, mimeType)", - includeItemsFromAllDrives=False, - supportsAllDrives=False, - ).execute() + response = ( + service.files() + .list( + q=query, + fields="files(id, name, mimeType)", + includeItemsFromAllDrives=False, + supportsAllDrives=False, + ) + .execute() + ) all_files = response.get("files", []) audio_files_metadata = [ - f for f in all_files - if f["name"].lower().endswith(AUDIO_EXTENSIONS) + f for f in all_files if f["name"].lower().endswith(AUDIO_EXTENSIONS) ] if not audio_files_metadata: @@ -86,15 +91,15 @@ async def download_audio_files_from_drive(folder_id: str, user_id: str) -> List[ wrapped_files = [] skipped_count = 0 - + for item in audio_files_metadata: - file_id = item["id"] # Get the Google Drive File ID + file_id = item["id"] # Get the Google Drive File ID # Check if the file is already processed (check Conversation by external_source_id and user_id) existing = await Conversation.find_one( Conversation.external_source_id == file_id, Conversation.external_source_type == "gdrive", - Conversation.user_id == user_id + Conversation.user_id == user_id, ) if existing: @@ -107,9 +112,11 @@ async def download_audio_files_from_drive(folder_id: str, user_id: str) -> List[ # Attach the file_id to the UploadFile object for later use (for external_source_id) wrapped_file.file_id = file_id wrapped_files.append(wrapped_file) - + if not wrapped_files and skipped_count > 0: - raise AudioValidationError(f"All {skipped_count} files in the folder have already been processed.") + raise AudioValidationError( + f"All {skipped_count} files in the folder have already been processed." + ) return wrapped_files @@ -117,5 +124,3 @@ async def download_audio_files_from_drive(folder_id: str, user_id: str) -> List[ if isinstance(e, AudioValidationError): raise raise AudioValidationError(f"Google Drive API Error: {e}") from e - - diff --git a/backends/advanced/src/advanced_omi_backend/utils/logging_utils.py b/backends/advanced/src/advanced_omi_backend/utils/logging_utils.py index d0c69e34..a8b31304 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/logging_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/logging_utils.py @@ -4,33 +4,46 @@ Provides functions to mask sensitive information in logs to prevent accidental exposure of credentials, tokens, and other secrets. """ + import re from typing import Any, Dict, List, Optional, Set, Union # Common patterns for identifying secret field names SECRET_KEYWORDS = [ - 'PASSWORD', 'PASSWD', 'PWD', - 'TOKEN', 'AUTH', 'AUTHORIZATION', - 'KEY', 'APIKEY', 'API_KEY', 'SECRET', - 'CREDENTIAL', 'CRED', - 'PRIVATE', 'CERTIFICATE', 'CERT' + "PASSWORD", + "PASSWD", + "PWD", + "TOKEN", + "AUTH", + "AUTHORIZATION", + "KEY", + "APIKEY", + "API_KEY", + "SECRET", + "CREDENTIAL", + "CRED", + "PRIVATE", + "CERTIFICATE", + "CERT", ] # Default mask for secrets -SECRET_MASK = 'β€’β€’β€’β€’β€’β€’β€’β€’' +SECRET_MASK = "β€’β€’β€’β€’β€’β€’β€’β€’" -def is_secret_field(field_name: str, additional_keywords: Optional[List[str]] = None) -> bool: +def is_secret_field( + field_name: str, additional_keywords: Optional[List[str]] = None +) -> bool: """ Check if a field name indicates a secret value. - + Args: field_name: The field/key name to check additional_keywords: Optional additional keywords to check for - + Returns: True if field name matches secret patterns - + Examples: >>> is_secret_field('smtp_password') True @@ -40,18 +53,18 @@ def is_secret_field(field_name: str, additional_keywords: Optional[List[str]] = True """ field_upper = field_name.upper() - + # Check default keywords for keyword in SECRET_KEYWORDS: if keyword in field_upper: return True - + # Check additional keywords if provided if additional_keywords: for keyword in additional_keywords: if keyword.upper() in field_upper: return True - + return False @@ -59,38 +72,37 @@ def mask_dict( data: Dict[str, Any], mask: str = SECRET_MASK, secret_fields: Optional[Set[str]] = None, - additional_keywords: Optional[List[str]] = None + additional_keywords: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Mask secret values in a dictionary for safe logging. - + Args: data: Dictionary to mask mask: String to use for masked values (default: 'β€’β€’β€’β€’β€’β€’β€’β€’') secret_fields: Explicit set of field names to mask (case-insensitive) additional_keywords: Additional keywords to identify secret fields - + Returns: New dictionary with secrets masked - + Examples: >>> config = {'smtp_host': 'smtp.gmail.com', 'smtp_password': 'secret123'} >>> mask_dict(config) {'smtp_host': 'smtp.gmail.com', 'smtp_password': 'β€’β€’β€’β€’β€’β€’β€’β€’'} - + >>> mask_dict({'token': 'abc123'}, secret_fields={'token'}) {'token': 'β€’β€’β€’β€’β€’β€’β€’β€’'} """ masked = {} secret_fields_lower = {f.lower() for f in (secret_fields or set())} - + for key, value in data.items(): # Check if this is a secret field - is_secret = ( - key.lower() in secret_fields_lower or - is_secret_field(key, additional_keywords) + is_secret = key.lower() in secret_fields_lower or is_secret_field( + key, additional_keywords ) - + if is_secret and value: # Mask non-empty secret values masked[key] = mask @@ -100,60 +112,61 @@ def mask_dict( elif isinstance(value, list): # Handle lists of dictionaries masked[key] = [ - mask_dict(item, mask, secret_fields, additional_keywords) - if isinstance(item, dict) else item + ( + mask_dict(item, mask, secret_fields, additional_keywords) + if isinstance(item, dict) + else item + ) for item in value ] else: # Keep non-secret values as-is masked[key] = value - + return masked def mask_string( - text: str, - patterns: Optional[List[str]] = None, - mask: str = SECRET_MASK + text: str, patterns: Optional[List[str]] = None, mask: str = SECRET_MASK ) -> str: """ Mask sensitive patterns in strings (e.g., tokens in error messages). - + Args: text: String to mask patterns: List of regex patterns to match and mask mask: String to use for masked values - + Returns: String with matched patterns masked - + Examples: >>> mask_string('Token: abc123def456', patterns=[r'Token: \w+']) 'Token: β€’β€’β€’β€’β€’β€’β€’β€’' - + >>> mask_string('password=secret123', patterns=[r'password=\S+']) 'password=β€’β€’β€’β€’β€’β€’β€’β€’' """ if not patterns: # Default patterns for common secret formats patterns = [ - r'password[=:]\s*\S+', - r'token[=:]\s*\S+', - r'key[=:]\s*\S+', - r'secret[=:]\s*\S+', - r'api[_-]?key[=:]\s*\S+', + r"password[=:]\s*\S+", + r"token[=:]\s*\S+", + r"key[=:]\s*\S+", + r"secret[=:]\s*\S+", + r"api[_-]?key[=:]\s*\S+", ] - + masked_text = text for pattern in patterns: # Replace the value part after the = or : with mask masked_text = re.sub( pattern, - lambda m: re.sub(r'([=:])\s*\S+', r'\1' + mask, m.group(0)), + lambda m: re.sub(r"([=:])\s*\S+", r"\1" + mask, m.group(0)), masked_text, - flags=re.IGNORECASE + flags=re.IGNORECASE, ) - + return masked_text @@ -162,21 +175,21 @@ def safe_log_config( name: str = "Configuration", mask: str = SECRET_MASK, secret_fields: Optional[Set[str]] = None, - additional_keywords: Optional[List[str]] = None + additional_keywords: Optional[List[str]] = None, ) -> str: """ Create a safe log message for configuration with masked secrets. - + Args: config: Configuration dictionary name: Name for the configuration (e.g., "SMTP Config") mask: String to use for masked values secret_fields: Explicit set of field names to mask additional_keywords: Additional keywords to identify secret fields - + Returns: Formatted string safe for logging - + Examples: >>> config = {'host': 'smtp.gmail.com', 'password': 'secret', 'port': 587} >>> safe_log_config(config, "SMTP") @@ -189,73 +202,71 @@ def safe_log_config( def mask_connection_string(connection_string: str, mask: str = SECRET_MASK) -> str: """ Mask credentials in connection strings (URLs, DSNs). - + Args: connection_string: Connection string that may contain credentials mask: String to use for masked values - + Returns: Connection string with credentials masked - + Examples: >>> mask_connection_string('mongodb://user:pass123@localhost:27017/db') 'mongodb://user:β€’β€’β€’β€’β€’β€’β€’β€’@localhost:27017/db' - + >>> mask_connection_string('postgresql://admin:secret@db.example.com/mydb') 'postgresql://admin:β€’β€’β€’β€’β€’β€’β€’β€’@db.example.com/mydb' """ # Pattern: protocol://username:password@host return re.sub( - r'([a-zA-Z][a-zA-Z0-9+.-]*://[^:]+:)[^@]+(@)', - r'\1' + mask + r'\2', - connection_string + r"([a-zA-Z][a-zA-Z0-9+.-]*://[^:]+:)[^@]+(@)", + r"\1" + mask + r"\2", + connection_string, ) def create_masked_repr( - obj: Any, - secret_attrs: Set[str], - mask: str = SECRET_MASK + obj: Any, secret_attrs: Set[str], mask: str = SECRET_MASK ) -> str: """ Create a string representation of an object with masked secret attributes. - + Useful for __repr__ methods in classes that contain secrets. - + Args: obj: Object to represent secret_attrs: Set of attribute names that are secrets mask: String to use for masked values - + Returns: String representation with secrets masked - + Examples: >>> class Config: ... def __init__(self): ... self.host = 'smtp.gmail.com' ... self.password = 'secret123' - >>> + >>> >>> config = Config() >>> create_masked_repr(config, {'password'}) "Config(host='smtp.gmail.com', password='β€’β€’β€’β€’β€’β€’β€’β€’')" """ class_name = obj.__class__.__name__ attrs = [] - + for key in dir(obj): # Skip private/magic attributes and methods - if key.startswith('_') or callable(getattr(obj, key)): + if key.startswith("_") or callable(getattr(obj, key)): continue - + value = getattr(obj, key) - + # Mask secret attributes if key in secret_attrs: value_repr = f"'{mask}'" else: value_repr = repr(value) - + attrs.append(f"{key}={value_repr}") - + return f"{class_name}({', '.join(attrs)})" diff --git a/backends/advanced/src/advanced_omi_backend/utils/model_utils.py b/backends/advanced/src/advanced_omi_backend/utils/model_utils.py index 96042ba0..59adb0d9 100644 --- a/backends/advanced/src/advanced_omi_backend/utils/model_utils.py +++ b/backends/advanced/src/advanced_omi_backend/utils/model_utils.py @@ -11,22 +11,22 @@ def get_model_config(config_data: Dict[str, Any], model_role: str) -> Dict[str, Any]: """ Get model configuration for a given role from config.yml data. - + This function looks up the default model name for the given role in the 'defaults' section, then finds the corresponding model definition in the 'models' section. - + Args: config_data: Parsed config.yml data (dict with 'defaults' and 'models' keys) model_role: The role to look up (e.g., 'llm', 'embedding', 'stt', 'tts') - + Returns: Model configuration dictionary if found - + Raises: ValueError: If the default for the role is not found or the model definition is not found in the models list. - + Example: >>> from advanced_omi_backend.services.memory.config import load_config_yml >>> from advanced_omi_backend.utils.model_utils import get_model_config @@ -34,13 +34,16 @@ def get_model_config(config_data: Dict[str, Any], model_role: str) -> Dict[str, >>> llm_config = get_model_config(config_data, 'llm') >>> print(llm_config['model_name']) """ - default_name = config_data.get('defaults', {}).get(model_role) + default_name = config_data.get("defaults", {}).get(model_role) if not default_name: - raise ValueError(f"Configuration for 'defaults.{model_role}' not found in config.yml") - - for model in config_data.get('models', []): - if model.get('name') == default_name: + raise ValueError( + f"Configuration for 'defaults.{model_role}' not found in config.yml" + ) + + for model in config_data.get("models", []): + if model.get("name") == default_name: return model - - raise ValueError(f"Model '{default_name}' for role '{model_role}' not found in config.yml models list") + raise ValueError( + f"Model '{default_name}' for role '{model_role}' not found in config.yml models list" + ) diff --git a/backends/advanced/src/advanced_omi_backend/workers/__init__.py b/backends/advanced/src/advanced_omi_backend/workers/__init__.py index d4792805..d2439887 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/__init__.py +++ b/backends/advanced/src/advanced_omi_backend/workers/__init__.py @@ -32,50 +32,33 @@ from advanced_omi_backend.models.job import _ensure_beanie_initialized # Import from audio_jobs -from .audio_jobs import ( - audio_streaming_persistence_job, -) +from .audio_jobs import audio_streaming_persistence_job # Import from conversation_jobs -from .conversation_jobs import ( - open_conversation_job, -) +from .conversation_jobs import open_conversation_job # Import from memory_jobs -from .memory_jobs import ( - enqueue_memory_processing, - process_memory_job, -) +from .memory_jobs import enqueue_memory_processing, process_memory_job # Import from speaker_jobs -from .speaker_jobs import ( - check_enrolled_speakers_job, - recognise_speakers_job, -) +from .speaker_jobs import check_enrolled_speakers_job, recognise_speakers_job # Import from transcription_jobs -from .transcription_jobs import ( - stream_speech_detection_job, - transcribe_full_audio_job, -) +from .transcription_jobs import stream_speech_detection_job, transcribe_full_audio_job __all__ = [ # Transcription jobs "transcribe_full_audio_job", "stream_speech_detection_job", - # Speaker jobs "check_enrolled_speakers_job", "recognise_speakers_job", - # Conversation jobs "open_conversation_job", "audio_streaming_persistence_job", - # Memory jobs "process_memory_job", "enqueue_memory_processing", - # Queue utils "get_queue", "get_job_stats", diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py index 757bf29b..02448123 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/audio_jobs.py @@ -30,7 +30,7 @@ async def audio_streaming_persistence_job( client_id: str, always_persist: bool = False, *, - redis_client=None + redis_client=None, ) -> Dict[str, Any]: """ Long-running RQ job that stores audio chunks in MongoDB with Opus compression. @@ -57,7 +57,9 @@ async def audio_streaming_persistence_job( cross-process config cache issues. """ - logger.info(f"🎡 Starting MongoDB audio persistence for session {session_id} (always_persist={always_persist})") + logger.info( + f"🎡 Starting MongoDB audio persistence for session {session_id} (always_persist={always_persist})" + ) # Setup audio persistence consumer group (separate from transcription consumer) audio_stream_name = f"audio:stream:{client_id}" @@ -66,12 +68,11 @@ async def audio_streaming_persistence_job( try: await redis_client.xgroup_create( - audio_stream_name, - audio_group_name, - "0", - mkstream=True + audio_stream_name, audio_group_name, "0", mkstream=True + ) + logger.info( + f"πŸ“¦ Created audio persistence consumer group for {audio_stream_name}" ) - logger.info(f"πŸ“¦ Created audio persistence consumer group for {audio_stream_name}") except Exception as e: if "BUSYGROUP" not in str(e): logger.warning(f"Failed to create audio consumer group: {e}") @@ -90,6 +91,7 @@ async def audio_streaming_persistence_job( if existing_conversation_id: existing_id_str = existing_conversation_id.decode() from advanced_omi_backend.models.conversation import Conversation + existing_conv = await Conversation.find_one( Conversation.conversation_id == existing_id_str ) @@ -118,15 +120,13 @@ async def audio_streaming_persistence_job( transcript_versions=[], memory_versions=[], processing_status="pending_transcription", - always_persist=True + always_persist=True, ) await conversation.insert() # Set conversation:current Redis key await redis_client.set( - conversation_key, - conversation.conversation_id, - ex=3600 # 1 hour expiry + conversation_key, conversation.conversation_id, ex=3600 # 1 hour expiry ) logger.info( @@ -185,6 +185,7 @@ async def audio_streaming_persistence_job( from rq import get_current_job from advanced_omi_backend.utils.job_utils import check_job_alive + current_job = get_current_job() async def flush_pcm_buffer() -> bool: @@ -207,7 +208,7 @@ async def flush_pcm_buffer() -> bool: pcm_data=bytes(pcm_buffer), sample_rate=SAMPLE_RATE, channels=CHANNELS, - bitrate=24 # 24kbps for speech + bitrate=24, # 24kbps for speech ) # Calculate chunk metadata @@ -247,7 +248,9 @@ async def flush_pcm_buffer() -> bool: # Calculate running totals chunk_count = chunk_index + 1 total_duration = end_time - compression_ratio = compressed_size / original_size if original_size > 0 else 0.0 + compression_ratio = ( + compressed_size / original_size if original_size > 0 else 0.0 + ) # Update conversation fields conversation.audio_chunks_count = chunk_count @@ -271,7 +274,9 @@ async def flush_pcm_buffer() -> bool: return True except Exception as e: - logger.error(f"❌ Failed to save audio chunk {chunk_index}: {e}", exc_info=True) + logger.error( + f"❌ Failed to save audio chunk {chunk_index}: {e}", exc_info=True + ) return False while True: @@ -303,7 +308,7 @@ async def flush_pcm_buffer() -> bool: audio_consumer_name, {audio_stream_name: ">"}, count=50, - block=500 + block=500, ) if final_messages: @@ -322,9 +327,13 @@ async def flush_pcm_buffer() -> bool: chunk_index += 1 chunk_start_time += CHUNK_DURATION_SECONDS - await redis_client.xack(audio_stream_name, audio_group_name, message_id) + await redis_client.xack( + audio_stream_name, audio_group_name, message_id + ) - logger.info(f"πŸ“¦ Final read processed {len(final_messages[0][1])} messages") + logger.info( + f"πŸ“¦ Final read processed {len(final_messages[0][1])} messages" + ) except Exception as e: logger.debug(f"Final audio read error (non-fatal): {e}") @@ -377,7 +386,11 @@ async def flush_pcm_buffer() -> bool: if current_conversation_id and len(pcm_buffer) > 0: # Flush final partial chunk await flush_pcm_buffer() - duration = (time.time() - conversation_start_time) if conversation_start_time else 0 + duration = ( + (time.time() - conversation_start_time) + if conversation_start_time + else 0 + ) logger.info( f"βœ… Conversation {current_conversation_id[:12]} ended: " f"{chunk_index + 1} chunks, {duration:.1f}s" @@ -399,7 +412,7 @@ async def flush_pcm_buffer() -> bool: audio_consumer_name, {audio_stream_name: ">"}, count=20, # Read up to 20 chunks at a time - block=100 # 100ms timeout + block=100, # 100ms timeout ) if audio_messages: @@ -429,13 +442,17 @@ async def flush_pcm_buffer() -> bool: chunk_start_time += CHUNK_DURATION_SECONDS # ACK the message - await redis_client.xack(audio_stream_name, audio_group_name, message_id) + await redis_client.xack( + audio_stream_name, audio_group_name, message_id + ) else: # No new messages if end_signal_received: consecutive_empty_reads += 1 - logger.info(f"πŸ“­ No new messages ({consecutive_empty_reads}/{max_empty_reads})") + logger.info( + f"πŸ“­ No new messages ({consecutive_empty_reads}/{max_empty_reads})" + ) if consecutive_empty_reads >= max_empty_reads: logger.info(f"βœ… Stream empty after END signal - stopping") @@ -455,7 +472,9 @@ async def flush_pcm_buffer() -> bool: # Calculate total duration if total_pcm_bytes > 0: duration = total_pcm_bytes / BYTES_PER_SECOND - compression_ratio = total_compressed_bytes / total_pcm_bytes if total_pcm_bytes > 0 else 0.0 + compression_ratio = ( + total_compressed_bytes / total_pcm_bytes if total_pcm_bytes > 0 else 0.0 + ) else: logger.warning(f"⚠️ No audio chunks written for session {session_id}") duration = 0.0 @@ -486,7 +505,7 @@ async def flush_pcm_buffer() -> bool: "total_compressed_bytes": total_compressed_bytes, "compression_ratio": compression_ratio, "duration_seconds": duration, - "runtime_seconds": runtime_seconds + "runtime_seconds": runtime_seconds, } diff --git a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_worker.py b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_worker.py index 026a3e8e..3328ba2f 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/audio_stream_worker.py +++ b/backends/advanced/src/advanced_omi_backend/workers/audio_stream_worker.py @@ -25,8 +25,7 @@ from advanced_omi_backend.speaker_recognition_client import SpeakerRecognitionClient logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" ) logger = logging.getLogger(__name__) @@ -35,16 +34,16 @@ async def main(): """Main worker entry point.""" logger.info("πŸš€ Starting streaming transcription worker") - logger.info("πŸ“‹ Provider configuration loaded from config.yml (defaults.stt_stream)") + logger.info( + "πŸ“‹ Provider configuration loaded from config.yml (defaults.stt_stream)" + ) redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") # Create Redis client try: redis_client = await redis.from_url( - redis_url, - encoding="utf-8", - decode_responses=False + redis_url, encoding="utf-8", decode_responses=False ) logger.info(f"βœ… Connected to Redis: {redis_url}") @@ -59,15 +58,21 @@ async def main(): try: plugin_router = init_plugin_router() if plugin_router: - logger.info(f"βœ… Plugin router initialized with {len(plugin_router.plugins)} plugins") + logger.info( + f"βœ… Plugin router initialized with {len(plugin_router.plugins)} plugins" + ) # Initialize async plugins for plugin_id, plugin in plugin_router.plugins.items(): try: await plugin.initialize() - logger.info(f"βœ… Plugin '{plugin_id}' initialized in streaming worker") + logger.info( + f"βœ… Plugin '{plugin_id}' initialized in streaming worker" + ) except Exception as e: - logger.exception(f"Failed to initialize plugin '{plugin_id}' in streaming worker: {e}") + logger.exception( + f"Failed to initialize plugin '{plugin_id}' in streaming worker: {e}" + ) else: logger.warning("No plugin router available - plugins will not be triggered") except Exception as e: @@ -78,9 +83,13 @@ async def main(): try: speaker_client = SpeakerRecognitionClient() if speaker_client.enabled: - logger.info(f"Speaker recognition client initialized: {speaker_client.service_url}") + logger.info( + f"Speaker recognition client initialized: {speaker_client.service_url}" + ) else: - logger.info("Speaker recognition disabled β€” streaming speaker identification off") + logger.info( + "Speaker recognition disabled β€” streaming speaker identification off" + ) speaker_client = None except Exception as e: logger.warning(f"Failed to initialize speaker recognition client: {e}") @@ -95,8 +104,12 @@ async def main(): ) logger.info("Streaming transcription consumer created") except Exception as e: - logger.error(f"Failed to create streaming transcription consumer: {e}", exc_info=True) - logger.error("Ensure config.yml has defaults.stt_stream configured with valid provider") + logger.error( + f"Failed to create streaming transcription consumer: {e}", exc_info=True + ) + logger.error( + "Ensure config.yml has defaults.stt_stream configured with valid provider" + ) await redis_client.aclose() sys.exit(1) @@ -111,7 +124,9 @@ def signal_handler(signum, frame): try: logger.info("βœ… Streaming transcription worker ready") logger.info("πŸ“‘ Listening for audio streams on audio:stream:* pattern") - logger.info("πŸ“’ Publishing interim results to transcription:interim:{session_id}") + logger.info( + "πŸ“’ Publishing interim results to transcription:interim:{session_id}" + ) logger.info("πŸ’Ύ Publishing final results to transcription:results:{session_id}") # This blocks until consumer is stopped diff --git a/backends/advanced/src/advanced_omi_backend/workers/cleanup_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/cleanup_jobs.py index 99ea5869..acc76d2f 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/cleanup_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/cleanup_jobs.py @@ -4,6 +4,7 @@ Provides manual cleanup of soft-deleted conversations and chunks. Auto-cleanup is controlled via admin API settings (stored in /app/data/cleanup_config.json). """ + import logging from datetime import datetime, timedelta from typing import Optional @@ -19,8 +20,7 @@ @async_job(redis=False, beanie=True, timeout=1800) # 30 minute timeout async def purge_old_deleted_conversations( - retention_days: Optional[int] = None, - dry_run: bool = False + retention_days: Optional[int] = None, dry_run: bool = False ) -> dict: """ Permanently delete conversations that have been soft-deleted for longer than retention period. @@ -35,16 +35,17 @@ async def purge_old_deleted_conversations( # Get retention period from config if not specified if retention_days is None: settings_dict = get_cleanup_settings() - retention_days = settings_dict['retention_days'] + retention_days = settings_dict["retention_days"] cutoff_date = datetime.utcnow() - timedelta(days=retention_days) - logger.info(f"{'[DRY RUN] ' if dry_run else ''}Purging conversations deleted before {cutoff_date.isoformat()}") + logger.info( + f"{'[DRY RUN] ' if dry_run else ''}Purging conversations deleted before {cutoff_date.isoformat()}" + ) # Find soft-deleted conversations older than cutoff old_deleted = await Conversation.find( - Conversation.deleted == True, - Conversation.deleted_at < cutoff_date + Conversation.deleted == True, Conversation.deleted_at < cutoff_date ).to_list() purged_conversations = 0 @@ -129,7 +130,7 @@ def schedule_cleanup_job(retention_days: Optional[int] = None) -> Optional[str]: """ # Check if auto-cleanup is enabled settings_dict = get_cleanup_settings() - if not settings_dict['auto_cleanup_enabled']: + if not settings_dict["auto_cleanup_enabled"]: logger.info("Auto-cleanup is disabled (auto_cleanup_enabled=false)") return None @@ -137,7 +138,7 @@ def schedule_cleanup_job(retention_days: Optional[int] = None) -> Optional[str]: from advanced_omi_backend.controllers.queue_controller import get_queue if retention_days is None: - retention_days = settings_dict['retention_days'] + retention_days = settings_dict["retention_days"] queue = get_queue("default") job = queue.enqueue( @@ -146,10 +147,11 @@ def schedule_cleanup_job(retention_days: Optional[int] = None) -> Optional[str]: dry_run=False, job_timeout="30m", ) - logger.info(f"Scheduled cleanup job {job.id} with {retention_days}-day retention") + logger.info( + f"Scheduled cleanup job {job.id} with {retention_days}-day retention" + ) return job.id except Exception as e: logger.error(f"Failed to schedule cleanup job: {e}") return None - diff --git a/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py index 9b6691c9..332b1b96 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/finetuning_jobs.py @@ -39,6 +39,7 @@ # Job 1: Speaker Fine-tuning # --------------------------------------------------------------------------- + async def run_speaker_finetuning_job() -> dict: """Process applied diarization annotations and send to speaker recognition service. @@ -136,13 +137,17 @@ async def run_speaker_finetuning_job() -> dict: # Mark as trained annotation.processed_by = ( - f"{annotation.processed_by},training" if annotation.processed_by else "training" + f"{annotation.processed_by},training" + if annotation.processed_by + else "training" ) annotation.updated_at = datetime.now(timezone.utc) await annotation.save() except Exception as e: - logger.error(f"Speaker finetuning: error processing annotation {annotation.id}: {e}") + logger.error( + f"Speaker finetuning: error processing annotation {annotation.id}: {e}" + ) failed += 1 total = enrolled + appended @@ -150,7 +155,13 @@ async def run_speaker_finetuning_job() -> dict: f"Speaker finetuning complete: {total} processed " f"({enrolled} new, {appended} appended, {failed} failed, {cleaned} orphaned cleaned)" ) - return {"enrolled": enrolled, "appended": appended, "failed": failed, "cleaned": cleaned, "processed": total} + return { + "enrolled": enrolled, + "appended": appended, + "failed": failed, + "cleaned": cleaned, + "processed": total, + } # --------------------------------------------------------------------------- @@ -174,12 +185,14 @@ def _build_vibevoice_label(conversation) -> dict: segments = [] for seg in transcript.segments: speaker_id = speaker_map.setdefault(seg.speaker, len(speaker_map)) - segments.append({ - "speaker": speaker_id, - "text": seg.text, - "start": round(seg.start, 2), - "end": round(seg.end, 2), - }) + segments.append( + { + "speaker": speaker_id, + "text": seg.text, + "start": round(seg.start, 2), + "end": round(seg.end, 2), + } + ) return { "audio_path": f"{conversation.conversation_id}.wav", @@ -198,31 +211,49 @@ async def run_asr_finetuning_job() -> dict: from advanced_omi_backend.model_registry import get_models_registry from advanced_omi_backend.models.annotation import Annotation, AnnotationType from advanced_omi_backend.models.conversation import Conversation - from advanced_omi_backend.utils.audio_chunk_utils import reconstruct_wav_from_conversation + from advanced_omi_backend.utils.audio_chunk_utils import ( + reconstruct_wav_from_conversation, + ) # Resolve STT service URL from model registry (same URL used for transcription) registry = get_models_registry() stt_model = registry.get_default("stt") if registry else None if not stt_model or not stt_model.model_url: logger.warning("ASR finetuning: no STT model configured in registry, skipping") - return {"conversations_exported": 0, "annotations_consumed": 0, "message": "No STT model configured"} + return { + "conversations_exported": 0, + "annotations_consumed": 0, + "message": "No STT model configured", + } vibevoice_url = stt_model.model_url.rstrip("/") # Find applied annotations (TRANSCRIPT and DIARIZATION) not yet consumed by ASR training annotations = await Annotation.find( - {"annotation_type": {"$in": [AnnotationType.TRANSCRIPT.value, AnnotationType.DIARIZATION.value]}}, + { + "annotation_type": { + "$in": [ + AnnotationType.TRANSCRIPT.value, + AnnotationType.DIARIZATION.value, + ] + } + }, Annotation.processed == True, ).to_list() ready = [ - a for a in annotations + a + for a in annotations if not a.processed_by or _ASR_TRAINING_MARKER not in a.processed_by ] if not ready: logger.info("ASR finetuning: no annotations ready for export") - return {"conversations_exported": 0, "annotations_consumed": 0, "message": "No annotations ready"} + return { + "conversations_exported": 0, + "annotations_consumed": 0, + "message": "No annotations ready", + } # Group annotations by conversation_id by_conversation: dict[str, list[Annotation]] = {} @@ -245,36 +276,52 @@ async def run_asr_finetuning_job() -> dict: Conversation.conversation_id == conv_id ) if not conversation or not conversation.active_transcript: - logger.warning(f"ASR finetuning: conversation {conv_id} not found or no transcript") + logger.warning( + f"ASR finetuning: conversation {conv_id} not found or no transcript" + ) errors += 1 continue if not conversation.active_transcript.segments: - logger.info(f"ASR finetuning: conversation {conv_id} has no segments, skipping") + logger.info( + f"ASR finetuning: conversation {conv_id} has no segments, skipping" + ) continue # Reconstruct full WAV audio wav_data = await reconstruct_wav_from_conversation(conv_id) if not wav_data: - logger.warning(f"ASR finetuning: no audio for conversation {conv_id}") + logger.warning( + f"ASR finetuning: no audio for conversation {conv_id}" + ) errors += 1 continue # Build training label label = _build_vibevoice_label(conversation) if not label.get("segments"): - logger.info(f"ASR finetuning: no segments in label for {conv_id}, skipping") + logger.info( + f"ASR finetuning: no segments in label for {conv_id}, skipping" + ) continue # Try to add jargon context from Redis cache if conversation.user_id: - jargon = await redis_client.get(f"asr:jargon:{conversation.user_id}") + jargon = await redis_client.get( + f"asr:jargon:{conversation.user_id}" + ) if jargon: - label["customized_context"] = [t.strip() for t in jargon.split(",") if t.strip()] + label["customized_context"] = [ + t.strip() for t in jargon.split(",") if t.strip() + ] # POST to VibeVoice /fine-tune endpoint files = { - "audio_files": (f"{conv_id}.wav", io.BytesIO(wav_data), "audio/wav"), + "audio_files": ( + f"{conv_id}.wav", + io.BytesIO(wav_data), + "audio/wav", + ), } data = {"labels": json.dumps([label])} @@ -307,7 +354,9 @@ async def run_asr_finetuning_job() -> dict: consumed += 1 except Exception as e: - logger.error(f"ASR finetuning: error processing conversation {conv_id}: {e}") + logger.error( + f"ASR finetuning: error processing conversation {conv_id}: {e}" + ) errors += 1 finally: @@ -328,6 +377,7 @@ async def run_asr_finetuning_job() -> dict: # Job 3: ASR Jargon Extraction # --------------------------------------------------------------------------- + async def run_asr_jargon_extraction_job() -> dict: """Extract jargon from recent memories for all users and cache in Redis.""" from advanced_omi_backend.models.user import User @@ -344,7 +394,9 @@ async def run_asr_jargon_extraction_job() -> dict: try: jargon = await _extract_jargon_for_user(user_id) if jargon: - await redis_client.set(f"asr:jargon:{user_id}", jargon, ex=JARGON_CACHE_TTL) + await redis_client.set( + f"asr:jargon:{user_id}", jargon, ex=JARGON_CACHE_TTL + ) processed += 1 logger.debug(f"Cached jargon for user {user_id}: {jargon[:80]}...") else: @@ -398,7 +450,9 @@ async def _extract_jargon_for_user(user_id: str) -> Optional[str]: # Use LLM to extract jargon registry = get_prompt_registry() - prompt_template = await registry.get_prompt("asr.jargon_extraction", memories=memory_text) + prompt_template = await registry.get_prompt( + "asr.jargon_extraction", memories=memory_text + ) result = await async_generate(prompt_template) diff --git a/backends/advanced/src/advanced_omi_backend/workers/obsidian_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/obsidian_jobs.py index 8c67616d..69a550a4 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/obsidian_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/obsidian_jobs.py @@ -26,7 +26,7 @@ def count_markdown_files(vault_path: str) -> int: @async_job(redis=True, beanie=False) -async def ingest_obsidian_vault_job(job_id: str, vault_path: str, redis_client=None) -> dict: # type: ignore +async def ingest_obsidian_vault_job(job_id: str, vault_path: str, redis_client=None) -> dict: # type: ignore """ Long-running ingestion job enqueued on the default RQ queue. """ @@ -80,16 +80,18 @@ async def ingest_obsidian_vault_job(job_id: str, vault_path: str, redis_client=N return {"status": "canceled"} try: - note_data = obsidian_service.parse_obsidian_note(root, filename, vault_path) + note_data = obsidian_service.parse_obsidian_note( + root, filename, vault_path + ) chunks = await obsidian_service.chunking_and_embedding(note_data) if chunks: obsidian_service.ingest_note_and_chunks(note_data, chunks) - + processed += 1 job.meta["processed"] = processed job.meta["last_file"] = os.path.join(root, filename) job.save_meta() - + except Exception as exc: logger.error("Processing %s failed: %s", filename, exc) errors.append(f"{filename}: {exc}") @@ -103,5 +105,5 @@ async def ingest_obsidian_vault_job(job_id: str, vault_path: str, redis_client=N "status": "finished", "processed": processed, "total": total, - "errors": errors + "errors": errors, } diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/health_monitor.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/health_monitor.py index 3cec3810..c97812cc 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/health_monitor.py +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/health_monitor.py @@ -105,7 +105,9 @@ async def _monitor_loop(self): raise except Exception as e: logger.error(f"Health monitor loop error: {e}", exc_info=True) - self.running = False # Mark monitor as stopped so callers know it's not active + self.running = ( + False # Mark monitor as stopped so callers know it's not active + ) raise # Re-raise to ensure the monitor task fails properly async def _check_health(self): @@ -311,7 +313,9 @@ def _handle_registration_loss(self): if success: logger.info("βœ… Bulk restart completed - workers should re-register soon") else: - logger.error("❌ Bulk restart encountered errors - check individual worker logs") + logger.error( + "❌ Bulk restart encountered errors - check individual worker logs" + ) def _restart_all_rq_workers(self) -> bool: """ diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/process_manager.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/process_manager.py index 5448b96f..909d9b0c 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/process_manager.py +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/process_manager.py @@ -270,9 +270,7 @@ def restart_worker(self, name: str, timeout: int = 30) -> bool: worker.state = WorkerState.FAILED return False - logger.info( - f"{name}: Stopped in {stop_duration:.2f}s (timeout was {timeout}s)" - ) + logger.info(f"{name}: Stopped in {stop_duration:.2f}s (timeout was {timeout}s)") # START phase with timing start_start = time.time() diff --git a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/worker_registry.py b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/worker_registry.py index 3bbf93cc..011c1de4 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/orchestrator/worker_registry.py +++ b/backends/advanced/src/advanced_omi_backend/workers/orchestrator/worker_registry.py @@ -103,14 +103,19 @@ def build_worker_definitions() -> List[WorkerDefinition]: # Log worker configuration try: from advanced_omi_backend.model_registry import get_models_registry + registry = get_models_registry() if registry: stt_stream = registry.get_default("stt_stream") stt_batch = registry.get_default("stt") if stt_stream: - logger.info(f"Streaming STT configured: {stt_stream.name} ({stt_stream.model_provider})") + logger.info( + f"Streaming STT configured: {stt_stream.name} ({stt_stream.model_provider})" + ) if stt_batch: - logger.info(f"Batch STT configured: {stt_batch.name} ({stt_batch.model_provider}) - handled by RQ workers") + logger.info( + f"Batch STT configured: {stt_batch.name} ({stt_batch.model_provider}) - handled by RQ workers" + ) except Exception as e: logger.warning(f"Failed to log STT configuration: {e}") @@ -119,9 +124,7 @@ def build_worker_definitions() -> List[WorkerDefinition]: logger.info(f"Total workers configured: {len(workers)}") logger.info(f"Enabled workers: {len(enabled_workers)}") - logger.info( - f"Enabled worker names: {', '.join([w.name for w in enabled_workers])}" - ) + logger.info(f"Enabled worker names: {', '.join([w.name for w in enabled_workers])}") if disabled_workers: logger.info( diff --git a/backends/advanced/src/advanced_omi_backend/workers/prompt_optimization_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/prompt_optimization_jobs.py index c98285a3..6824f4a4 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/prompt_optimization_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/prompt_optimization_jobs.py @@ -52,9 +52,7 @@ async def run_prompt_optimization_job() -> dict: langfuse_client = registry._get_client() if langfuse_client is None: - logger.warning( - "Prompt optimization: LangFuse not configured β€” skipping" - ) + logger.warning("Prompt optimization: LangFuse not configured β€” skipping") return {"skipped": True, "reason": "LangFuse not available"} total_users = 0 @@ -154,9 +152,7 @@ async def run_prompt_optimization_job() -> dict: "optimized_at": datetime.now(timezone.utc).isoformat(), }, ) - logger.info( - f"Created new LangFuse prompt '{user_prompt_name}'" - ) + logger.info(f"Created new LangFuse prompt '{user_prompt_name}'") except Exception as e: err_msg = str(e).lower() if "already exists" in err_msg or "409" in err_msg: diff --git a/backends/advanced/src/advanced_omi_backend/workers/waveform_jobs.py b/backends/advanced/src/advanced_omi_backend/workers/waveform_jobs.py index f58387cd..1726d61c 100644 --- a/backends/advanced/src/advanced_omi_backend/workers/waveform_jobs.py +++ b/backends/advanced/src/advanced_omi_backend/workers/waveform_jobs.py @@ -53,20 +53,24 @@ async def generate_waveform_data( waveform_gen_time = 0.0 try: - logger.info(f"🎡 Generating waveform for conversation {conversation_id[:12]}... (sample_rate={sample_rate} samples/sec)") + logger.info( + f"🎡 Generating waveform for conversation {conversation_id[:12]}... (sample_rate={sample_rate} samples/sec)" + ) # Retrieve all audio chunks for conversation fetch_start = time.time() chunks = await retrieve_audio_chunks(conversation_id=conversation_id) fetch_time = time.time() - fetch_start - logger.info(f"πŸ“¦ Fetched {len(chunks) if chunks else 0} chunks from MongoDB in {fetch_time:.2f}s") + logger.info( + f"πŸ“¦ Fetched {len(chunks) if chunks else 0} chunks from MongoDB in {fetch_time:.2f}s" + ) if not chunks: logger.warning(f"No audio chunks found for conversation {conversation_id}") return { "success": False, - "error": "No audio chunks found for this conversation" + "error": "No audio chunks found for this conversation", } # Get audio format from first chunk @@ -162,29 +166,30 @@ async def generate_waveform_data( samples=waveform_samples, sample_rate=sample_rate, duration_seconds=total_duration, - processing_time_seconds=processing_time + processing_time_seconds=processing_time, ) await waveform_doc.insert() - logger.info(f"πŸ’Ύ Saved waveform to MongoDB for conversation {conversation_id[:12]}") + logger.info( + f"πŸ’Ύ Saved waveform to MongoDB for conversation {conversation_id[:12]}" + ) return { "success": True, "samples": waveform_samples, "sample_rate": sample_rate, "duration_seconds": total_duration, - "processing_time_seconds": processing_time + "processing_time_seconds": processing_time, } except Exception as e: processing_time = time.time() - start_time logger.error( - f"❌ Waveform generation failed for {conversation_id}: {e}", - exc_info=True + f"❌ Waveform generation failed for {conversation_id}: {e}", exc_info=True ) return { "success": False, "error": str(e), - "processing_time_seconds": processing_time + "processing_time_seconds": processing_time, } diff --git a/backends/advanced/src/scripts/cleanup_state.py b/backends/advanced/src/scripts/cleanup_state.py index 49bfd332..ad3aba3f 100644 --- a/backends/advanced/src/scripts/cleanup_state.py +++ b/backends/advanced/src/scripts/cleanup_state.py @@ -35,7 +35,13 @@ from qdrant_client.models import Distance, VectorParams from rich.console import Console from rich.panel import Panel - from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn + from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, + ) from rich.prompt import Confirm from rich.table import Table from rich.text import Text @@ -62,12 +68,18 @@ # Helpers # --------------------------------------------------------------------------- + def get_qdrant_collection_name() -> str: """Get Qdrant collection name from memory service configuration.""" try: memory_config = build_memory_config_from_env() - if hasattr(memory_config, "vector_store_config") and memory_config.vector_store_config: - return memory_config.vector_store_config.get("collection_name", "chronicle_memories") + if ( + hasattr(memory_config, "vector_store_config") + and memory_config.vector_store_config + ): + return memory_config.vector_store_config.get( + "collection_name", "chronicle_memories" + ) except Exception: pass return "chronicle_memories" @@ -93,6 +105,7 @@ def _human_size(nbytes: int) -> str: # Stats # --------------------------------------------------------------------------- + class Stats: """Track counts across the system.""" @@ -177,7 +190,9 @@ async def gather_stats( if langfuse_client: try: prompts_response = langfuse_client.prompts.list(limit=100) - s.langfuse_prompts = len(prompts_response.data) if hasattr(prompts_response, "data") else 0 + s.langfuse_prompts = ( + len(prompts_response.data) if hasattr(prompts_response, "data") else 0 + ) except Exception: pass @@ -205,13 +220,21 @@ def render_stats_table(stats: Stats, title: str = "Current State") -> Table: def row(label, value, style="white"): table.add_row(label, f"[{style}]{value}[/{style}]") - row("Conversations", str(stats.conversations), "green" if stats.conversations else "dim") + row( + "Conversations", + str(stats.conversations), + "green" if stats.conversations else "dim", + ) row( " with transcripts", str(stats.conversations_with_transcript), "green" if stats.conversations_with_transcript else "dim", ) - row("Audio Chunks", str(stats.audio_chunks), "green" if stats.audio_chunks else "dim") + row( + "Audio Chunks", + str(stats.audio_chunks), + "green" if stats.audio_chunks else "dim", + ) row("Waveforms", str(stats.waveforms), "dim") row("Chat Sessions", str(stats.chat_sessions), "dim") row("Chat Messages", str(stats.chat_messages), "dim") @@ -220,7 +243,11 @@ def row(label, value, style="white"): row("Memories (Qdrant)", str(stats.memories), "yellow" if stats.memories else "dim") row("Neo4j Nodes", str(stats.neo4j_nodes), "dim") row("Neo4j Relationships", str(stats.neo4j_relationships), "dim") - row("LangFuse Prompts", str(stats.langfuse_prompts), "yellow" if stats.langfuse_prompts else "dim") + row( + "LangFuse Prompts", + str(stats.langfuse_prompts), + "yellow" if stats.langfuse_prompts else "dim", + ) table.add_section() row("Redis Jobs", str(stats.redis_jobs), "dim") row("Legacy WAV Files", str(stats.legacy_wav), "dim") @@ -234,6 +261,7 @@ def row(label, value, style="white"): # Backup # --------------------------------------------------------------------------- + class BackupResult: """Track which backup exports succeeded or failed.""" @@ -241,7 +269,13 @@ def __init__(self): self.exports: dict[str, dict] = {} # name -> {ok, path, size, sha256, error} def record(self, name: str, path: Optional[Path], ok: bool, error: str = ""): - entry = {"ok": ok, "error": error, "path": str(path) if path else None, "size": 0, "sha256": ""} + entry = { + "ok": ok, + "error": error, + "path": str(path) if path else None, + "size": 0, + "sha256": "", + } if ok and path and path.exists(): entry["size"] = path.stat().st_size entry["sha256"] = _file_sha256(path) @@ -255,10 +289,16 @@ def all_ok(self) -> bool: def critical_ok(self) -> bool: """conversations, audio_metadata, and annotations are critical.""" critical = ("conversations", "audio_metadata", "annotations") - return all(self.exports.get(n, {}).get("ok", False) for n in critical if n in self.exports) + return all( + self.exports.get(n, {}).get("ok", False) + for n in critical + if n in self.exports + ) def render_table(self) -> Table: - table = Table(title="Backup Verification", border_style="dim", title_style="bold white") + table = Table( + title="Backup Verification", border_style="dim", title_style="bold white" + ) table.add_column("Export", style="white", min_width=24) table.add_column("Status", justify="center", min_width=8) table.add_column("Size", justify="right", min_width=10) @@ -285,7 +325,14 @@ def total_size(self) -> int: class BackupManager: """Export data to a timestamped backup directory.""" - def __init__(self, backup_dir: str, export_audio: bool, mongo_db: Any, neo4j_driver: Any = None, langfuse_client: Any = None): + def __init__( + self, + backup_dir: str, + export_audio: bool, + mongo_db: Any, + neo4j_driver: Any = None, + langfuse_client: Any = None, + ): self.backup_dir = Path(backup_dir) self.export_audio = export_audio self.mongo_db = mongo_db @@ -323,7 +370,9 @@ async def run( steps.append(("audio_wav", self._export_audio_wav)) if qdrant_client: - steps.append(("memories", lambda r: self._export_memories(qdrant_client, r))) + steps.append( + ("memories", lambda r: self._export_memories(qdrant_client, r)) + ) if self.neo4j_driver: steps.append(("neo4j_graph", self._export_neo4j)) @@ -336,7 +385,11 @@ async def run( for name, func in steps: progress.update(task, description=f"Exporting {name}...") try: - path = await func(result) if asyncio.iscoroutinefunction(func) else func(result) + path = ( + await func(result) + if asyncio.iscoroutinefunction(func) + else func(result) + ) if not result.exports.get(name): # func didn't record itself - record success result.record(name, path, True) @@ -384,19 +437,21 @@ async def _export_audio_metadata(self, result: BackupResult) -> Path: cursor = collection.find({}) data = [] async for chunk in cursor: - data.append({ - "conversation_id": chunk.get("conversation_id"), - "chunk_index": chunk.get("chunk_index"), - "start_time": chunk.get("start_time"), - "end_time": chunk.get("end_time"), - "duration": chunk.get("duration"), - "original_size": chunk.get("original_size"), - "compressed_size": chunk.get("compressed_size"), - "sample_rate": chunk.get("sample_rate", 16000), - "channels": chunk.get("channels", 1), - "has_speech": chunk.get("has_speech"), - "created_at": str(chunk.get("created_at", "")), - }) + data.append( + { + "conversation_id": chunk.get("conversation_id"), + "chunk_index": chunk.get("chunk_index"), + "start_time": chunk.get("start_time"), + "end_time": chunk.get("end_time"), + "duration": chunk.get("duration"), + "original_size": chunk.get("original_size"), + "compressed_size": chunk.get("compressed_size"), + "sample_rate": chunk.get("sample_rate", 16000), + "channels": chunk.get("channels", 1), + "has_speech": chunk.get("has_speech"), + "created_at": str(chunk.get("created_at", "")), + } + ) path = self.backup_path / "audio_chunks_metadata.json" with open(path, "w") as f: json.dump(data, f, indent=2, default=str) @@ -417,14 +472,16 @@ async def _export_chat_sessions(self, result: BackupResult) -> Path: cursor = collection.find({}) data = [] async for session in cursor: - data.append({ - "session_id": session.get("session_id"), - "user_id": session.get("user_id"), - "title": session.get("title"), - "created_at": str(session.get("created_at", "")), - "updated_at": str(session.get("updated_at", "")), - "metadata": session.get("metadata", {}), - }) + data.append( + { + "session_id": session.get("session_id"), + "user_id": session.get("user_id"), + "title": session.get("title"), + "created_at": str(session.get("created_at", "")), + "updated_at": str(session.get("updated_at", "")), + "metadata": session.get("metadata", {}), + } + ) path = self.backup_path / "chat_sessions.json" with open(path, "w") as f: json.dump(data, f, indent=2, default=str) @@ -436,16 +493,18 @@ async def _export_chat_messages(self, result: BackupResult) -> Path: cursor = collection.find({}) data = [] async for msg in cursor: - data.append({ - "message_id": msg.get("message_id"), - "session_id": msg.get("session_id"), - "user_id": msg.get("user_id"), - "role": msg.get("role"), - "content": msg.get("content"), - "timestamp": str(msg.get("timestamp", "")), - "memories_used": msg.get("memories_used", []), - "metadata": msg.get("metadata", {}), - }) + data.append( + { + "message_id": msg.get("message_id"), + "session_id": msg.get("session_id"), + "user_id": msg.get("user_id"), + "role": msg.get("role"), + "content": msg.get("content"), + "timestamp": str(msg.get("timestamp", "")), + "memories_used": msg.get("memories_used", []), + "metadata": msg.get("metadata", {}), + } + ) path = self.backup_path / "chat_messages.json" with open(path, "w") as f: json.dump(data, f, indent=2, default=str) @@ -479,7 +538,9 @@ async def _export_audio_wav(self, result: BackupResult) -> Optional[Path]: for conv in conversations: try: - ok = await self._export_conversation_audio(conv.conversation_id, audio_dir) + ok = await self._export_conversation_audio( + conv.conversation_id, audio_dir + ) if ok: exported += 1 except Exception as e: @@ -491,11 +552,17 @@ async def _export_audio_wav(self, result: BackupResult) -> Optional[Path]: result.record("audio_wav", audio_dir, ok, error) return audio_dir - async def _export_conversation_audio(self, conversation_id: str, audio_dir: Path) -> bool: + async def _export_conversation_audio( + self, conversation_id: str, audio_dir: Path + ) -> bool: """Decode Opus chunks to WAV for a single conversation. Returns True if audio was exported.""" - chunks = await AudioChunkDocument.find( - AudioChunkDocument.conversation_id == conversation_id - ).sort("+chunk_index").to_list() + chunks = ( + await AudioChunkDocument.find( + AudioChunkDocument.conversation_id == conversation_id + ) + .sort("+chunk_index") + .to_list() + ) if not chunks: return False @@ -543,7 +610,9 @@ async def _export_conversation_audio(self, conversation_id: str, audio_dir: Path return True - async def _export_memories(self, qdrant_client: AsyncQdrantClient, result: BackupResult) -> Path: + async def _export_memories( + self, qdrant_client: AsyncQdrantClient, result: BackupResult + ) -> Path: collection_name = get_qdrant_collection_name() collections = await qdrant_client.get_collections() exists = any(c.name == collection_name for c in collections.collections) @@ -568,7 +637,9 @@ async def _export_memories(self, qdrant_client: AsyncQdrantClient, result: Backu if not points: break for pt in points: - data.append({"id": str(pt.id), "vector": pt.vector, "payload": pt.payload}) + data.append( + {"id": str(pt.id), "vector": pt.vector, "payload": pt.payload} + ) if next_offset is None: break offset = next_offset @@ -583,7 +654,9 @@ def _export_neo4j(self, result: BackupResult) -> Path: try: with self.neo4j_driver.session() as session: nodes_data = [] - for record in session.run("MATCH (n) RETURN n, labels(n) AS labels, elementId(n) AS eid"): + for record in session.run( + "MATCH (n) RETURN n, labels(n) AS labels, elementId(n) AS eid" + ): node = dict(record["n"]) node["_labels"] = record["labels"] node["_element_id"] = record["eid"] @@ -594,15 +667,24 @@ def _export_neo4j(self, result: BackupResult) -> Path: "MATCH (a)-[r]->(b) RETURN elementId(a) AS src, type(r) AS rel_type, " "properties(r) AS props, elementId(b) AS dst" ): - rels_data.append({ - "source": record["src"], - "type": record["rel_type"], - "properties": dict(record["props"]) if record["props"] else {}, - "target": record["dst"], - }) + rels_data.append( + { + "source": record["src"], + "type": record["rel_type"], + "properties": ( + dict(record["props"]) if record["props"] else {} + ), + "target": record["dst"], + } + ) with open(path, "w") as f: - json.dump({"nodes": nodes_data, "relationships": rels_data}, f, indent=2, default=str) + json.dump( + {"nodes": nodes_data, "relationships": rels_data}, + f, + indent=2, + default=str, + ) result.record("neo4j_graph", path, True) except Exception as e: result.record("neo4j_graph", None, False, str(e)) @@ -654,6 +736,7 @@ def _export_langfuse_prompts(self, result: BackupResult) -> Path: # Cleanup # --------------------------------------------------------------------------- + class CleanupManager: """Delete data across all services.""" @@ -768,6 +851,7 @@ def _cleanup_legacy_wav(self, stats: Stats): # Connection setup # --------------------------------------------------------------------------- + async def connect_services(): """Initialize all service connections. Returns (mongo_db, redis_conn, qdrant_client, neo4j_driver, langfuse_client).""" # MongoDB @@ -777,7 +861,13 @@ async def connect_services(): mongo_db = mongo_client[mongodb_database] await init_beanie( database=mongo_db, - document_models=[Conversation, AudioChunkDocument, WaveformData, User, Annotation], + document_models=[ + Conversation, + AudioChunkDocument, + WaveformData, + User, + Annotation, + ], ) # Redis @@ -827,6 +917,7 @@ async def connect_services(): # Display helpers # --------------------------------------------------------------------------- + def print_header(): console.print() console.print( @@ -841,11 +932,19 @@ def print_header(): def print_dry_run(stats: Stats, args): console.print() - console.print(Panel("[bold yellow]DRY-RUN MODE[/bold yellow] - no changes will be made", border_style="yellow")) + console.print( + Panel( + "[bold yellow]DRY-RUN MODE[/bold yellow] - no changes will be made", + border_style="yellow", + ) + ) console.print() if args.backup or args.backup_only: - console.print("[cyan]Would create backup at:[/cyan]", str(Path(args.backup_dir) / f"backup_...")) + console.print( + "[cyan]Would create backup at:[/cyan]", + str(Path(args.backup_dir) / f"backup_..."), + ) if args.export_audio: audio_note = f"(from {stats.conversations_with_transcript} conversations with transcripts)" console.print(f"[cyan]Would export audio WAV files[/cyan] {audio_note}") @@ -887,18 +986,26 @@ def print_confirmation(stats: Stats, args) -> bool: console.print() if args.backup or args.backup_only: - console.print(Panel( - f"[green]Backup will be created at:[/green] {args.backup_dir}\n" - + ("[green]Audio WAV export included[/green]" if args.export_audio else "[dim]Audio WAV export: off[/dim]"), - title="Backup", - border_style="green", - )) + console.print( + Panel( + f"[green]Backup will be created at:[/green] {args.backup_dir}\n" + + ( + "[green]Audio WAV export included[/green]" + if args.export_audio + else "[dim]Audio WAV export: off[/dim]" + ), + title="Backup", + border_style="green", + ) + ) elif not args.backup_only: - console.print(Panel( - "[bold red]No backup will be created![/bold red]\nData will be permanently lost.", - title="Warning", - border_style="red", - )) + console.print( + Panel( + "[bold red]No backup will be created![/bold red]\nData will be permanently lost.", + title="Warning", + border_style="red", + ) + ) if not args.backup_only: items = [ @@ -911,18 +1018,22 @@ def print_confirmation(stats: Stats, args) -> bool: f" {stats.memories} memories", ] if stats.neo4j_nodes: - items.append(f" {stats.neo4j_nodes} Neo4j nodes + {stats.neo4j_relationships} relationships") + items.append( + f" {stats.neo4j_nodes} Neo4j nodes + {stats.neo4j_relationships} relationships" + ) items.append(f" {stats.redis_jobs} Redis jobs") if args.include_wav: items.append(f" {stats.legacy_wav} legacy WAV files") if args.delete_users: items.append(f" [bold red]{stats.users} users (DANGEROUS)[/bold red]") - console.print(Panel( - "\n".join(items), - title="[bold red]Will Delete[/bold red]", - border_style="red", - )) + console.print( + Panel( + "\n".join(items), + title="[bold red]Will Delete[/bold red]", + border_style="red", + ) + ) console.print() return Confirm.ask("[bold]Proceed?[/bold]", default=False) @@ -932,6 +1043,7 @@ def print_confirmation(stats: Stats, args) -> bool: # Main # --------------------------------------------------------------------------- + async def main(): parser = argparse.ArgumentParser( description="Chronicle Cleanup & Backup Tool", @@ -947,14 +1059,37 @@ async def main(): """, ) - parser.add_argument("--backup", action="store_true", help="Create backup before cleaning") - parser.add_argument("--backup-only", action="store_true", help="Create backup WITHOUT cleaning (safe)") - parser.add_argument("--export-audio", action="store_true", help="Include audio WAV export in backup (conversations with transcripts only)") - parser.add_argument("--include-wav", action="store_true", help="Include legacy WAV file cleanup") - parser.add_argument("--dry-run", action="store_true", help="Preview without making changes") + parser.add_argument( + "--backup", action="store_true", help="Create backup before cleaning" + ) + parser.add_argument( + "--backup-only", + action="store_true", + help="Create backup WITHOUT cleaning (safe)", + ) + parser.add_argument( + "--export-audio", + action="store_true", + help="Include audio WAV export in backup (conversations with transcripts only)", + ) + parser.add_argument( + "--include-wav", action="store_true", help="Include legacy WAV file cleanup" + ) + parser.add_argument( + "--dry-run", action="store_true", help="Preview without making changes" + ) parser.add_argument("--force", action="store_true", help="Skip confirmation prompt") - parser.add_argument("--backup-dir", type=str, default="/app/data/backups", help="Backup directory (default: /app/data/backups)") - parser.add_argument("--delete-users", action="store_true", help="DANGEROUS: Also delete user accounts") + parser.add_argument( + "--backup-dir", + type=str, + default="/app/data/backups", + help="Backup directory (default: /app/data/backups)", + ) + parser.add_argument( + "--delete-users", + action="store_true", + help="DANGEROUS: Also delete user accounts", + ) args = parser.parse_args() @@ -968,11 +1103,15 @@ async def main(): # Connect with console.status("[bold cyan]Connecting to services...", spinner="dots"): - mongo_db, redis_conn, qdrant_client, neo4j_driver, langfuse_client = await connect_services() + mongo_db, redis_conn, qdrant_client, neo4j_driver, langfuse_client = ( + await connect_services() + ) # Gather stats with console.status("[bold cyan]Gathering statistics...", spinner="dots"): - stats = await gather_stats(mongo_db, redis_conn, qdrant_client, neo4j_driver, langfuse_client) + stats = await gather_stats( + mongo_db, redis_conn, qdrant_client, neo4j_driver, langfuse_client + ) console.print() console.print(render_stats_table(stats, "Current Backend State")) @@ -993,7 +1132,9 @@ async def main(): do_backup = args.backup or args.backup_only if do_backup: console.print() - backup_mgr = BackupManager(args.backup_dir, args.export_audio, mongo_db, neo4j_driver, langfuse_client) + backup_mgr = BackupManager( + args.backup_dir, args.export_audio, mongo_db, neo4j_driver, langfuse_client + ) result = await backup_mgr.run(qdrant_client, stats) console.print() @@ -1006,44 +1147,61 @@ async def main(): if not result.critical_ok: console.print() - console.print(Panel( - "[bold red]Critical backup exports failed![/bold red]\n" - "Conversations or audio metadata could not be exported.\n" - "Cleanup will NOT proceed to protect your data.", - title="Backup Verification Failed", - border_style="red", - )) + console.print( + Panel( + "[bold red]Critical backup exports failed![/bold red]\n" + "Conversations or audio metadata could not be exported.\n" + "Cleanup will NOT proceed to protect your data.", + title="Backup Verification Failed", + border_style="red", + ) + ) sys.exit(1) if not result.all_ok: console.print() - console.print("[yellow]Some non-critical exports failed (see table above).[/yellow]") + console.print( + "[yellow]Some non-critical exports failed (see table above).[/yellow]" + ) # If backup-only, we're done if args.backup_only: console.print() - console.print(Panel( - "[bold green]Backup completed successfully![/bold green]\n" - "No data was deleted.", - border_style="green", - )) + console.print( + Panel( + "[bold green]Backup completed successfully![/bold green]\n" + "No data was deleted.", + border_style="green", + ) + ) return # Cleanup console.print() cleanup_mgr = CleanupManager( - mongo_db, redis_conn, qdrant_client, args.include_wav, args.delete_users, neo4j_driver + mongo_db, + redis_conn, + qdrant_client, + args.include_wav, + args.delete_users, + neo4j_driver, ) success = await cleanup_mgr.run(stats) if not success: - console.print(Panel("[bold red]Cleanup encountered errors![/bold red]", border_style="red")) + console.print( + Panel( + "[bold red]Cleanup encountered errors![/bold red]", border_style="red" + ) + ) sys.exit(1) # Verify console.print() with console.status("[bold cyan]Verifying cleanup...", spinner="dots"): - final_stats = await gather_stats(mongo_db, redis_conn, qdrant_client, neo4j_driver, langfuse_client) + final_stats = await gather_stats( + mongo_db, redis_conn, qdrant_client, neo4j_driver, langfuse_client + ) console.print(render_stats_table(final_stats, "After Cleanup")) diff --git a/backends/advanced/start-k8s.sh b/backends/advanced/start-k8s.sh index 847e3a6e..2274ed99 100755 --- a/backends/advanced/start-k8s.sh +++ b/backends/advanced/start-k8s.sh @@ -289,4 +289,4 @@ echo "⚠️ One service exited, stopping all services..." wait echo "πŸ”„ All services stopped" -exit 1 \ No newline at end of file +exit 1 diff --git a/backends/advanced/start.sh b/backends/advanced/start.sh index b5eaa2a0..d066896f 100755 --- a/backends/advanced/start.sh +++ b/backends/advanced/start.sh @@ -70,4 +70,4 @@ echo "⚠️ One service exited, stopping all services..." wait echo "πŸ”„ All services stopped" -exit 1 \ No newline at end of file +exit 1 diff --git a/backends/advanced/tests/test_audio_persistence_mongodb.py b/backends/advanced/tests/test_audio_persistence_mongodb.py index 30b5212e..49e4e889 100644 --- a/backends/advanced/tests/test_audio_persistence_mongodb.py +++ b/backends/advanced/tests/test_audio_persistence_mongodb.py @@ -15,26 +15,26 @@ from pathlib import Path import pytest +from beanie import init_beanie from bson import Binary from motor.motor_asyncio import AsyncIOMotorClient -from beanie import init_beanie from advanced_omi_backend.models.audio_chunk import AudioChunkDocument from advanced_omi_backend.models.conversation import Conversation from advanced_omi_backend.utils.audio_chunk_utils import ( - encode_pcm_to_opus, - decode_opus_to_pcm, build_wav_from_pcm, - retrieve_audio_chunks, concatenate_chunks_to_pcm, - reconstruct_wav_from_conversation, convert_wav_to_chunks, + decode_opus_to_pcm, + encode_pcm_to_opus, + reconstruct_wav_from_conversation, + retrieve_audio_chunks, wait_for_audio_chunks, ) - # Test configuration + def get_mongodb_url(): """Get MongoDB URL from environment or pytest args.""" return os.getenv("MONGODB_URI", "mongodb://localhost:27018") @@ -66,10 +66,7 @@ async def init_db(mongodb_client): """Initialize Beanie with test database.""" db = mongodb_client[get_test_db_name()] - await init_beanie( - database=db, - document_models=[AudioChunkDocument, Conversation] - ) + await init_beanie(database=db, document_models=[AudioChunkDocument, Conversation]) yield db @@ -88,6 +85,7 @@ async def clean_db(init_db): # Test data generators + def generate_pcm_data(duration_seconds=1, sample_rate=16000): """Generate sample PCM audio data.""" num_samples = int(sample_rate * duration_seconds) @@ -112,6 +110,7 @@ def create_wav_file(pcm_data, output_path, sample_rate=16000): # Integration Tests + @pytest.mark.asyncio class TestOpusCodecIntegration: """Test Opus encoding/decoding with real data.""" @@ -239,11 +238,7 @@ async def test_retrieve_chunks_with_pagination(self, clean_db): await chunk.insert() # Retrieve chunks 5-7 (3 chunks starting at index 5) - chunks = await retrieve_audio_chunks( - conversation_id, - start_index=5, - limit=3 - ) + chunks = await retrieve_audio_chunks(conversation_id, start_index=5, limit=3) assert len(chunks) == 3 assert chunks[0].chunk_index == 5 @@ -341,7 +336,7 @@ async def test_convert_wav_to_chunks(self, clean_db, tmp_path): conversation_id=conversation_id, audio_uuid="test-audio-001", user_id="test-user", - client_id="test-client" + client_id="test-client", ) await conversation.insert() @@ -376,7 +371,7 @@ async def test_convert_long_wav_creates_multiple_chunks(self, clean_db, tmp_path conversation_id=conversation_id, audio_uuid="test-audio-002", user_id="test-user", - client_id="test-client" + client_id="test-client", ) await conversation.insert() @@ -419,10 +414,7 @@ async def test_wait_for_chunks_immediate_success(self, clean_db): async def test_wait_for_chunks_timeout(self, clean_db): """Test wait times out when chunks don't exist.""" - result = await wait_for_audio_chunks( - "nonexistent-conv", - max_wait_seconds=1 - ) + result = await wait_for_audio_chunks("nonexistent-conv", max_wait_seconds=1) assert result is False diff --git a/backends/advanced/tests/test_conversation_models.py b/backends/advanced/tests/test_conversation_models.py index 4652e14a..bd24e105 100644 --- a/backends/advanced/tests/test_conversation_models.py +++ b/backends/advanced/tests/test_conversation_models.py @@ -19,7 +19,7 @@ def test_speaker_segment_model(self): end=15.8, text="Hello, how are you today?", speaker="Speaker A", - confidence=0.95 + confidence=0.95, ) assert segment.start == 10.5 @@ -31,8 +31,12 @@ def test_speaker_segment_model(self): def test_transcript_version_model(self): """Test TranscriptVersion model.""" segments = [ - Conversation.SpeakerSegment(start=0.0, end=5.0, text="Hello", speaker="Speaker A"), - Conversation.SpeakerSegment(start=5.1, end=10.0, text="Hi there", speaker="Speaker B") + Conversation.SpeakerSegment( + start=0.0, end=5.0, text="Hello", speaker="Speaker A" + ), + Conversation.SpeakerSegment( + start=5.1, end=10.0, text="Hi there", speaker="Speaker B" + ), ] version = Conversation.TranscriptVersion( @@ -43,7 +47,7 @@ def test_transcript_version_model(self): model="nova-3", created_at=datetime.now(), processing_time_seconds=12.5, - metadata={"confidence": 0.9} + metadata={"confidence": 0.9}, ) assert version.version_id == "trans-v1" @@ -64,7 +68,7 @@ def test_memory_version_model(self): model="gpt-4o-mini", created_at=datetime.now(), processing_time_seconds=45.2, - metadata={"extraction_quality": "high"} + metadata={"extraction_quality": "high"}, ) assert version.version_id == "mem-v1" @@ -82,12 +86,7 @@ def test_provider_enums(self): def test_word_model(self): """Test Word model.""" - word = Conversation.Word( - word="hello", - start=0.0, - end=0.5, - confidence=0.98 - ) + word = Conversation.Word(word="hello", start=0.0, end=0.5, confidence=0.98) assert word.word == "hello" assert word.start == 0.0 assert word.end == 0.5 @@ -96,10 +95,7 @@ def test_word_model(self): def test_speaker_segment_defaults(self): """Test SpeakerSegment default values.""" segment = Conversation.SpeakerSegment( - start=0.0, - end=1.0, - text="Test", - speaker="Speaker 0" + start=0.0, end=1.0, text="Test", speaker="Speaker 0" ) assert segment.confidence is None assert segment.identified_as is None diff --git a/backends/advanced/tests/test_memory_entry.py b/backends/advanced/tests/test_memory_entry.py index fc8ae223..af46d969 100644 --- a/backends/advanced/tests/test_memory_entry.py +++ b/backends/advanced/tests/test_memory_entry.py @@ -4,6 +4,7 @@ """ import time + from advanced_omi_backend.services.memory.base import MemoryEntry @@ -14,10 +15,7 @@ def test_memory_entry_auto_initializes_timestamps(self): """Test that MemoryEntry auto-initializes created_at and updated_at when not provided.""" before_creation = int(time.time()) - entry = MemoryEntry( - id="test-123", - content="Test memory content" - ) + entry = MemoryEntry(id="test-123", content="Test memory content") after_creation = int(time.time()) @@ -34,24 +32,32 @@ def test_memory_entry_auto_initializes_timestamps(self): updated_timestamp = int(entry.updated_at) # Timestamps should be within reasonable range (during test execution) - assert before_creation <= created_timestamp <= after_creation, "created_at should be within test execution time" - assert before_creation <= updated_timestamp <= after_creation, "updated_at should be within test execution time" + assert ( + before_creation <= created_timestamp <= after_creation + ), "created_at should be within test execution time" + assert ( + before_creation <= updated_timestamp <= after_creation + ), "updated_at should be within test execution time" # Both should be equal since they're created at the same time - assert entry.created_at == entry.updated_at, "created_at and updated_at should be equal for new entries" + assert ( + entry.created_at == entry.updated_at + ), "created_at and updated_at should be equal for new entries" def test_memory_entry_with_created_at_only(self): """Test that updated_at defaults to created_at when only created_at is provided.""" custom_timestamp = "1234567890" entry = MemoryEntry( - id="test-123", - content="Test memory content", - created_at=custom_timestamp + id="test-123", content="Test memory content", created_at=custom_timestamp ) - assert entry.created_at == custom_timestamp, "created_at should match provided value" - assert entry.updated_at == custom_timestamp, "updated_at should default to created_at" + assert ( + entry.created_at == custom_timestamp + ), "created_at should match provided value" + assert ( + entry.updated_at == custom_timestamp + ), "updated_at should default to created_at" def test_memory_entry_with_both_timestamps(self): """Test that both timestamps are preserved when explicitly provided.""" @@ -62,19 +68,25 @@ def test_memory_entry_with_both_timestamps(self): id="test-123", content="Test memory content", created_at=created_timestamp, - updated_at=updated_timestamp + updated_at=updated_timestamp, ) - assert entry.created_at == created_timestamp, "created_at should match provided value" - assert entry.updated_at == updated_timestamp, "updated_at should match provided value" - assert entry.created_at != entry.updated_at, "timestamps should be different when explicitly set" + assert ( + entry.created_at == created_timestamp + ), "created_at should match provided value" + assert ( + entry.updated_at == updated_timestamp + ), "updated_at should match provided value" + assert ( + entry.created_at != entry.updated_at + ), "timestamps should be different when explicitly set" def test_memory_entry_to_dict_includes_timestamps(self): """Test that to_dict() serialization includes both timestamp fields.""" entry = MemoryEntry( id="test-123", content="Test memory content", - metadata={"user_id": "user-456"} + metadata={"user_id": "user-456"}, ) entry_dict = entry.to_dict() @@ -86,15 +98,25 @@ def test_memory_entry_to_dict_includes_timestamps(self): assert "created_at" in entry_dict, "Dict should contain 'created_at'" assert "updated_at" in entry_dict, "Dict should contain 'updated_at'" assert "metadata" in entry_dict, "Dict should contain 'metadata'" - assert "user_id" in entry_dict, "Dict should contain 'user_id' (extracted from metadata)" + assert ( + "user_id" in entry_dict + ), "Dict should contain 'user_id' (extracted from metadata)" # Verify timestamp values are present and correct - assert entry_dict["created_at"] == entry.created_at, "Serialized created_at should match entry" - assert entry_dict["updated_at"] == entry.updated_at, "Serialized updated_at should match entry" + assert ( + entry_dict["created_at"] == entry.created_at + ), "Serialized created_at should match entry" + assert ( + entry_dict["updated_at"] == entry.updated_at + ), "Serialized updated_at should match entry" # Verify frontend compatibility - assert entry_dict["memory"] == entry.content, "memory field should match content for frontend" - assert entry_dict["content"] == entry.content, "content field should match content" + assert ( + entry_dict["memory"] == entry.content + ), "memory field should match content for frontend" + assert ( + entry_dict["content"] == entry.content + ), "content field should match content" def test_memory_entry_with_none_timestamps(self): """Test that None timestamps are properly initialized.""" @@ -102,13 +124,19 @@ def test_memory_entry_with_none_timestamps(self): id="test-123", content="Test memory content", created_at=None, - updated_at=None + updated_at=None, ) # Both should be auto-initialized even when explicitly set to None - assert entry.created_at is not None, "created_at should be auto-initialized from None" - assert entry.updated_at is not None, "updated_at should be auto-initialized from None" - assert entry.created_at == entry.updated_at, "Both timestamps should be equal when auto-initialized" + assert ( + entry.created_at is not None + ), "created_at should be auto-initialized from None" + assert ( + entry.updated_at is not None + ), "updated_at should be auto-initialized from None" + assert ( + entry.created_at == entry.updated_at + ), "Both timestamps should be equal when auto-initialized" def test_memory_entry_with_all_fields(self): """Test MemoryEntry with all fields populated.""" @@ -119,7 +147,7 @@ def test_memory_entry_with_all_fields(self): embedding=[0.1, 0.2, 0.3], score=0.95, created_at="1234567890", - updated_at="1234567900" + updated_at="1234567900", ) # Verify all fields are preserved @@ -138,10 +166,7 @@ def test_memory_entry_with_all_fields(self): def test_memory_entry_timestamp_format(self): """Test that timestamps are in the expected format (Unix timestamp strings).""" - entry = MemoryEntry( - id="test-123", - content="Test memory content" - ) + entry = MemoryEntry(id="test-123", content="Test memory content") # Timestamps should be strings representing Unix timestamps assert entry.created_at.isdigit(), "created_at should be a numeric string" diff --git a/backends/advanced/tests/test_memory_providers.py b/backends/advanced/tests/test_memory_providers.py index 945d4eb9..5eb2b6dc 100644 --- a/backends/advanced/tests/test_memory_providers.py +++ b/backends/advanced/tests/test_memory_providers.py @@ -5,8 +5,11 @@ """ import time -from advanced_omi_backend.services.memory.providers.openmemory_mcp import OpenMemoryMCPService + from advanced_omi_backend.services.memory.base import MemoryEntry +from advanced_omi_backend.services.memory.providers.openmemory_mcp import ( + OpenMemoryMCPService, +) class TestOpenMemoryMCPProviderTimestamps: @@ -25,7 +28,7 @@ def test_mcp_result_to_memory_entry_with_both_timestamps(self): "content": "Test memory content", "created_at": "1704067200", # 2024-01-01 00:00:00 UTC "updated_at": "1704153600", # 2024-01-02 00:00:00 UTC - "metadata": {"source": "test"} + "metadata": {"source": "test"}, } # Convert to MemoryEntry @@ -64,7 +67,9 @@ def test_mcp_result_to_memory_entry_with_missing_updated_at(self): assert entry is not None, "MemoryEntry should be created" assert entry.created_at is not None, "created_at should be present" assert entry.updated_at is not None, "updated_at should default to created_at" - assert entry.created_at == entry.updated_at, "updated_at should equal created_at when missing" + assert ( + entry.created_at == entry.updated_at + ), "updated_at should equal created_at when missing" def test_mcp_result_to_memory_entry_with_alternate_timestamp_fields(self): """Test that OpenMemory MCP provider handles alternate timestamp field names.""" @@ -85,9 +90,13 @@ def test_mcp_result_to_memory_entry_with_alternate_timestamp_fields(self): # Verify conversion handles alternate field names assert entry is not None, "MemoryEntry should be created" - assert entry.content == "Test memory content", "Should extract from 'memory' field" + assert ( + entry.content == "Test memory content" + ), "Should extract from 'memory' field" assert entry.created_at == "1704067200", "Should extract from 'timestamp' field" - assert entry.updated_at == "1704153600", "Should extract from 'modified_at' field" + assert ( + entry.updated_at == "1704153600" + ), "Should extract from 'modified_at' field" def test_mcp_result_with_no_timestamps(self): """Test that OpenMemory MCP provider generates timestamps when none provided.""" @@ -116,8 +125,12 @@ def test_mcp_result_with_no_timestamps(self): # Verify timestamps are current (within test execution window) created_int = int(entry.created_at) updated_int = int(entry.updated_at) - assert before_conversion <= created_int <= after_conversion, "Timestamp should be current" - assert before_conversion <= updated_int <= after_conversion, "Timestamp should be current" + assert ( + before_conversion <= created_int <= after_conversion + ), "Timestamp should be current" + assert ( + before_conversion <= updated_int <= after_conversion + ), "Timestamp should be current" class TestProviderTimestampConsistency: @@ -139,8 +152,18 @@ def test_all_providers_return_memory_entry_with_timestamps(self): # Verify all return MemoryEntry instances with both timestamp fields for entry, provider_name in [(mcp_entry, "OpenMemory MCP")]: - assert isinstance(entry, MemoryEntry), f"{provider_name} should return MemoryEntry" - assert hasattr(entry, "created_at"), f"{provider_name} entry should have created_at" - assert hasattr(entry, "updated_at"), f"{provider_name} entry should have updated_at" - assert entry.created_at is not None, f"{provider_name} created_at should not be None" - assert entry.updated_at is not None, f"{provider_name} updated_at should not be None" + assert isinstance( + entry, MemoryEntry + ), f"{provider_name} should return MemoryEntry" + assert hasattr( + entry, "created_at" + ), f"{provider_name} entry should have created_at" + assert hasattr( + entry, "updated_at" + ), f"{provider_name} entry should have updated_at" + assert ( + entry.created_at is not None + ), f"{provider_name} created_at should not be None" + assert ( + entry.updated_at is not None + ), f"{provider_name} updated_at should not be None" diff --git a/backends/advanced/tests/test_obsidian_service.py b/backends/advanced/tests/test_obsidian_service.py index 0daafc1a..47ef3601 100644 --- a/backends/advanced/tests/test_obsidian_service.py +++ b/backends/advanced/tests/test_obsidian_service.py @@ -1,41 +1,56 @@ -import unittest import asyncio -from unittest.mock import MagicMock, patch, AsyncMock -import sys import os +import sys +import unittest +from unittest.mock import AsyncMock, MagicMock, patch -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../src"))) from advanced_omi_backend.services.obsidian_service import ( - ObsidianService, ObsidianSearchError, + ObsidianService, ) + class TestObsidianService(unittest.TestCase): def setUp(self): # Patch load_root_config - self.config_patcher = patch('advanced_omi_backend.services.obsidian_service.load_root_config') + self.config_patcher = patch( + "advanced_omi_backend.services.obsidian_service.load_root_config" + ) self.mock_load_config = self.config_patcher.start() self.mock_load_config.return_value = { - 'defaults': {'llm': 'gpt-4', 'embedding': 'text-embedding-3-small'}, - 'models': [ - {'name': 'gpt-4', 'model_url': 'https://api.openai.com/v1', 'api_key': 'sk-test'}, - {'name': 'text-embedding-3-small', 'model_name': 'text-embedding-3-small', 'embedding_dimensions': 1536, 'model_url': 'https://api.openai.com/v1', 'api_key': 'sk-test'} - ] + "defaults": {"llm": "gpt-4", "embedding": "text-embedding-3-small"}, + "models": [ + { + "name": "gpt-4", + "model_url": "https://api.openai.com/v1", + "api_key": "sk-test", + }, + { + "name": "text-embedding-3-small", + "model_name": "text-embedding-3-small", + "embedding_dimensions": 1536, + "model_url": "https://api.openai.com/v1", + "api_key": "sk-test", + }, + ], } self.addCleanup(self.config_patcher.stop) # Patch embedding helper self.embedding_patcher = patch( - 'advanced_omi_backend.services.obsidian_service.generate_openai_embeddings', - new_callable=AsyncMock + "advanced_omi_backend.services.obsidian_service.generate_openai_embeddings", + new_callable=AsyncMock, ) self.mock_generate_embeddings = self.embedding_patcher.start() self.addCleanup(self.embedding_patcher.stop) # Patch GraphDatabase - self.graph_db_patcher = patch('advanced_omi_backend.services.neo4j_client.GraphDatabase') + self.graph_db_patcher = patch( + "advanced_omi_backend.services.neo4j_client.GraphDatabase" + ) self.mock_graph_db = self.graph_db_patcher.start() self.mock_driver = MagicMock() self.mock_session = MagicMock() @@ -44,14 +59,17 @@ def setUp(self): self.addCleanup(self.graph_db_patcher.stop) # Patch environment variables - self.env_patcher = patch.dict(os.environ, { - "NEO4J_HOST": "localhost", - "NEO4J_USER": "neo4j", - "NEO4J_PASSWORD": "password" - }) + self.env_patcher = patch.dict( + os.environ, + { + "NEO4J_HOST": "localhost", + "NEO4J_USER": "neo4j", + "NEO4J_PASSWORD": "password", + }, + ) self.env_patcher.start() self.addCleanup(self.env_patcher.stop) - + # Initialize Service self.service = ObsidianService() @@ -59,45 +77,45 @@ def test_search_obsidian_success(self): # Setup mock embedding response mock_embedding = [0.1, 0.2, 0.3] self.mock_generate_embeddings.return_value = [mock_embedding] - + # Setup mock Neo4j results mock_record1 = { - 'source': 'Note1', - 'content': 'Content of chunk 1', - 'tags': ['tag1', 'tag2'], - 'outgoing_links': ['Note2'], - 'score': 0.95 + "source": "Note1", + "content": "Content of chunk 1", + "tags": ["tag1", "tag2"], + "outgoing_links": ["Note2"], + "score": 0.95, } mock_record2 = { - 'source': 'Note2', - 'content': 'Content of chunk 2', - 'tags': [], - 'outgoing_links': [], - 'score': 0.90 + "source": "Note2", + "content": "Content of chunk 2", + "tags": [], + "outgoing_links": [], + "score": 0.90, } - + # The session.run returns an iterable of records self.mock_session.run.return_value = [mock_record1, mock_record2] - + # Execute search response = asyncio.run(self.service.search_obsidian("test query", limit=2)) - + # Assertions # 1. Check embedding call self.mock_generate_embeddings.assert_awaited_once() - + # 2. Check Neo4j query execution self.mock_session.run.assert_called_once() args, kwargs = self.mock_session.run.call_args self.assertIn("CALL db.index.vector.queryNodes", args[0]) - self.assertEqual(kwargs['vector'], mock_embedding) - self.assertEqual(kwargs['limit'], 2) - + self.assertEqual(kwargs["vector"], mock_embedding) + self.assertEqual(kwargs["limit"], 2) + # 3. Check results formatting - self.assertEqual(len(response['results']), 2) - + self.assertEqual(len(response["results"]), 2) + # Check first result format - first_entry = response['results'][0] + first_entry = response["results"][0] self.assertIn("SOURCE: Note1", first_entry) self.assertIn("TAGS: tag1, tag2", first_entry) self.assertIn("RELATED NOTES: Note2", first_entry) @@ -105,18 +123,18 @@ def test_search_obsidian_success(self): def test_setup_database(self): self.service.setup_database() - + # Verify constraints and index creation calls self.assertTrue(self.mock_session.run.called) # It should run at least 3 queries: Note constraint, Chunk constraint, Vector Index self.assertGreaterEqual(self.mock_session.run.call_count, 3) - + calls = [call[0][0] for call in self.mock_session.run.call_args_list] self.assertTrue(any("CREATE CONSTRAINT note_path" in c for c in calls)) self.assertTrue(any("CREATE CONSTRAINT chunk_id" in c for c in calls)) self.assertTrue(any("CREATE VECTOR INDEX chunk_embeddings" in c for c in calls)) - @patch('advanced_omi_backend.services.obsidian_service.chunk_text_with_spacy') + @patch("advanced_omi_backend.services.obsidian_service.chunk_text_with_spacy") def test_chunking_and_embedding_uses_shared_chunker(self, mock_chunker): mock_chunker.return_value = ["part1"] self.mock_generate_embeddings.return_value = [[0.1, 0.2]] @@ -130,7 +148,9 @@ def test_chunking_and_embedding_uses_shared_chunker(self, mock_chunker): "tags": [], } chunks = asyncio.run(self.service.chunking_and_embedding(note_data)) - mock_chunker.assert_called_once_with("sample", max_tokens=self.service.chunk_word_limit) + mock_chunker.assert_called_once_with( + "sample", max_tokens=self.service.chunk_word_limit + ) self.mock_generate_embeddings.assert_awaited_once() self.assertEqual(len(chunks), 1) @@ -142,21 +162,19 @@ def test_ingest_note_and_chunks(self): "content": "some content", "wordcount": 2, "links": ["OtherNote"], - "tags": ["tag1"] + "tags": ["tag1"], } - chunks = [ - {"text": "chunk1", "embedding": [0.1, 0.2]} - ] - + chunks = [{"text": "chunk1", "embedding": [0.1, 0.2]}] + self.service.ingest_note_and_chunks(note_data, chunks) - + # Verify DB calls # 1. Note + Folder merge # 2. Chunk merge # 3. Tag merge # 4. Link merge self.assertGreaterEqual(self.mock_session.run.call_count, 4) - + calls = [call[0][0] for call in self.mock_session.run.call_args_list] self.assertTrue(any("MERGE (f:Folder" in c for c in calls)) self.assertTrue(any("MERGE (c:Chunk" in c for c in calls)) @@ -166,10 +184,10 @@ def test_ingest_note_and_chunks(self): def test_search_obsidian_embedding_fail(self): # Mock embedding failure (raises exception) self.mock_generate_embeddings.side_effect = Exception("API Error") - + with self.assertRaises(ObsidianSearchError) as ctx: asyncio.run(self.service.search_obsidian("test query")) - + self.assertEqual(ctx.exception.stage, "embedding") self.assertIn("API Error", str(ctx.exception)) self.mock_session.run.assert_not_called() @@ -178,13 +196,13 @@ def test_search_obsidian_db_fail(self): # Setup mock embedding mock_embedding = [0.1] self.mock_generate_embeddings.return_value = [mock_embedding] - + # Mock DB failure self.mock_session.run.side_effect = Exception("DB Connection Failed") - + with self.assertRaises(ObsidianSearchError) as ctx: asyncio.run(self.service.search_obsidian("test query")) - + self.assertEqual(ctx.exception.stage, "database") self.assertIn("DB Connection Failed", str(ctx.exception)) @@ -192,13 +210,14 @@ def test_search_obsidian_empty_results(self): # Setup mock embedding mock_embedding = [0.1] self.mock_generate_embeddings.return_value = [mock_embedding] - + # Mock empty DB results self.mock_session.run.return_value = [] - + response = asyncio.run(self.service.search_obsidian("test query")) - - self.assertEqual(response['results'], []) -if __name__ == '__main__': + self.assertEqual(response["results"], []) + + +if __name__ == "__main__": unittest.main() diff --git a/backends/advanced/upload_files.py b/backends/advanced/upload_files.py index 77a001f3..6943c5eb 100755 --- a/backends/advanced/upload_files.py +++ b/backends/advanced/upload_files.py @@ -18,33 +18,36 @@ # Configure colored logging class ColoredFormatter(logging.Formatter): """Custom formatter with colors for different log levels.""" - + # ANSI color codes COLORS = { - 'DEBUG': '\033[36m', # Cyan - 'INFO': '\033[32m', # Green - 'WARNING': '\033[33m', # Yellow - 'ERROR': '\033[31m', # Red - 'CRITICAL': '\033[35m', # Magenta + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[35m", # Magenta } - RESET = '\033[0m' # Reset color - + RESET = "\033[0m" # Reset color + def format(self, record): # Add color to the log level if record.levelname in self.COLORS: - record.levelname = f"{self.COLORS[record.levelname]}{record.levelname}{self.RESET}" + record.levelname = ( + f"{self.COLORS[record.levelname]}{record.levelname}{self.RESET}" + ) return super().format(record) + # Configure logging with colors logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler()] + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], ) # Apply colored formatter for handler in logging.getLogger().handlers: - handler.setFormatter(ColoredFormatter('%(asctime)s - %(levelname)s - %(message)s')) + handler.setFormatter(ColoredFormatter("%(asctime)s - %(levelname)s - %(message)s")) logger = logging.getLogger(__name__) @@ -55,46 +58,43 @@ def load_env_variables() -> Optional[str]: if not env_file.exists(): logger.error(".env file not found. Please create it with ADMIN_PASSWORD.") return None - + admin_password = None - with open(env_file, 'r') as f: + with open(env_file, "r") as f: for line in f: line = line.strip() - if line.startswith('ADMIN_PASSWORD='): - admin_password = line.split('=', 1)[1].strip('"\'') + if line.startswith("ADMIN_PASSWORD="): + admin_password = line.split("=", 1)[1].strip("\"'") break - + if not admin_password: logger.error("ADMIN_PASSWORD not found in .env file.") return None - + return admin_password -def get_admin_token(password: str, base_url: str = "http://localhost:8000") -> Optional[str]: +def get_admin_token( + password: str, base_url: str = "http://localhost:8000" +) -> Optional[str]: """Authenticate and get admin token.""" logger.info("Requesting admin token...") - + auth_url = f"{base_url}/auth/jwt/login" - + try: response = requests.post( auth_url, - data={ - 'username': 'admin@example.com', - 'password': password - }, - headers={ - 'Content-Type': 'application/x-www-form-urlencoded' - }, - timeout=10 + data={"username": "admin@example.com", "password": password}, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=10, ) - + logger.info(f"Auth response status: {response.status_code}") - + if response.status_code == 200: data = response.json() - token = data.get('access_token') + token = data.get("access_token") if token: logger.info("Admin token obtained.") return token @@ -111,7 +111,7 @@ def get_admin_token(password: str, base_url: str = "http://localhost:8000") -> O logger.error(f"Failed to parse error response as JSON: {json_error}") logger.error(f"Response text: {response.text}") return None - + except requests.exceptions.RequestException as e: logger.error(f"Request failed: {e}") return None @@ -132,7 +132,7 @@ def get_audio_duration(file_path: str) -> float: def validate_audio_format(file_path: str) -> tuple[bool, str]: """Validate that audio file is 16kHz, 16-bit mono format. - + Returns: tuple: (is_valid, error_message) """ @@ -141,42 +141,44 @@ def validate_audio_format(file_path: str) -> tuple[bool, str]: channels = wav_file.getnchannels() sample_rate = wav_file.getframerate() sample_width = wav_file.getsampwidth() - + errors = [] - + if channels != 1: errors.append(f"Expected mono (1 channel), got {channels} channels") - + if sample_rate != 16000: errors.append(f"Expected 16kHz sample rate, got {sample_rate}Hz") - + if sample_width != 2: # 2 bytes = 16 bits errors.append(f"Expected 16-bit audio, got {sample_width * 8}-bit") - + if errors: return False, "; ".join(errors) - + return True, "" - + except Exception as e: return False, f"Error reading WAV file: {e}" -def collect_wav_files(audio_dir: str, filter_list: Optional[list[str]] = None) -> list[str]: +def collect_wav_files( + audio_dir: str, filter_list: Optional[list[str]] = None +) -> list[str]: """Collect all .wav files from the specified directory with duration checking.""" logger.info(f"Collecting .wav files from {audio_dir} ...") - + audio_path = Path(audio_dir).expanduser() if not audio_path.exists(): logger.error(f"Directory {audio_path} does not exist.") return [] - + wav_files = list(audio_path.glob("*.wav")) - + if not wav_files: logger.warning(f"No .wav files found in {audio_path}") return [] - + # Filter files if filter_list is provided, otherwise accept all if filter_list is None: candidate_files = wav_files @@ -187,45 +189,57 @@ def collect_wav_files(audio_dir: str, filter_list: Optional[list[str]] = None) - candidate_files.append(f) else: logger.info(f"Skipping file (not in filter): {f.name}") - + # Check duration and filter out files over 20 minutes selected_files = [] total_duration = 0.0 - + for file_path in candidate_files: # First validate audio format is_valid, format_error = validate_audio_format(str(file_path)) if not is_valid: - logger.error(f"πŸ”΄ SKIPPING: {file_path.name} - Invalid format: {format_error}") + logger.error( + f"πŸ”΄ SKIPPING: {file_path.name} - Invalid format: {format_error}" + ) continue - + duration = get_audio_duration(str(file_path)) duration_minutes = duration / 60.0 - - + selected_files.append(file_path) total_duration += duration - logger.info(f"βœ… Added file: {file_path.name} (duration: {duration_minutes:.1f} minutes)") - + logger.info( + f"βœ… Added file: {file_path.name} (duration: {duration_minutes:.1f} minutes)" + ) + total_minutes = total_duration / 60.0 - logger.info(f"πŸ“Š Total files to upload: {len(selected_files)} (total duration: {total_minutes:.1f} minutes)") - + logger.info( + f"πŸ“Š Total files to upload: {len(selected_files)} (total duration: {total_minutes:.1f} minutes)" + ) + return [str(f) for f in selected_files] -def upload_files_async(files: list[str], token: str, base_url: str = "http://localhost:8000") -> bool: +def upload_files_async( + files: list[str], token: str, base_url: str = "http://localhost:8000" +) -> bool: """Upload files to the backend for async processing with real-time progress tracking.""" if not files: logger.error("No files to upload.") return False - + logger.info(f"πŸš€ Starting async upload to {base_url}/api/audio/upload ...") # Prepare files for upload files_data = [] for file_path in files: try: - files_data.append(('files', (os.path.basename(file_path), open(file_path, 'rb'), 'audio/wav'))) + files_data.append( + ( + "files", + (os.path.basename(file_path), open(file_path, "rb"), "audio/wav"), + ) + ) except IOError as e: logger.error(f"Error opening file {file_path}: {e}") continue @@ -239,17 +253,15 @@ def upload_files_async(files: list[str], token: str, base_url: str = "http://loc response = requests.post( f"{base_url}/api/audio/upload", files=files_data, - data={'device_name': 'file_upload_batch'}, - headers={ - 'Authorization': f'Bearer {token}' - }, - timeout=60 # Short timeout for job submission + data={"device_name": "file_upload_batch"}, + headers={"Authorization": f"Bearer {token}"}, + timeout=60, # Short timeout for job submission ) - + # Close all file handles for _, file_tuple in files_data: file_tuple[1].close() - + if response.status_code != 200: logger.error(f"Failed to start async processing: {response.status_code}") try: @@ -258,19 +270,19 @@ def upload_files_async(files: list[str], token: str, base_url: str = "http://loc except: logger.error(f"Response text: {response.text}") return False - + # Get job ID job_data = response.json() job_id = job_data.get("job_id") total_files = job_data.get("total_files", 0) - + logger.info(f"βœ… Job started successfully: {job_id}") logger.info(f"πŸ“Š Processing {total_files} files...") logger.info(f"πŸ”— Status URL: {job_data.get('status_url', 'N/A')}") - + # Poll for job completion return poll_job_status(job_id, token, base_url, total_files) - + except requests.exceptions.Timeout: logger.error("Job submission timed out.") return False @@ -289,38 +301,42 @@ def upload_files_async(files: list[str], token: str, base_url: str = "http://loc def poll_job_status(job_id: str, token: str, base_url: str, total_files: int) -> bool: """Poll job status until completion with progress updates.""" status_url = f"{base_url}/api/queue/jobs/{job_id}" - headers = {'Authorization': f'Bearer {token}'} - + headers = {"Authorization": f"Bearer {token}"} + start_time = time.time() last_progress = -1 last_current_file = None - + logger.info("πŸ”„ Polling job status...") - + while True: try: response = requests.get(status_url, headers=headers, timeout=30) - + if response.status_code != 200: logger.error(f"Failed to get job status: {response.status_code}") return False - + job_status = response.json() status = job_status.get("status") progress = job_status.get("progress_percent", 0) current_file = job_status.get("current_file") processed_files = job_status.get("processed_files", 0) - + # Show progress updates if progress != last_progress or current_file != last_current_file: elapsed = time.time() - start_time if current_file: - logger.info(f"πŸ“ˆ Progress: {progress:.1f}% ({processed_files}/{total_files}) - Processing: {current_file} (elapsed: {elapsed:.0f}s)") + logger.info( + f"πŸ“ˆ Progress: {progress:.1f}% ({processed_files}/{total_files}) - Processing: {current_file} (elapsed: {elapsed:.0f}s)" + ) else: - logger.info(f"πŸ“ˆ Progress: {progress:.1f}% ({processed_files}/{total_files}) (elapsed: {elapsed:.0f}s)") + logger.info( + f"πŸ“ˆ Progress: {progress:.1f}% ({processed_files}/{total_files}) (elapsed: {elapsed:.0f}s)" + ) last_progress = progress last_current_file = current_file - + # Check completion status (RQ standard: "finished") if status == "finished": elapsed = time.time() - start_time @@ -331,14 +347,14 @@ def poll_job_status(job_id: str, token: str, base_url: str, total_files: int) -> completed = len([f for f in files if f.get("status") == "finished"]) failed = len([f for f in files if f.get("status") == "failed"]) skipped = len([f for f in files if f.get("status") == "skipped"]) - + logger.info(f"πŸ“Š Final Summary:") logger.info(f" βœ… Completed: {completed}") if failed > 0: logger.error(f" ❌ Failed: {failed}") if skipped > 0: logger.warning(f" ⏭️ Skipped: {skipped}") - + # Show failed files for file_info in files: if file_info.get("status") == "failed": @@ -346,16 +362,18 @@ def poll_job_status(job_id: str, token: str, base_url: str, total_files: int) -> logger.error(f" ❌ {file_info.get('filename')}: {error_msg}") elif file_info.get("status") == "skipped": error_msg = file_info.get("error_message", "Skipped") - logger.warning(f" ⏭️ {file_info.get('filename')}: {error_msg}") - + logger.warning( + f" ⏭️ {file_info.get('filename')}: {error_msg}" + ) + return completed > 0 # Success if at least one file completed - + elif status == "failed": elapsed = time.time() - start_time error_msg = job_status.get("error_message", "Unknown error") logger.error(f"πŸ’₯ Job failed after {elapsed:.0f}s: {error_msg}") return False - + elif status in ["queued", "processing"]: # Continue polling time.sleep(5) # Poll every 5 seconds @@ -364,7 +382,7 @@ def poll_job_status(job_id: str, token: str, base_url: str, total_files: int) -> logger.warning(f"Unknown job status: {status}") time.sleep(5) continue - + except requests.exceptions.RequestException as e: logger.error(f"Error polling job status: {e}") time.sleep(10) # Wait longer on error @@ -376,16 +394,18 @@ def poll_job_status(job_id: str, token: str, base_url: str, total_files: int) -> def parse_args(): """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Upload audio files to Chronicle backend") + parser = argparse.ArgumentParser( + description="Upload audio files to Chronicle backend" + ) parser.add_argument( "files", nargs="*", - help="Audio files to upload. If none provided, uses default test file." + help="Audio files to upload. If none provided, uses default test file.", ) parser.add_argument( "--base-url", default="http://localhost:8000", - help="Backend base URL (default: http://localhost:8000)" + help="Backend base URL (default: http://localhost:8000)", ) return parser.parse_args() @@ -393,20 +413,20 @@ def parse_args(): def main(): """Main function to orchestrate the upload process.""" args = parse_args() - + logger.info("Chronicle Audio File Upload Tool") logger.info("=" * 40) - + # Load environment variables admin_password = load_env_variables() if not admin_password: sys.exit(1) - + # Get admin token token = get_admin_token(admin_password, args.base_url) if not token: sys.exit(1) - + # Determine files to upload if args.files: # Use provided files @@ -422,26 +442,33 @@ def main(): else: # Use default test file (committed to git, used in tests) project_root = Path(__file__).parent.parent.parent - specific_file = project_root / "extras" / "test-audios" / "DIY Experts Glass Blowing_16khz_mono_4min.wav" - + specific_file = ( + project_root + / "extras" + / "test-audios" + / "DIY Experts Glass Blowing_16khz_mono_4min.wav" + ) + if specific_file.exists(): wav_files = [str(specific_file)] logger.info(f"Using default test file: {specific_file}") else: logger.error(f"Default test file not found: {specific_file}") - logger.info("Please provide file paths as arguments or ensure test file exists") + logger.info( + "Please provide file paths as arguments or ensure test file exists" + ) sys.exit(1) - + if not wav_files: logger.error("No files to upload") sys.exit(1) - + logger.info(f"Uploading {len(wav_files)} files:") for f in wav_files: logger.info(f"- {os.path.basename(f)}") - + success = upload_files_async(wav_files, token, args.base_url) - + if success: logger.info("πŸŽ‰ Upload process completed successfully!") sys.exit(0) @@ -451,4 +478,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/backends/advanced/webui/nginx.conf b/backends/advanced/webui/nginx.conf index eee76bad..bc75b9a5 100644 --- a/backends/advanced/webui/nginx.conf +++ b/backends/advanced/webui/nginx.conf @@ -6,7 +6,7 @@ server { # Basic settings client_max_body_size 100M; - + # Gzip compression gzip on; gzip_vary on; @@ -44,4 +44,4 @@ server { add_header X-Content-Type-Options nosniff; add_header X-XSS-Protection "1; mode=block"; add_header Referrer-Policy "strict-origin-when-cross-origin"; -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/components/ConversationVersionDropdown.tsx b/backends/advanced/webui/src/components/ConversationVersionDropdown.tsx index ed21f69c..cfd782d0 100644 --- a/backends/advanced/webui/src/components/ConversationVersionDropdown.tsx +++ b/backends/advanced/webui/src/components/ConversationVersionDropdown.tsx @@ -252,4 +252,4 @@ export default function ConversationVersionDropdown({ )} ) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/components/ConversationVersionHeader.tsx b/backends/advanced/webui/src/components/ConversationVersionHeader.tsx index 55627c4f..4b22cb8a 100644 --- a/backends/advanced/webui/src/components/ConversationVersionHeader.tsx +++ b/backends/advanced/webui/src/components/ConversationVersionHeader.tsx @@ -110,4 +110,4 @@ export default function ConversationVersionHeader({ conversationId, versionInfo, ); -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/components/audio/DebugPanel.tsx b/backends/advanced/webui/src/components/audio/DebugPanel.tsx index a3785f1d..9424a91e 100644 --- a/backends/advanced/webui/src/components/audio/DebugPanel.tsx +++ b/backends/advanced/webui/src/components/audio/DebugPanel.tsx @@ -20,14 +20,14 @@ export default function DebugPanel({ recording }: DebugPanelProps) { Attempts: {recording.debugStats.connectionAttempts}

- +

Audio Chunks

Sent: {recording.debugStats.chunksSent}

- Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? + Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? Math.round(recording.debugStats.chunksSent / ((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000)) : 0}/s

@@ -45,7 +45,7 @@ export default function DebugPanel({ recording }: DebugPanelProps) {

Session

- Duration: {recording.debugStats.sessionStartTime ? + Duration: {recording.debugStats.sessionStartTime ? Math.round((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000) + 's' : 'N/A'}

@@ -72,4 +72,4 @@ export default function DebugPanel({ recording }: DebugPanelProps) {

) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/components/audio/RecordingStatus.tsx b/backends/advanced/webui/src/components/audio/RecordingStatus.tsx index b208beaa..03a9a19f 100644 --- a/backends/advanced/webui/src/components/audio/RecordingStatus.tsx +++ b/backends/advanced/webui/src/components/audio/RecordingStatus.tsx @@ -8,7 +8,7 @@ interface RecordingStatusProps { export default function RecordingStatus({ recording }: RecordingStatusProps) { const { user } = useAuth() - + const getStatusIcon = () => { switch (recording.connectionStatus) { case 'connected': @@ -51,7 +51,7 @@ export default function RecordingStatus({ recording }: RecordingStatusProps) {

- +

User: {user?.name || user?.email} @@ -66,7 +66,7 @@ export default function RecordingStatus({ recording }: RecordingStatusProps) { {/* Component Status Indicators */}

πŸ“Š Component Status

- +
{/* WebSocket Status */}
@@ -141,4 +141,4 @@ export default function RecordingStatus({ recording }: RecordingStatusProps) {
) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx b/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx index db17d626..febc90c8 100644 --- a/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx +++ b/backends/advanced/webui/src/components/audio/SimpleDebugPanel.tsx @@ -20,14 +20,14 @@ export default function SimpleDebugPanel({ recording }: SimpleDebugPanelProps) { Recording: {recording.isRecording ? 'Yes' : 'No'}

- +

Audio Chunks

Sent: {recording.debugStats.chunksSent}

- Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? + Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? Math.round(recording.debugStats.chunksSent / ((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000)) : 0}/s

@@ -45,7 +45,7 @@ export default function SimpleDebugPanel({ recording }: SimpleDebugPanelProps) {

Session

- Duration: {recording.debugStats.sessionStartTime ? + Duration: {recording.debugStats.sessionStartTime ? Math.round((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000) + 's' : 'N/A'}

@@ -72,4 +72,4 @@ export default function SimpleDebugPanel({ recording }: SimpleDebugPanelProps) {

) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/components/audio/StatusDisplay.tsx b/backends/advanced/webui/src/components/audio/StatusDisplay.tsx index d151ef4d..d6bf8850 100644 --- a/backends/advanced/webui/src/components/audio/StatusDisplay.tsx +++ b/backends/advanced/webui/src/components/audio/StatusDisplay.tsx @@ -48,14 +48,14 @@ const getStepStatus = (stepId: RecordingStep, currentStep: RecordingStep, isReco if (stepIndex <= currentStepIndex) return 'error' return 'pending' } - + if (isRecording) { return 'completed' // All steps completed when recording } - + const stepIndex = steps.findIndex(s => s.id === stepId) const currentStepIndex = steps.findIndex(s => s.id === currentStep) - + if (stepIndex < currentStepIndex) return 'completed' if (stepIndex === currentStepIndex) return 'current' return 'pending' @@ -84,18 +84,18 @@ export default function StatusDisplay({ recording }: StatusDisplayProps) { if (recording.currentStep === 'idle' || recording.isRecording) { return null } - + return (

Recording Setup Progress

- +
{steps.map((step, index) => { const status = getStepStatus(step.id, recording.currentStep, recording.isRecording) - + return (
{step.icon}
- + {/* Step Info */}
@@ -118,7 +118,7 @@ export default function StatusDisplay({ recording }: StatusDisplayProps) { {step.description}

- + {/* Step Number */}
{index + 1} @@ -127,7 +127,7 @@ export default function StatusDisplay({ recording }: StatusDisplayProps) { ) })}
- + {/* Overall Status */}
@@ -141,4 +141,4 @@ export default function StatusDisplay({ recording }: StatusDisplayProps) {
) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/contexts/AuthContext.tsx b/backends/advanced/webui/src/contexts/AuthContext.tsx index d4761358..581118c4 100644 --- a/backends/advanced/webui/src/contexts/AuthContext.tsx +++ b/backends/advanced/webui/src/contexts/AuthContext.tsx @@ -33,7 +33,7 @@ export function AuthProvider({ children }: { children: ReactNode }) { console.log('πŸ” AuthContext: Initializing authentication...') const savedToken = localStorage.getItem(getStorageKey('token')) console.log('πŸ” AuthContext: Saved token exists:', !!savedToken) - + if (savedToken) { try { console.log('πŸ” AuthContext: Verifying token with API call...') @@ -74,11 +74,11 @@ export function AuthProvider({ children }: { children: ReactNode }) { return { success: true } } catch (error: any) { console.error('Login failed:', error) - + // Parse structured error response from backend let errorMessage = 'Login failed. Please try again.' let errorType = 'unknown' - + if (error.response?.data) { const errorData = error.response.data errorMessage = errorData.detail || errorMessage @@ -87,9 +87,9 @@ export function AuthProvider({ children }: { children: ReactNode }) { errorMessage = 'Unable to connect to server. Please check your connection and try again.' errorType = 'connection_failure' } - - return { - success: false, + + return { + success: false, error: errorMessage, errorType: errorType } @@ -115,4 +115,4 @@ export function useAuth() { throw new Error('useAuth must be used within an AuthProvider') } return context -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/hooks/useAudioRecording.ts b/backends/advanced/webui/src/hooks/useAudioRecording.ts index 164fa9d5..66bec149 100644 --- a/backends/advanced/webui/src/hooks/useAudioRecording.ts +++ b/backends/advanced/webui/src/hooks/useAudioRecording.ts @@ -22,30 +22,30 @@ export interface UseAudioRecordingReturn { // Connection state isWebSocketConnected: boolean connectionStatus: 'disconnected' | 'connecting' | 'connected' | 'error' - + // Recording state isRecording: boolean recordingDuration: number audioProcessingStarted: boolean - + // Component states (direct checks, no sync issues) hasValidWebSocket: boolean hasValidMicrophone: boolean hasValidAudioContext: boolean isCurrentlyStreaming: boolean - + // Granular test states hasMicrophoneAccess: boolean hasAudioContext: boolean isStreaming: boolean - + // Error management error: string | null componentErrors: ComponentErrors - + // Debug information debugStats: DebugStats - + // Actions connectWebSocketOnly: () => Promise disconnectWebSocketOnly: () => void @@ -58,7 +58,7 @@ export interface UseAudioRecordingReturn { testFullFlowOnly: () => Promise startRecording: () => Promise stopRecording: () => void - + // Utilities formatDuration: (seconds: number) => string canAccessMicrophone: boolean @@ -72,12 +72,12 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const [recordingDuration, setRecordingDuration] = useState(0) const [error, setError] = useState(null) const [audioProcessingStarted, setAudioProcessingStarted] = useState(false) - + // Granular testing states const [hasMicrophoneAccess, setHasMicrophoneAccess] = useState(false) const [hasAudioContext, setHasAudioContext] = useState(false) const [isStreaming, setIsStreaming] = useState(false) - + // Error tracking const [componentErrors, setComponentErrors] = useState({ websocket: null, @@ -85,7 +85,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { audioContext: null, streaming: null }) - + // Debug stats const [debugStats, setDebugStats] = useState({ chunksSent: 0, @@ -95,7 +95,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { sessionStartTime: null, connectionAttempts: 0 }) - + // Refs for direct access (no state sync issues) const wsRef = useRef(null) const mediaStreamRef = useRef(null) @@ -107,18 +107,18 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const audioProcessingStartedRef = useRef(false) const chunkCountRef = useRef(0) // Note: Legacy message queue code removed as it was unused - + // Check if we're on localhost or using HTTPS const isLocalhost = window.location.hostname === 'localhost' || window.location.hostname === '127.0.0.1' const isHttps = window.location.protocol === 'https:' const canAccessMicrophone = isLocalhost || isHttps - + // Direct status checks (no state sync issues) const hasValidWebSocket = wsRef.current?.readyState === WebSocket.OPEN const hasValidMicrophone = mediaStreamRef.current !== null const hasValidAudioContext = audioContextRef.current !== null const isCurrentlyStreaming = isStreaming && hasValidWebSocket && hasValidMicrophone - + const connectWebSocket = useCallback(async () => { if (wsRef.current?.readyState === WebSocket.OPEN) { return true @@ -157,16 +157,16 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { console.log('🎀 WebSocket connected for live recording') setConnectionStatus('connected') setIsWebSocketConnected(true) - + // Add stabilization delay before resolving to prevent protocol violations setTimeout(() => { wsRef.current = ws - setDebugStats(prev => ({ - ...prev, + setDebugStats(prev => ({ + ...prev, sessionStartTime: new Date(), connectionAttempts: prev.connectionAttempts + 1 })) - + // Start keepalive ping every 30 seconds keepAliveIntervalRef.current = setInterval(() => { if (ws.readyState === WebSocket.OPEN) { @@ -179,7 +179,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { } } }, 30000) - + console.log('πŸ”Œ WebSocket stabilized and ready for messages') resolve(true) }, 100) // 100ms stabilization delay @@ -190,13 +190,13 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { setConnectionStatus('disconnected') setIsWebSocketConnected(false) wsRef.current = null - + // Clear keepalive interval if (keepAliveIntervalRef.current) { clearInterval(keepAliveIntervalRef.current) keepAliveIntervalRef.current = undefined } - + if (isRecording) { stopRecording() } @@ -210,7 +210,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { setComponentErrors(prev => ({ ...prev, websocket: errorMsg })) reject(error) } - + ws.onmessage = (event) => { // Handle any messages from the server console.log('🎀 Received message from server:', event.data) @@ -322,7 +322,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const requestMicrophoneOnly = async () => { try { setComponentErrors(prev => ({ ...prev, microphone: null })) - + if (!canAccessMicrophone) { throw new Error('Microphone access requires HTTPS or localhost') } @@ -336,10 +336,10 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { autoGainControl: true } }) - + // Clean up the stream immediately - we just wanted to test permissions stream.getTracks().forEach(track => track.stop()) - + setHasMicrophoneAccess(true) console.log('🎀 Microphone access granted') return true @@ -355,7 +355,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const createAudioContextOnly = async () => { try { setComponentErrors(prev => ({ ...prev, audioContext: null })) - + if (audioContextRef.current) { audioContextRef.current.close() } @@ -363,10 +363,10 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const audioContext = new AudioContext({ sampleRate: 16000 }) const analyser = audioContext.createAnalyser() analyser.fftSize = 256 - + audioContextRef.current = audioContext analyserRef.current = analyser - + setHasAudioContext(true) console.log('πŸ“Š Audio context created successfully') return true @@ -382,17 +382,17 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const startStreamingOnly = async () => { try { setComponentErrors(prev => ({ ...prev, streaming: null })) - + // Use direct checks instead of state if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) { throw new Error('WebSocket not connected') } - + // Check if microphone access was previously tested if (!hasMicrophoneAccess) { throw new Error('Microphone access test required first - click "Get Mic" button') } - + // Check if audio context was previously created if (!hasAudioContext) { throw new Error('Audio context test required first - click "Create Context" button') @@ -425,7 +425,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) { return } - + if (!audioProcessingStartedRef.current) { console.log('🚫 Audio processing not started yet, skipping chunk') return @@ -433,7 +433,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const inputBuffer = event.inputBuffer const inputData = inputBuffer.getChannelData(0) - + // Convert float32 to int16 PCM const pcmBuffer = new Int16Array(inputData.length) for (let i = 0; i < inputData.length; i++) { @@ -460,14 +460,14 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { wsRef.current.send(JSON.stringify(chunkHeader) + '\n') wsRef.current.send(new Uint8Array(pcmBuffer.buffer, pcmBuffer.byteOffset, pcmBuffer.byteLength)) - + // Update debug stats chunkCountRef.current++ setDebugStats(prev => ({ ...prev, chunksSent: chunkCountRef.current })) } catch (error) { console.error('Failed to send audio chunk:', error) - setDebugStats(prev => ({ - ...prev, + setDebugStats(prev => ({ + ...prev, lastError: error instanceof Error ? error.message : 'Chunk send failed', lastErrorTime: new Date() })) @@ -516,46 +516,46 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { try { setError(null) console.log('πŸ’Ύ Starting full flow test...') - + // Step 1: Connect WebSocket const connected = await connectWebSocket() if (!connected) { throw new Error('WebSocket connection failed') } - + // Step 2: Get microphone access const micAccess = await requestMicrophoneOnly() if (!micAccess) { throw new Error('Microphone access failed') } - + // Step 3: Create audio context const contextCreated = await createAudioContextOnly() if (!contextCreated) { throw new Error('Audio context creation failed') } - + // Step 4: Send audio-start const startSent = await sendAudioStartOnly() if (!startSent) { throw new Error('Audio-start message failed') } - + // Step 5: Start streaming for 10 seconds const streamingStarted = await startStreamingOnly() if (!streamingStarted) { throw new Error('Audio streaming failed') } - + console.log('πŸ’Ύ Full flow test running for 10 seconds...') - + // Wait 10 seconds setTimeout(() => { stopStreamingOnly() sendAudioStopOnly() console.log('πŸ’Ύ Full flow test completed') }, 10000) - + return true } catch (error) { console.error('Full flow test failed:', error) @@ -596,10 +596,10 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const audioContext = new AudioContext({ sampleRate: 16000 }) const analyser = audioContext.createAnalyser() const source = audioContext.createMediaStreamSource(stream) - + analyser.fftSize = 256 source.connect(analyser) - + audioContextRef.current = audioContext analyserRef.current = analyser @@ -636,7 +636,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) { return } - + // Don't send audio chunks until audio-start has been sent and processed if (!audioProcessingStartedRef.current) { console.log('🚫 Audio processing not started yet, skipping chunk') @@ -645,7 +645,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { const inputBuffer = event.inputBuffer const inputData = inputBuffer.getChannelData(0) - + // Convert float32 to int16 PCM const pcmBuffer = new Int16Array(inputData.length) for (let i = 0; i < inputData.length; i++) { @@ -675,14 +675,14 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { wsRef.current.send(JSON.stringify(chunkHeader) + '\n') // Send the actual Int16Array buffer, not the underlying ArrayBuffer wsRef.current.send(new Uint8Array(pcmBuffer.buffer, pcmBuffer.byteOffset, pcmBuffer.byteLength)) - + // Update debug stats chunkCountRef.current++ setDebugStats(prev => ({ ...prev, chunksSent: chunkCountRef.current })) } catch (error) { console.error('Failed to send audio chunk:', error) - setDebugStats(prev => ({ - ...prev, + setDebugStats(prev => ({ + ...prev, lastError: error instanceof Error ? error.message : 'Chunk send failed', lastErrorTime: new Date() })) @@ -697,7 +697,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { setIsRecording(true) setRecordingDuration(0) - + // Start duration timer durationIntervalRef.current = setInterval(() => { setRecordingDuration(prev => prev + 1) @@ -718,7 +718,7 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { setAudioProcessingStarted(false) audioProcessingStartedRef.current = false console.log('πŸ›‘ Audio processing disabled') - + // Send Wyoming protocol stop message if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { try { @@ -792,30 +792,30 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { // Connection state isWebSocketConnected, connectionStatus, - + // Recording state isRecording, recordingDuration, audioProcessingStarted, - + // Direct status checks (no state sync issues) hasValidWebSocket, hasValidMicrophone, hasValidAudioContext, isCurrentlyStreaming, - + // Granular test states hasMicrophoneAccess, hasAudioContext, isStreaming, - + // Error management error, componentErrors, - + // Debug information debugStats, - + // Actions connectWebSocketOnly, disconnectWebSocketOnly, @@ -828,12 +828,12 @@ export const useAudioRecording = (): UseAudioRecordingReturn => { testFullFlowOnly, startRecording, stopRecording, - + // Utilities formatDuration, canAccessMicrophone, - + // Internal refs for components that need them analyserRef } as UseAudioRecordingReturn & { analyserRef: React.RefObject } -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/pages/Chat.tsx b/backends/advanced/webui/src/pages/Chat.tsx index f4a8d899..58ff757c 100644 --- a/backends/advanced/webui/src/pages/Chat.tsx +++ b/backends/advanced/webui/src/pages/Chat.tsx @@ -478,7 +478,7 @@ export default function Chat() {
- +
) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/pages/Conversations.tsx b/backends/advanced/webui/src/pages/Conversations.tsx index 872afa21..6c0ddca5 100644 --- a/backends/advanced/webui/src/pages/Conversations.tsx +++ b/backends/advanced/webui/src/pages/Conversations.tsx @@ -178,7 +178,7 @@ export default function Conversations() { const allSpeakers = useMemo(() => { const speakers = [...enrolledSpeakers] const existingNames = new Set(speakers.map(s => s.name)) - + // Add speakers from all diarization annotations diarizationAnnotations.forEach((annotations) => { annotations.forEach(a => { @@ -1994,4 +1994,4 @@ export default function Conversations() { )}
) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/pages/LiveRecord.tsx b/backends/advanced/webui/src/pages/LiveRecord.tsx index 2f934d8c..69e7e721 100644 --- a/backends/advanced/webui/src/pages/LiveRecord.tsx +++ b/backends/advanced/webui/src/pages/LiveRecord.tsx @@ -106,7 +106,7 @@ export default function LiveRecord() { {/* Audio Visualizer - Shows waveform when recording */} - @@ -136,4 +136,4 @@ export default function LiveRecord() {
) -} \ No newline at end of file +} diff --git a/backends/advanced/webui/src/pages/Queue.tsx b/backends/advanced/webui/src/pages/Queue.tsx index 3ea93db7..1305c78a 100644 --- a/backends/advanced/webui/src/pages/Queue.tsx +++ b/backends/advanced/webui/src/pages/Queue.tsx @@ -2908,4 +2908,4 @@ const Queue: React.FC = () => { ); }; -export default Queue; \ No newline at end of file +export default Queue; diff --git a/backends/advanced/webui/src/pages/Users.tsx b/backends/advanced/webui/src/pages/Users.tsx index a60675bf..67db5216 100644 --- a/backends/advanced/webui/src/pages/Users.tsx +++ b/backends/advanced/webui/src/pages/Users.tsx @@ -189,7 +189,7 @@ export default function Users() { />
- +