diff --git a/.env.template b/.env.template
index 388edbf5..1a638bb5 100644
--- a/.env.template
+++ b/.env.template
@@ -218,4 +218,4 @@ WEBUI_MEMORY_REQUEST=128Mi
SPEAKER_CPU_LIMIT=2000m
SPEAKER_MEMORY_LIMIT=4Gi
SPEAKER_CPU_REQUEST=500m
-SPEAKER_MEMORY_REQUEST=2Gi
\ No newline at end of file
+SPEAKER_MEMORY_REQUEST=2Gi
diff --git a/.github/workflows/advanced-docker-compose-build.yml b/.github/workflows/advanced-docker-compose-build.yml
index 93e72d68..ff406822 100644
--- a/.github/workflows/advanced-docker-compose-build.yml
+++ b/.github/workflows/advanced-docker-compose-build.yml
@@ -140,10 +140,10 @@ jobs:
set -euo pipefail
docker compose version
OWNER_LC=$(echo "$OWNER" | tr '[:upper:]' '[:lower:]')
-
+
# CUDA variants from pyproject.toml
CUDA_VARIANTS=("cpu" "cu121" "cu126" "cu128")
-
+
# Base services (no CUDA variants, no profiles)
base_service_specs=(
"chronicle-backend|advanced-chronicle-backend|docker-compose.yml|."
@@ -151,11 +151,11 @@ jobs:
"webui|advanced-webui|docker-compose.yml|."
"openmemory-mcp|openmemory-mcp|../../extras/openmemory-mcp/docker-compose.yml|../../extras/openmemory-mcp"
)
-
+
# Build and push base services
for spec in "${base_service_specs[@]}"; do
IFS='|' read -r svc svc_repo compose_file project_dir <<< "$spec"
-
+
echo "::group::Building and pushing $svc_repo"
if [ "$compose_file" = "docker-compose.yml" ] && [ "$project_dir" = "." ]; then
docker compose build --pull "$svc"
@@ -173,7 +173,7 @@ jobs:
echo "::endgroup::"
continue
fi
-
+
# Tag and push with version
target_image="$REGISTRY/$OWNER_LC/$svc_repo:$VERSION"
latest_image="$REGISTRY/$OWNER_LC/$svc_repo:latest"
@@ -181,21 +181,21 @@ jobs:
docker tag "$img_id" "$target_image"
echo "Tagging $img_id as $latest_image"
docker tag "$img_id" "$latest_image"
-
+
echo "Pushing $target_image"
docker push "$target_image"
echo "Pushing $latest_image"
docker push "$latest_image"
-
+
# Clean up local tags
docker image rm -f "$target_image" || true
docker image rm -f "$latest_image" || true
echo "::endgroup::"
-
+
# Aggressive cleanup to save space
docker system prune -af || true
done
-
+
# Build and push parakeet-asr with CUDA variants (cu121, cu126, cu128)
echo "::group::Building and pushing parakeet-asr CUDA variants"
cd ../../extras/asr-services
@@ -203,7 +203,7 @@ jobs:
echo "Building parakeet-asr-${cuda_variant}"
export PYTORCH_CUDA_VERSION="${cuda_variant}"
docker compose build parakeet-asr
-
+
img_id=$(docker compose images -q parakeet-asr | head -n1)
if [ -n "${img_id:-}" ]; then
target_image="$REGISTRY/$OWNER_LC/parakeet-asr-${cuda_variant}:$VERSION"
@@ -212,23 +212,23 @@ jobs:
docker tag "$img_id" "$target_image"
echo "Tagging $img_id as $latest_image"
docker tag "$img_id" "$latest_image"
-
+
echo "Pushing $target_image"
docker push "$target_image"
echo "Pushing $latest_image"
docker push "$latest_image"
-
+
# Clean up local tags
docker image rm -f "$target_image" || true
docker image rm -f "$latest_image" || true
fi
-
+
# Aggressive cleanup to save space
docker system prune -af || true
done
cd - > /dev/null
echo "::endgroup::"
-
+
# Build and push speaker-recognition with all CUDA variants (including CPU)
# Note: speaker-service has profiles, but we can build it directly by setting PYTORCH_CUDA_VERSION
echo "::group::Building and pushing speaker-recognition variants"
@@ -238,7 +238,7 @@ jobs:
export PYTORCH_CUDA_VERSION="${cuda_variant}"
# Build speaker-service directly (profiles only affect 'up', not 'build')
docker compose build speaker-service
-
+
img_id=$(docker compose images -q speaker-service | head -n1)
if [ -n "${img_id:-}" ]; then
target_image="$REGISTRY/$OWNER_LC/speaker-recognition-${cuda_variant}:$VERSION"
@@ -247,23 +247,23 @@ jobs:
docker tag "$img_id" "$target_image"
echo "Tagging $img_id as $latest_image"
docker tag "$img_id" "$latest_image"
-
+
echo "Pushing $target_image"
docker push "$target_image"
echo "Pushing $latest_image"
docker push "$latest_image"
-
+
# Clean up local tags
docker image rm -f "$target_image" || true
docker image rm -f "$latest_image" || true
fi
-
+
# Aggressive cleanup to save space
docker system prune -af || true
done
cd - > /dev/null
echo "::endgroup::"
-
+
# Summary
echo "::group::Build Summary"
echo "Built and pushed images with version tag: ${VERSION}"
diff --git a/CLAUDE.md b/CLAUDE.md
index fc3d8818..2fa839a8 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -613,4 +613,4 @@ The uv package manager is used for all python projects. Wherever you'd call `pyt
- Only use `--no-cache` when explicitly needed (e.g., if cached layers are causing issues or when troubleshooting build problems)
- Docker's build cache is efficient and saves significant time during development
-- Remember that whenever there's a python command, you should use uv run python3 instead
\ No newline at end of file
+- Remember that whenever there's a python command, you should use uv run python3 instead
diff --git a/Docs/init-system.md b/Docs/init-system.md
index 895d727d..9d859226 100644
--- a/Docs/init-system.md
+++ b/Docs/init-system.md
@@ -3,7 +3,7 @@
## Quick Links
- **π [Start Here: Quick Start Guide](../quickstart.md)** - Main setup path for new users
-- **π [Full Documentation](../CLAUDE.md)** - Comprehensive reference
+- **π [Full Documentation](../CLAUDE.md)** - Comprehensive reference
- **ποΈ [Architecture Details](overview.md)** - Technical deep dive
---
@@ -59,12 +59,12 @@ Each service can be configured independently:
cd backends/advanced
uv run --with-requirements setup-requirements.txt python init.py
-# Speaker Recognition only
+# Speaker Recognition only
cd extras/speaker-recognition
./setup.sh
# ASR Services only
-cd extras/asr-services
+cd extras/asr-services
./setup.sh
# OpenMemory MCP only
@@ -80,21 +80,21 @@ cd extras/openmemory-mcp
- **Generates**: Complete `.env` file with all required configuration
- **Default ports**: Backend (8000), WebUI (5173)
-### Speaker Recognition
+### Speaker Recognition
- **Prompts for**: Hugging Face token, compute mode (cpu/gpu)
- **Service port**: 8085
- **WebUI port**: 5173
- **Requires**: HF_TOKEN for pyannote models
### ASR Services
-- **Starts**: Parakeet ASR service via Docker Compose
+- **Starts**: Parakeet ASR service via Docker Compose
- **Service port**: 8767
- **Purpose**: Offline speech-to-text processing
- **No configuration required**
### OpenMemory MCP
- **Starts**: External OpenMemory MCP server
-- **Service port**: 8765
+- **Service port**: 8765
- **WebUI**: Available at http://localhost:8765
- **Purpose**: Cross-client memory compatibility
@@ -112,10 +112,10 @@ Note (Linux): If `host.docker.internal` is unavailable, add `extra_hosts: - "hos
## Key Benefits
-β
**No Unnecessary Building** - Services are only started when you explicitly request them
-β
**Resource Efficient** - Parakeet ASR won't start if you're using cloud transcription
-β
**Clean Separation** - Configuration vs service management are separate concerns
-β
**Unified Control** - Single command to start/stop all services
+β
**No Unnecessary Building** - Services are only started when you explicitly request them
+β
**Resource Efficient** - Parakeet ASR won't start if you're using cloud transcription
+β
**Clean Separation** - Configuration vs service management are separate concerns
+β
**Unified Control** - Single command to start/stop all services
β
**Selective Starting** - Choose which services to run based on your current needs
## Ports & Access
@@ -215,7 +215,7 @@ You can also manage services individually:
# Advanced Backend
cd backends/advanced && docker compose up --build -d
-# Speaker Recognition
+# Speaker Recognition
cd extras/speaker-recognition && docker compose up --build -d
# ASR Services (only if using offline transcription)
@@ -225,6 +225,59 @@ cd extras/asr-services && docker compose up --build -d
cd extras/openmemory-mcp && docker compose up --build -d
```
+## Startup Flow (Mermaid) diagram
+
+Chronicle has two layers:
+- **Setup** (`wizard.sh` / `wizard.py`) writes config (`.env`, `config/config.yml`, optional SSL/nginx config).
+- **Run** (`start.sh` / `services.py`) starts the configured services via `docker compose`.
+
+```mermaid
+flowchart TD
+ A[wizard.sh] --> B[uv run --with-requirements setup-requirements.txt wizard.py]
+ B --> C{Select services}
+ C --> D[backends/advanced/init.py\nwrites backends/advanced/.env + config/config.yml]
+ C --> E[extras/speaker-recognition/init.py\nwrites extras/speaker-recognition/.env\noptionally ssl/* + nginx.conf]
+ C --> F[extras/asr-services/init.py\nwrites extras/asr-services/.env]
+ C --> G[extras/openmemory-mcp/setup.sh]
+
+ A2[start.sh] --> B2[uv run --with-requirements setup-requirements.txt python services.py start ...]
+ B2 --> H{For each service:\n.env exists?}
+ H -->|yes| I[services.py runs docker compose\nin each service directory]
+ H -->|no| J[Skip (not configured)]
+```
+
+### How `services.py` picks Speaker Recognition variants
+
+`services.py` reads `extras/speaker-recognition/.env` and decides:
+- `COMPUTE_MODE=cpu|gpu|strixhalo` β choose compose profile
+- `REACT_UI_HTTPS=true|false` β include `nginx` (HTTPS) vs run only API+UI (HTTP)
+
+```mermaid
+flowchart TD
+ S[start.sh] --> P[services.py]
+ P --> R[Read extras/speaker-recognition/.env]
+ R --> M{COMPUTE_MODE}
+ M -->|cpu| C1[docker compose --profile cpu up ...]
+ M -->|gpu| C2[docker compose --profile gpu up ...]
+ M -->|strixhalo| C3[docker compose --profile strixhalo up ...]
+ R --> H{REACT_UI_HTTPS}
+ H -->|true| N1[Start profile default set:\nAPI + web-ui + nginx]
+ H -->|false| N2[Start only:\nAPI + web-ui (no nginx)]
+```
+
+### CPU + NVIDIA share the same `Dockerfile` + `pyproject.toml`
+
+Speaker recognition uses a single dependency definition with per-accelerator βextrasβ:
+- `extras/speaker-recognition/pyproject.toml` defines extras like `cpu`, `cu121`, `cu126`, `cu128`, `strixhalo`.
+- `extras/speaker-recognition/Dockerfile` takes `ARG PYTORCH_CUDA_VERSION` and runs:
+ - `uv sync --extra ${PYTORCH_CUDA_VERSION}`
+ - `uv run --extra ${PYTORCH_CUDA_VERSION} ...`
+- `extras/speaker-recognition/docker-compose.yml` sets that build arg per profile:
+ - CPU profile defaults to `PYTORCH_CUDA_VERSION=cpu`
+ - GPU profile defaults to `PYTORCH_CUDA_VERSION=cu126` and reserves NVIDIA GPUs
+
+AMD/ROCm (Strix Halo) uses the same `pyproject.toml` interface (the `strixhalo` extra), but a different build recipe (`extras/speaker-recognition/Dockerfile.strixhalo`) and ROCm device mappings, because the base image provides the torch stack.
+
## Configuration Files
### Generated Files
@@ -250,7 +303,7 @@ cd extras/openmemory-mcp && docker compose up --build -d
# Backend health
curl http://localhost:8000/health
-# Speaker Recognition health
+# Speaker Recognition health
curl http://localhost:8085/health
# ASR service health
@@ -267,4 +320,4 @@ cd backends/advanced && docker compose logs chronicle-backend
# Speaker Recognition logs
cd extras/speaker-recognition && docker compose logs speaker-service
-```
\ No newline at end of file
+```
diff --git a/README-K8S.md b/README-K8S.md
index 8bbe22fa..6a36e038 100644
--- a/README-K8S.md
+++ b/README-K8S.md
@@ -47,7 +47,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
2. **Install Ubuntu Server**
- Boot from USB/DVD
- Choose "Install Ubuntu Server"
- - Configure network with static IP (recommended: 192.168.1.42)
+ - Configure network with static IP (recommended: 192.168.1.42)
- Set hostname (e.g., `k8s_control_plane`)
- Create user account
- Install OpenSSH server
@@ -56,10 +56,10 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# Update system
sudo apt update && sudo apt upgrade -y
-
+
# Install essential packages
sudo apt install -y curl wget git vim htop tree
-
+
# Configure firewall
sudo ufw allow ssh
sudo ufw allow 6443 # Kubernetes API
@@ -75,11 +75,11 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# Install MicroK8s
sudo snap install microk8s --classic
-
+
# Add user to microk8s group
sudo usermod -a -G microk8s $USER
sudo chown -f -R $USER ~/.kube
-
+
# Log out and back in, or run:
newgrp microk8s
```
@@ -88,10 +88,10 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# Start MicroK8s
sudo microk8s start
-
+
# Wait for all services to be ready
sudo microk8s status --wait-ready
-
+
# Generate join token for worker nodes
sudo microk8s add-node
# This will output a command like:
@@ -103,7 +103,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# Start MicroK8s
sudo microk8s start
-
+
# Wait for all services to be ready
sudo microk8s status --wait-ready
```
@@ -115,7 +115,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
sudo microk8s enable ingress
sudo microk8s enable storage
sudo microk8s enable metrics-server
-
+
# Wait for add-ons to be ready
sudo microk8s status --wait-ready
```
@@ -125,7 +125,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
# Create kubectl alias
echo 'alias kubectl="microk8s kubectl"' >> ~/.bashrc
source ~/.bashrc
-
+
# Verify installation
kubectl get nodes
kubectl get pods -A
@@ -147,11 +147,11 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# Install MicroK8s
sudo snap install microk8s --classic
-
+
# Add user to microk8s group
sudo usermod -a -G microk8s $USER
sudo chown -f -R $USER ~/.kube
-
+
# Log out and back in, or run:
newgrp microk8s
```
@@ -161,7 +161,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
# Use the join command from the control plane
# Replace with your actual join token
sudo microk8s join 192.168.1.42:25000/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
-
+
# Wait for node to join
sudo microk8s status --wait-ready
```
@@ -170,7 +170,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# On the control plane, verify the worker node joined
kubectl get nodes
-
+
# The worker node should show as Ready
# Example output:
# NAME STATUS ROLES AGE VERSION
@@ -182,7 +182,7 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# From your build machine, configure the worker node
./configure-insecure-registry-remote.sh 192.168.1.43
-
+
# Repeat for each worker node with their respective IPs
```
@@ -194,10 +194,10 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# Enable the built-in MicroK8s registry (not enabled by default)
sudo microk8s enable registry
-
+
# Wait for registry to be ready
sudo microk8s status --wait-ready
-
+
# Verify registry is running
kubectl get pods -n container-registry
```
@@ -213,11 +213,11 @@ This guide walks you through setting up Chronicle from scratch on a fresh Ubuntu
```bash
# From your build machine, configure MicroK8s to trust the insecure registry
chmod +x scripts/configure-insecure-registry-remote.sh
-
+
# Run the configuration script with your node IP address
# Usage: ./scripts/configure-insecure-registry-remote.sh
AsyncIO Events]
end
@@ -99,33 +99,33 @@ graph TB
APP --> WS
WS --> AUTH
AUTH --> WP
-
+
%% Wyoming Protocol Flow
WP --> AS
WP --> AC
WP --> AST
-
+
AS --> CM
AC --> CM
AST --> CM
-
+
%% Client State Management
CM --> CS
CS --> CT
-
+
%% Audio Processing Flow
CS -->|queue audio| AQ
AQ --> AP
AP -->|save| FS
AP -->|queue| TQ
-
+
%% Transcription Flow
TQ --> TP
TP --> DG
TP --> WY
TP -->|save transcript| AC_COL
TP -->|signal completion| TC
-
+
%% Conversation Closure Flow (Critical Path)
AST -->|close_conversation| CS
CS -->|delegate to| CONV_MGR
@@ -133,7 +133,7 @@ graph TB
TC -->|transcript ready| CONV_MGR
CONV_MGR -->|queue memory| MQ
CONV_MGR -->|if enabled| CQ
-
+
%% Memory Processing Flow
MQ --> MP
MP -->|read via| CONV_REPO
@@ -142,13 +142,13 @@ graph TB
MP --> OLL
MP -->|orchestrate by mem0| MEM
MP -->|update via| CONV_REPO
-
+
%% Cropping Flow
CQ --> CP
CP -->|read| FS
CP -->|save cropped| FS
CP -->|update| AC_COL
-
+
%% Management
PM -.-> AQ
PM -.-> TQ
@@ -385,7 +385,7 @@ BackgroundTaskManager(
### Conversation Closure and Memory Processing
-**Recent Architecture Changes**:
+**Recent Architecture Changes**:
1. Memory processing decoupled from transcription to prevent duplicates
2. **Event-driven coordination** eliminates polling/retry race conditions
@@ -410,14 +410,14 @@ The system now uses **TranscriptCoordinator** for proper async coordination:
class TranscriptCoordinator:
async def wait_for_transcript_completion(audio_uuid: str) -> bool:
# Wait for asyncio.Event (no polling!)
-
+
def signal_transcript_ready(audio_uuid: str):
# Signal completion from TranscriptionManager
```
**Benefits:**
- β
**Zero polling** - Uses asyncio events instead of sleep/retry loops
-- β
**Immediate processing** - Memory processor starts as soon as transcript is ready
+- β
**Immediate processing** - Memory processor starts as soon as transcript is ready
- β
**No race conditions** - Proper async coordination prevents timing issues
- β
**Better performance** - No artificial delays or timeout-based waiting
@@ -450,7 +450,7 @@ The `close_conversation()` method in ClientState now delegates to ConversationMa
- Most reliable trigger
- Immediate processing
-2. **Client Disconnect**
+2. **Client Disconnect**
- WebSocket connection closed
- Cleanup handler ensures closure
- Prevents orphaned conversations
@@ -473,10 +473,10 @@ async def close_current_conversation(self):
# 1. Check if conversation has content
if not self.has_transcripts():
return # No memory processing needed
-
+
# 2. Save conversation to MongoDB
audio_uuid = await self.save_conversation()
-
+
# 3. Queue memory processing (ONLY place this happens)
if audio_uuid and self.has_required_data():
processor_manager.queue_memory_processing({
@@ -484,11 +484,11 @@ async def close_current_conversation(self):
'user_id': self.user_id,
'transcript': self.get_transcript()
})
-
+
# 4. Queue audio cropping if enabled
if self.audio_cropping_enabled:
processor_manager.queue_audio_cropping(audio_path)
-
+
# 5. Reset state for next conversation
self.reset_conversation_state()
```
@@ -591,21 +591,21 @@ graph LR
Mongo[mongo:4.4.18
Primary Database]
Qdrant[qdrant
Vector Store]
end
-
+
subgraph "External Services"
Ollama[ollama
LLM Service]
ASRService[ASR Services
extras/asr-services]
end
-
+
subgraph "Client Access"
WebBrowser[Web Browser
Dashboard]
AudioClient[Audio Client
Mobile/Desktop]
end
-
+
WebBrowser -->|Port 5173 (dev)| WebUI
WebBrowser -->|Port 80 (prod)| Proxy
AudioClient -->|Port 8000| Backend
-
+
Proxy --> Backend
Proxy --> WebUI
Backend --> Mongo
@@ -639,7 +639,7 @@ graph LR
## Detailed Data Flow Architecture
-> π **Reference Documentation**:
+> π **Reference Documentation**:
> - [Authentication Details](./auth.md) - Complete authentication system documentation
### Complete System Data Flow Diagram
@@ -661,13 +661,13 @@ flowchart TB
subgraph "π΅ Audio Processing Pipeline"
WSAuth[WebSocket Auth
π Connection timeout: 30s]
OpusDecoder[Opus/PCM Decoder
Real-time Processing]
-
+
subgraph "β±οΈ Per-Client State Management"
ClientState[Client State
π Conversation timeout: 1.5min]
AudioQueue[Audio Chunk Queue
60s segments]
ConversationTimer[Conversation Timer
π Auto-timeout tracking]
end
-
+
subgraph "ποΈ Transcription Layer"
ASRManager[Transcription Manager
π Init timeout: 60s]
DeepgramWS[Deepgram WebSocket
Nova-3 Model, Smart Format
π Auto-reconnect on disconnect]
@@ -685,7 +685,7 @@ flowchart TB
LLMProcessor[Ollama LLM
π Circuit breaker protection]
VectorStore[Qdrant Vector Store
π Semantic search]
end
-
+
end
@@ -701,7 +701,7 @@ flowchart TB
AuthGW -->|β 401 Unauthorized
β±οΈ Invalid/expired token| Client
AuthGW -->|β
Validated| ClientGen
ClientGen -->|π·οΈ Generate client_id| WSAuth
-
+
%% Audio Processing Flow
Client -->|π΅ Opus/PCM Stream
π 30s connection timeout| WSAuth
WSAuth -->|β 1008 Policy Violation
π Auth required| Client
@@ -709,7 +709,7 @@ flowchart TB
OpusDecoder -->|π¦ Audio chunks| ClientState
ClientState -->|β±οΈ 1.5min timeout check| ConversationTimer
ConversationTimer -->|π Timeout exceeded| ClientState
-
+
%% Transcription Flow with Failure Points
ClientState -->|π΅ Audio data| ASRManager
ASRManager -->|π Primary connection| DeepgramWS
@@ -727,7 +727,7 @@ flowchart TB
LLMProcessor -->|β Empty response
π Fallback memory| MemoryService
LLMProcessor -->|β
Memory extracted| VectorStore
MemoryService -->|π Track processing| QueueTracker
-
+
%% Disconnect and Cleanup Flow
Client -->|π Disconnect| ClientState
@@ -765,7 +765,7 @@ flowchart TB
#### π **Disconnection Scenarios**
1. **Client Disconnect**: Graceful cleanup with conversation finalization
-2. **Network Interruption**: Auto-reconnection with exponential backoff
+2. **Network Interruption**: Auto-reconnection with exponential backoff
3. **Service Failure**: Circuit breaker protection and alternative routing
4. **Authentication Expiry**: Forced re-authentication with clear error codes
@@ -792,7 +792,7 @@ flowchart TB
**Solution**: A 2-second delay is added before calling `close_client_audio()` to ensure the transcription manager is created by the background processor. Without this delay, the flush call fails silently and transcription never completes.
**File Upload Flow**:
-1. Audio chunks queued to `transcription_queue`
+1. Audio chunks queued to `transcription_queue`
2. Background transcription processor creates `TranscriptionManager` on first chunk
3. 2-second delay ensures manager exists before flush
4. Client audio closure triggers transcript completion
diff --git a/backends/advanced/Docs/auth.md b/backends/advanced/Docs/auth.md
index b1b9c273..fd019c90 100644
--- a/backends/advanced/Docs/auth.md
+++ b/backends/advanced/Docs/auth.md
@@ -16,11 +16,11 @@ class User(BeanieBaseUser, Document):
is_active: bool = True
is_superuser: bool = False
is_verified: bool = False
-
+
# Custom fields
display_name: Optional[str] = None
registered_clients: dict[str, dict] = Field(default_factory=dict)
-
+
@property
def user_id(self) -> str:
"""Return string representation of MongoDB ObjectId for backward compatibility."""
@@ -41,7 +41,7 @@ class UserManager(BaseUserManager[User, PydanticObjectId]):
"""Authenticate with email+password"""
username = credentials.get("username")
# Email-based authentication only
-
+
async def get_by_email(self, email: str) -> Optional[User]:
"""Get user by email address"""
```
@@ -287,7 +287,7 @@ curl -H "Authorization: Bearer $TOKEN" http://localhost:8000/users/me
```bash
# Old
AUTH_USERNAME=abc123 # Custom user_id (deprecated)
-
+
# New
AUTH_USERNAME=user@example.com # Email address only
```
@@ -296,7 +296,7 @@ curl -H "Authorization: Bearer $TOKEN" http://localhost:8000/users/me
```python
# Old
username = AUTH_USERNAME # Could be email or user_id
-
+
# New
username = AUTH_USERNAME # Email address only
```
@@ -336,4 +336,4 @@ async def get_all_data(user: User = Depends(current_superuser)):
# Unified user dashboard
```
-This authentication system provides enterprise-grade security with developer-friendly APIs, supporting email/password authentication and modern OAuth flows while maintaining proper data isolation and user management capabilities using MongoDB's robust ObjectId system.
\ No newline at end of file
+This authentication system provides enterprise-grade security with developer-friendly APIs, supporting email/password authentication and modern OAuth flows while maintaining proper data isolation and user management capabilities using MongoDB's robust ObjectId system.
diff --git a/backends/advanced/Docs/memories.md b/backends/advanced/Docs/memories.md
index 08ae393e..3da4e188 100644
--- a/backends/advanced/Docs/memories.md
+++ b/backends/advanced/Docs/memories.md
@@ -4,7 +4,7 @@
This document explains how to configure and customize the memory service in the chronicle backend.
-**Code References**:
+**Code References**:
- **Main Implementation**: `src/memory/memory_service.py`
- **Event Coordination**: `src/advanced_omi_backend/transcript_coordinator.py` (zero-polling async events)
- **Repository Layer**: `src/advanced_omi_backend/conversation_repository.py` (clean data access)
@@ -16,7 +16,7 @@ This document explains how to configure and customize the memory service in the
The memory service uses [Mem0](https://mem0.ai/) to store, retrieve, and search conversation memories. It integrates with Ollama for embeddings and LLM processing, and Qdrant for vector storage.
-**Key Architecture Changes**:
+**Key Architecture Changes**:
1. **Event-Driven Processing**: Memories use asyncio events instead of polling/retry mechanisms
2. **Repository Pattern**: Clean data access through ConversationRepository
3. **User-Centric Storage**: All memories keyed by user_id instead of client_id
@@ -47,7 +47,7 @@ The memory service uses [Mem0](https://mem0.ai/) to store, retrieve, and search
**Key Flow:**
1. **Audio** β **TranscriptCoordinator** signals completion
-2. **ConversationManager** waits for event, queues memory processing
+2. **ConversationManager** waits for event, queues memory processing
3. **MemoryProcessor** uses **ConversationRepository** for data access
4. **Mem0 + Ollama** extract and store memories in **Qdrant**
@@ -88,7 +88,7 @@ MEM0_CONFIG = {
},
},
"embedder": {
- "provider": "ollama",
+ "provider": "ollama",
"config": {
"model": "nomic-embed-text:latest",
"embedding_dims": 768,
@@ -311,10 +311,10 @@ MEM0_CONFIG["vector_store"]["config"].update({
```python
def search_memories_with_filters(self, query: str, user_id: str, topic: str = None):
filters = {}
-
+
if topic:
filters["metadata.topics"] = {"$in": [topic]}
-
+
return self.memory.search(
query=query,
user_id=user_id,
@@ -329,7 +329,7 @@ def search_memories_with_filters(self, query: str, user_id: str, topic: str = No
def get_important_memories(self, user_id: str):
"""Get memories sorted by importance/frequency"""
memories = self.memory.get_all(user_id=user_id)
-
+
# Custom scoring logic
for memory in memories:
score = 0
@@ -338,7 +338,7 @@ def get_important_memories(self, user_id: str):
if "deadline" in memory.get('memory', '').lower():
score += 3
memory['importance_score'] = score
-
+
return sorted(memories, key=lambda x: x.get('importance_score', 0), reverse=True)
```
@@ -409,13 +409,13 @@ Create a custom processing function:
def custom_memory_processor(transcript: str, client_id: str, audio_uuid: str, user_id: str, user_email: str):
# Extract entities
entities = extract_named_entities(transcript)
-
+
# Classify conversation type
conv_type = classify_conversation(transcript)
-
+
# Generate custom summary
summary = generate_custom_summary(transcript, conv_type)
-
+
# Store with enriched metadata
process_memory.add(
summary,
@@ -440,12 +440,12 @@ def init_specialized_memory_services():
# Personal memories
personal_config = MEM0_CONFIG.copy()
personal_config["vector_store"]["config"]["collection_name"] = "personal_memories"
-
- # Work memories
+
+ # Work memories
work_config = MEM0_CONFIG.copy()
work_config["vector_store"]["config"]["collection_name"] = "work_memories"
work_config["custom_prompt"] = "Focus on work-related tasks, meetings, and projects"
-
+
return {
"personal": Memory.from_config(personal_config),
"work": Memory.from_config(work_config)
@@ -460,7 +460,7 @@ Implement automatic memory cleanup:
def cleanup_old_memories(self, user_id: str, days_old: int = 365):
"""Remove memories older than specified days"""
cutoff_timestamp = int(time.time()) - (days_old * 24 * 60 * 60)
-
+
memories = self.get_all_memories(user_id)
for memory in memories:
if memory.get('metadata', {}).get('timestamp', 0) < cutoff_timestamp:
@@ -533,14 +533,14 @@ async def batch_add_memories(self, transcripts_data: List[Dict]):
tasks = []
for data in transcripts_data:
task = self.add_memory(
- data['transcript'],
- data['client_id'],
+ data['transcript'],
+ data['client_id'],
data['audio_uuid'],
data['user_id'], # Database user_id
data['user_email'] # User email
)
tasks.append(task)
-
+
results = await asyncio.gather(*tasks, return_exceptions=True)
return results
```
@@ -553,14 +553,14 @@ Implement memory consolidation:
def consolidate_memories(self, user_id: str, time_window_hours: int = 24):
"""Consolidate related memories from the same time period"""
recent_memories = self.get_recent_memories(user_id, time_window_hours)
-
+
if len(recent_memories) > 5: # If many memories in short time
consolidated = self.summarize_memories(recent_memories)
-
+
# Delete individual memories and store consolidated version
for memory in recent_memories:
self.delete_memory(memory['id'])
-
+
return self.add_consolidated_memory(consolidated, user_id)
```
@@ -569,7 +569,7 @@ def consolidate_memories(self, user_id: str, time_window_hours: int = 24):
The memory service exposes these endpoints with enhanced search capabilities:
- `GET /api/memories` - Get user memories with total count support (keyed by database user_id)
-- `GET /api/memories/search?query={query}&limit={limit}` - **Semantic memory search** with relevance scoring (user-scoped)
+- `GET /api/memories/search?query={query}&limit={limit}` - **Semantic memory search** with relevance scoring (user-scoped)
- `GET /api/memories/unfiltered` - User's memories without filtering for debugging
- `DELETE /api/memories/{memory_id}` - Delete specific memory (requires authentication)
- `GET /api/memories/admin` - Admin view of all memories across all users (superuser only)
@@ -578,11 +578,11 @@ The memory service exposes these endpoints with enhanced search capabilities:
**Semantic Search (`/api/memories/search`)**:
- **Relevance Scoring**: Returns similarity scores from vector database (0.0-1.0 range)
-- **Configurable Limits**: Supports `limit` parameter for result count control
+- **Configurable Limits**: Supports `limit` parameter for result count control
- **User Scoped**: Results automatically filtered by authenticated user
- **Vector-based**: Uses embeddings for contextual understanding beyond keyword matching
-**Memory Count API**:
+**Memory Count API**:
- **Chronicle Provider**: Native Qdrant count API provides accurate total counts
- **OpenMemory MCP Provider**: Count support varies by OpenMemory implementation
- **Response Format**: `{"memories": [...], "total_count": 42}` when supported
@@ -604,7 +604,7 @@ Returns all memories across all users in a clean, searchable format:
"user_id": "abc123",
"created_at": "2025-07-10T14:30:00Z",
"owner_user_id": "abc123",
- "owner_email": "user@example.com",
+ "owner_email": "user@example.com",
"owner_display_name": "John Doe",
"metadata": {
"client_id": "abc123-laptop",
diff --git a/backends/advanced/README.md b/backends/advanced/README.md
index 104137b3..e85636a7 100644
--- a/backends/advanced/README.md
+++ b/backends/advanced/README.md
@@ -38,7 +38,7 @@ Modern React-based web dashboard located in `./webui/` with:
- **Optional Services**: Speaker Recognition, network configuration
- **API Keys**: Prompts for all required keys with helpful links
-#### 2. Start Services
+#### 2. Start Services
**HTTP Mode (Default - No SSL required):**
```bash
diff --git a/backends/advanced/scripts/create_plugin.py b/backends/advanced/scripts/create_plugin.py
index f24427ad..9668c14d 100755
--- a/backends/advanced/scripts/create_plugin.py
+++ b/backends/advanced/scripts/create_plugin.py
@@ -7,6 +7,7 @@
Usage:
uv run python scripts/create_plugin.py my_awesome_plugin
"""
+
import argparse
import os
import shutil
@@ -16,7 +17,7 @@
def snake_to_pascal(snake_str: str) -> str:
"""Convert snake_case to PascalCase."""
- return ''.join(word.capitalize() for word in snake_str.split('_'))
+ return "".join(word.capitalize() for word in snake_str.split("_"))
def create_plugin(plugin_name: str, force: bool = False):
@@ -28,19 +29,19 @@ def create_plugin(plugin_name: str, force: bool = False):
force: Overwrite existing plugin if True
"""
# Validate plugin name
- if not plugin_name.replace('_', '').isalnum():
+ if not plugin_name.replace("_", "").isalnum():
print(f"β Error: Plugin name must be alphanumeric with underscores")
print(f" Got: {plugin_name}")
print(f" Example: my_awesome_plugin")
sys.exit(1)
# Convert to class name
- class_name = snake_to_pascal(plugin_name) + 'Plugin'
+ class_name = snake_to_pascal(plugin_name) + "Plugin"
# Get plugins directory (repo root plugins/)
script_dir = Path(__file__).parent
backend_dir = script_dir.parent
- plugins_dir = backend_dir.parent.parent / 'plugins'
+ plugins_dir = backend_dir.parent.parent / "plugins"
plugin_dir = plugins_dir / plugin_name
# Check if plugin already exists
@@ -70,7 +71,7 @@ def create_plugin(plugin_name: str, force: bool = False):
__all__ = ['{class_name}']
'''
- init_file = plugin_dir / '__init__.py'
+ init_file = plugin_dir / "__init__.py"
print(f"π Creating {init_file}")
init_file.write_text(init_content, encoding="utf-8")
@@ -271,12 +272,12 @@ async def _my_helper_method(self, data: Any) -> Any:
pass
'''
- plugin_file = plugin_dir / 'plugin.py'
+ plugin_file = plugin_dir / "plugin.py"
print(f"π Creating {plugin_file}")
- plugin_file.write_text(plugin_content,encoding="utf-8")
+ plugin_file.write_text(plugin_content, encoding="utf-8")
# Create README.md
- readme_content = f'''# {class_name}
+ readme_content = f"""# {class_name}
[Brief description of what your plugin does]
@@ -363,9 +364,9 @@ async def _my_helper_method(self, data: Any) -> Any:
## License
MIT License - see project LICENSE file for details.
-'''
+"""
- readme_file = plugin_dir / 'README.md'
+ readme_file = plugin_dir / "README.md"
print(f"π Creating {readme_file}")
readme_file.write_text(readme_content, encoding="utf-8")
@@ -402,23 +403,23 @@ async def _my_helper_method(self, data: Any) -> Any:
def main():
parser = argparse.ArgumentParser(
- description='Create a new Chronicle plugin with boilerplate structure',
+ description="Create a new Chronicle plugin with boilerplate structure",
formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog='''
+ epilog="""
Examples:
uv run python scripts/create_plugin.py my_awesome_plugin
uv run python scripts/create_plugin.py slack_notifier
uv run python scripts/create_plugin.py todo_extractor --force
- '''
+ """,
)
parser.add_argument(
- 'plugin_name',
- help='Plugin name in snake_case (e.g., my_awesome_plugin)'
+ "plugin_name", help="Plugin name in snake_case (e.g., my_awesome_plugin)"
)
parser.add_argument(
- '--force', '-f',
- action='store_true',
- help='Overwrite existing plugin if it exists'
+ "--force",
+ "-f",
+ action="store_true",
+ help="Overwrite existing plugin if it exists",
)
args = parser.parse_args()
@@ -433,5 +434,5 @@ def main():
sys.exit(1)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/backends/advanced/scripts/delete_all_conversations_api.py b/backends/advanced/scripts/delete_all_conversations_api.py
index 2c47f263..25cd2b3b 100755
--- a/backends/advanced/scripts/delete_all_conversations_api.py
+++ b/backends/advanced/scripts/delete_all_conversations_api.py
@@ -4,11 +4,12 @@
This uses the proper API authentication and endpoints.
"""
+import argparse
import asyncio
import os
import sys
-import argparse
from pathlib import Path
+
import aiohttp
from dotenv import load_dotenv
@@ -20,104 +21,100 @@ async def get_auth_token():
"""Get admin authentication token."""
admin_email = os.getenv("ADMIN_EMAIL")
admin_password = os.getenv("ADMIN_PASSWORD")
-
+
if not admin_email or not admin_password:
print("Error: ADMIN_EMAIL and ADMIN_PASSWORD must be set in .env file")
sys.exit(1)
-
+
base_url = "http://localhost:8000"
-
+
async with aiohttp.ClientSession() as session:
# Login to get token
- login_data = {
- "username": admin_email,
- "password": admin_password
- }
-
+ login_data = {"username": admin_email, "password": admin_password}
+
async with session.post(
- f"{base_url}/auth/jwt/login",
- data=login_data
+ f"{base_url}/auth/jwt/login", data=login_data
) as response:
if response.status != 200:
print(f"Failed to login: {response.status}")
text = await response.text()
print(f"Response: {text}")
sys.exit(1)
-
+
result = await response.json()
return result["access_token"]
async def delete_all_conversations(skip_prompt=False):
"""Delete all conversations using the API."""
-
+
base_url = "http://localhost:8000"
-
+
# Get auth token
print("Getting admin authentication token...")
token = await get_auth_token()
-
- headers = {
- "Authorization": f"Bearer {token}"
- }
-
+
+ headers = {"Authorization": f"Bearer {token}"}
+
async with aiohttp.ClientSession() as session:
# First, get all conversations
print("Fetching all conversations...")
async with session.get(
- f"{base_url}/api/conversations",
- headers=headers
+ f"{base_url}/api/conversations", headers=headers
) as response:
if response.status != 200:
print(f"Failed to fetch conversations: {response.status}")
text = await response.text()
print(f"Response: {text}")
return
-
+
data = await response.json()
-
+
# Extract conversations from nested structure
conversations_dict = data.get("conversations", {})
conversations = []
for client_id, client_conversations in conversations_dict.items():
conversations.extend(client_conversations)
-
+
print(f"Found {len(conversations)} conversations")
-
+
if len(conversations) == 0:
print("No conversations to delete")
return
-
+
# Confirm deletion unless --yes flag is used
if not skip_prompt:
- response = input(f"Are you sure you want to delete ALL {len(conversations)} conversations? (yes/no): ")
+ response = input(
+ f"Are you sure you want to delete ALL {len(conversations)} conversations? (yes/no): "
+ )
if response.lower() != "yes":
print("Deletion cancelled")
return
-
+
# Delete each conversation
deleted_count = 0
failed_count = 0
-
+
for conv in conversations:
audio_uuid = conv.get("audio_uuid")
if not audio_uuid:
print(f"Skipping conversation without audio_uuid: {conv.get('_id')}")
continue
-
+
# Delete the conversation
async with session.delete(
- f"{base_url}/api/conversations/{audio_uuid}",
- headers=headers
+ f"{base_url}/api/conversations/{audio_uuid}", headers=headers
) as response:
if response.status == 200:
deleted_count += 1
- print(f"Deleted conversation {audio_uuid} ({deleted_count}/{len(conversations)})")
+ print(
+ f"Deleted conversation {audio_uuid} ({deleted_count}/{len(conversations)})"
+ )
else:
failed_count += 1
text = await response.text()
print(f"Failed to delete {audio_uuid}: {response.status} - {text}")
-
+
print(f"\nDeletion complete:")
print(f" Successfully deleted: {deleted_count}")
print(f" Failed: {failed_count}")
@@ -125,8 +122,10 @@ async def delete_all_conversations(skip_prompt=False):
if __name__ == "__main__":
# Parse command line arguments
- parser = argparse.ArgumentParser(description='Delete all conversations')
- parser.add_argument('--yes', '-y', action='store_true', help='Skip confirmation prompt')
+ parser = argparse.ArgumentParser(description="Delete all conversations")
+ parser.add_argument(
+ "--yes", "-y", action="store_true", help="Skip confirmation prompt"
+ )
args = parser.parse_args()
-
- asyncio.run(delete_all_conversations(skip_prompt=args.yes))
\ No newline at end of file
+
+ asyncio.run(delete_all_conversations(skip_prompt=args.yes))
diff --git a/backends/advanced/scripts/laptop_client.py b/backends/advanced/scripts/laptop_client.py
index a0047f3b..a848469f 100644
--- a/backends/advanced/scripts/laptop_client.py
+++ b/backends/advanced/scripts/laptop_client.py
@@ -21,7 +21,11 @@
def build_websocket_uri(
- host: str, port: int, endpoint: str, token: str | None = None, device_name: str = "laptop"
+ host: str,
+ port: int,
+ endpoint: str,
+ token: str | None = None,
+ device_name: str = "laptop",
) -> str:
"""Build WebSocket URI with JWT token authentication."""
base_uri = f"ws://{host}:{port}{endpoint}"
@@ -36,7 +40,9 @@ def build_websocket_uri(
return base_uri
-async def authenticate_with_credentials(host: str, port: int, username: str, password: str) -> str:
+async def authenticate_with_credentials(
+ host: str, port: int, username: str, password: str
+) -> str:
"""Authenticate with username/password and return JWT token."""
auth_url = f"http://{host}:{port}/auth/jwt/login"
@@ -58,7 +64,9 @@ async def authenticate_with_credentials(host: str, port: int, username: str, pas
raise Exception("No access token received from server")
elif response.status == 400:
error_detail = await response.text()
- raise Exception(f"Authentication failed: Invalid credentials - {error_detail}")
+ raise Exception(
+ f"Authentication failed: Invalid credentials - {error_detail}"
+ )
else:
error_detail = await response.text()
raise Exception(
@@ -96,29 +104,31 @@ def validate_auth_args(args):
async def send_wyoming_event(websocket, wyoming_event):
"""Send a Wyoming protocol event over WebSocket.
-
+
Based on how the backend processes Wyoming events, they expect:
1. JSON header line ending with \n
2. Optional binary payload if payload_length > 0
-
+
This replicates Wyoming's async_write_event behavior for WebSocket transport.
"""
# Get the event data from Wyoming event
event_data = wyoming_event.event()
-
+
# Build event dict like Wyoming's async_write_event does
event_dict = event_data.to_dict()
event_dict["version"] = "1.0.0" # Wyoming adds version
-
+
# Add payload_length if payload exists (critical for audio chunks!)
if event_data.payload:
event_dict["payload_length"] = len(event_data.payload)
-
+
# Send JSON header
- json_header = json.dumps(event_dict) + '\n'
+ json_header = json.dumps(event_dict) + "\n"
await websocket.send(json_header)
- logger.debug(f"Sent Wyoming event: {event_data.type} (payload_length: {event_dict.get('payload_length', 0)})")
-
+ logger.debug(
+ f"Sent Wyoming event: {event_data.type} (payload_length: {event_dict.get('payload_length', 0)})"
+ )
+
# Send binary payload if exists
if event_data.payload:
await websocket.send(event_data.payload)
@@ -131,11 +141,17 @@ async def main():
description="Laptop audio client for OMI backend with dual authentication modes"
)
parser.add_argument("--host", default=DEFAULT_HOST, help="WebSocket server host")
- parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="WebSocket server port")
- parser.add_argument("--endpoint", default=DEFAULT_ENDPOINT, help="WebSocket endpoint")
+ parser.add_argument(
+ "--port", type=int, default=DEFAULT_PORT, help="WebSocket server port"
+ )
+ parser.add_argument(
+ "--endpoint", default=DEFAULT_ENDPOINT, help="WebSocket endpoint"
+ )
# Authentication options (mutually exclusive)
- auth_group = parser.add_argument_group("authentication", "Choose one authentication method")
+ auth_group = parser.add_argument_group(
+ "authentication", "Choose one authentication method"
+ )
auth_group.add_argument("--token", help="JWT authentication token")
auth_group.add_argument("--username", help="Username for login authentication")
auth_group.add_argument("--password", help="Password for login authentication")
@@ -174,7 +190,9 @@ async def main():
return
# Build WebSocket URI
- ws_uri = build_websocket_uri(args.host, args.port, args.endpoint, token, args.device_name)
+ ws_uri = build_websocket_uri(
+ args.host, args.port, args.endpoint, token, args.device_name
+ )
print(f"Connecting to {ws_uri}")
print(f"Using device name: {args.device_name}")
@@ -190,10 +208,12 @@ async def send_audio():
audio_start = AudioStart(
rate=stream.sample_rate,
width=stream.sample_width,
- channels=stream.channels
+ channels=stream.channels,
)
await send_wyoming_event(websocket, audio_start)
- logger.info(f"Sent audio-start event (rate={stream.sample_rate}, width={stream.sample_width}, channels={stream.channels})")
+ logger.info(
+ f"Sent audio-start event (rate={stream.sample_rate}, width={stream.sample_width}, channels={stream.channels})"
+ )
while True:
try:
data = await stream.read()
@@ -203,18 +223,24 @@ async def send_audio():
audio=data.audio,
rate=stream.sample_rate,
width=stream.sample_width,
- channels=stream.channels
+ channels=stream.channels,
)
await send_wyoming_event(websocket, audio_chunk)
- logger.debug(f"Sent audio chunk: {len(data.audio)} bytes")
- await asyncio.sleep(0.01) # Small delay to prevent overwhelming
+ logger.debug(
+ f"Sent audio chunk: {len(data.audio)} bytes"
+ )
+ await asyncio.sleep(
+ 0.01
+ ) # Small delay to prevent overwhelming
except websockets.exceptions.ConnectionClosed:
- logger.info("WebSocket connection closed during audio sending")
+ logger.info(
+ "WebSocket connection closed during audio sending"
+ )
break
except Exception as e:
logger.error(f"Error sending audio: {e}")
break
-
+
except Exception as e:
logger.error(f"Error in audio session: {e}")
finally:
diff --git a/backends/advanced/src/advanced_omi_backend/app_config.py b/backends/advanced/src/advanced_omi_backend/app_config.py
index 5ed50618..d2f50d8e 100644
--- a/backends/advanced/src/advanced_omi_backend/app_config.py
+++ b/backends/advanced/src/advanced_omi_backend/app_config.py
@@ -57,7 +57,9 @@ def __init__(self):
f"β
Using {self.transcription_provider.name} transcription provider ({self.transcription_provider.mode})"
)
else:
- logger.warning("β οΈ No transcription provider configured - speech-to-text will not be available")
+ logger.warning(
+ "β οΈ No transcription provider configured - speech-to-text will not be available"
+ )
# External Services Configuration
self.qdrant_base_url = os.getenv("QDRANT_BASE_URL", "qdrant")
@@ -73,7 +75,9 @@ def __init__(self):
# CORS Configuration
default_origins = "http://localhost:3000,http://localhost:3001,http://127.0.0.1:3000,http://127.0.0.1:3002"
self.cors_origins = os.getenv("CORS_ORIGINS", default_origins)
- self.allowed_origins = [origin.strip() for origin in self.cors_origins.split(",") if origin.strip()]
+ self.allowed_origins = [
+ origin.strip() for origin in self.cors_origins.split(",") if origin.strip()
+ ]
# Tailscale support
self.tailscale_regex = r"http://100\.\d{1,3}\.\d{1,3}\.\d{1,3}:3000"
@@ -105,15 +109,11 @@ def get_audio_chunk_dir() -> Path:
def get_mongo_collections():
"""Get MongoDB collections."""
return {
- 'users': app_config.users_col,
- 'speakers': app_config.speakers_col,
+ "users": app_config.users_col,
+ "speakers": app_config.speakers_col,
}
def get_redis_config():
"""Get Redis configuration."""
- return {
- 'url': app_config.redis_url,
- 'encoding': "utf-8",
- 'decode_responses': False
- }
+ return {"url": app_config.redis_url, "encoding": "utf-8", "decode_responses": False}
diff --git a/backends/advanced/src/advanced_omi_backend/auth.py b/backends/advanced/src/advanced_omi_backend/auth.py
index c0d0a7b5..3bd6ea88 100644
--- a/backends/advanced/src/advanced_omi_backend/auth.py
+++ b/backends/advanced/src/advanced_omi_backend/auth.py
@@ -53,12 +53,13 @@ def _verify_configured(var_name: str, *, optional: bool = False) -> Optional[str
# Accepted token issuers - comma-separated list of services whose tokens we accept
# Default: "chronicle,ushadow" (accept tokens from both chronicle and ushadow)
ACCEPTED_ISSUERS = [
- iss.strip()
- for iss in os.getenv("ACCEPTED_TOKEN_ISSUERS", "chronicle,ushadow").split(",")
+ iss.strip()
+ for iss in os.getenv("ACCEPTED_TOKEN_ISSUERS", "chronicle,ushadow").split(",")
if iss.strip()
]
logger.info(f"Accepting tokens from issuers: {ACCEPTED_ISSUERS}")
+
class UserManager(BaseUserManager[User, PydanticObjectId]):
"""User manager with minimal customization for fastapi-users."""
@@ -108,8 +109,9 @@ async def get_user_manager(user_db=Depends(get_user_db)):
def get_jwt_strategy() -> JWTStrategy:
"""Get JWT strategy for token generation and validation."""
return JWTStrategy(
- secret=SECRET_KEY, lifetime_seconds=JWT_LIFETIME_SECONDS,
- token_audience=["fastapi-users:auth"] + ACCEPTED_ISSUERS
+ secret=SECRET_KEY,
+ lifetime_seconds=JWT_LIFETIME_SECONDS,
+ token_audience=["fastapi-users:auth"] + ACCEPTED_ISSUERS,
)
@@ -220,7 +222,9 @@ async def create_admin_user_if_needed():
existing_admin = await user_db.get_by_email(ADMIN_EMAIL)
if existing_admin:
- logger.debug(f"existing_admin.id = {existing_admin.id}, type = {type(existing_admin.id)}")
+ logger.debug(
+ f"existing_admin.id = {existing_admin.id}, type = {type(existing_admin.id)}"
+ )
logger.debug(f"str(existing_admin.id) = {str(existing_admin.id)}")
logger.debug(f"existing_admin.user_id = {existing_admin.user_id}")
logger.info(
@@ -258,25 +262,39 @@ async def websocket_auth(websocket, token: Optional[str] = None) -> Optional[Use
# Try JWT token from query parameter first
if token:
- logger.info(f"Attempting WebSocket auth with query token (first 20 chars): {token[:20]}...")
+ logger.info(
+ f"Attempting WebSocket auth with query token (first 20 chars): {token[:20]}..."
+ )
try:
user_db_gen = get_user_db()
user_db = await user_db_gen.__anext__()
user_manager = UserManager(user_db)
user = await strategy.read_token(token, user_manager)
if user and user.is_active:
- logger.info(f"WebSocket auth successful for user {user.user_id} using query token.")
+ logger.info(
+ f"WebSocket auth successful for user {user.user_id} using query token."
+ )
return user
else:
- logger.warning(f"Token validated but user inactive or not found: user={user}")
+ logger.warning(
+ f"Token validated but user inactive or not found: user={user}"
+ )
except Exception as e:
- logger.error(f"WebSocket auth with query token failed: {type(e).__name__}: {e}", exc_info=True)
+ logger.error(
+ f"WebSocket auth with query token failed: {type(e).__name__}: {e}",
+ exc_info=True,
+ )
# Try cookie authentication
logger.debug("Attempting WebSocket auth with cookie.")
try:
cookie_header = next(
- (v.decode() for k, v in websocket.headers.items() if k.lower() == b"cookie"), None
+ (
+ v.decode()
+ for k, v in websocket.headers.items()
+ if k.lower() == b"cookie"
+ ),
+ None,
)
if cookie_header:
match = re.search(r"fastapiusersauth=([^;]+)", cookie_header)
@@ -286,7 +304,9 @@ async def websocket_auth(websocket, token: Optional[str] = None) -> Optional[Use
user_manager = UserManager(user_db)
user = await strategy.read_token(match.group(1), user_manager)
if user and user.is_active:
- logger.info(f"WebSocket auth successful for user {user.user_id} using cookie.")
+ logger.info(
+ f"WebSocket auth successful for user {user.user_id} using cookie."
+ )
return user
except Exception as e:
logger.warning(f"WebSocket auth with cookie failed: {e}")
diff --git a/backends/advanced/src/advanced_omi_backend/chat_service.py b/backends/advanced/src/advanced_omi_backend/chat_service.py
index 46b734a9..d887156e 100644
--- a/backends/advanced/src/advanced_omi_backend/chat_service.py
+++ b/backends/advanced/src/advanced_omi_backend/chat_service.py
@@ -40,7 +40,7 @@
class ChatMessage:
"""Represents a chat message."""
-
+
def __init__(
self,
message_id: str,
@@ -91,7 +91,7 @@ def from_dict(cls, data: Dict) -> "ChatMessage":
class ChatSession:
"""Represents a chat session."""
-
+
def __init__(
self,
session_id: str,
@@ -152,11 +152,13 @@ async def _get_system_prompt(self) -> str:
"""
try:
reg = get_models_registry()
- if reg and hasattr(reg, 'chat'):
+ if reg and hasattr(reg, "chat"):
chat_config = reg.chat
- prompt = chat_config.get('system_prompt')
+ prompt = chat_config.get("system_prompt")
if prompt:
- logger.info(f"β
Loaded chat system prompt from config (length: {len(prompt)} chars)")
+ logger.info(
+ f"β
Loaded chat system prompt from config (length: {len(prompt)} chars)"
+ )
logger.debug(f"System prompt: {prompt[:100]}...")
return prompt
except Exception as e:
@@ -193,9 +195,15 @@ async def initialize(self):
self.messages_collection = self.db["chat_messages"]
# Create indexes for better performance
- await self.sessions_collection.create_index([("user_id", 1), ("updated_at", -1)])
- await self.messages_collection.create_index([("session_id", 1), ("timestamp", 1)])
- await self.messages_collection.create_index([("user_id", 1), ("timestamp", -1)])
+ await self.sessions_collection.create_index(
+ [("user_id", 1), ("updated_at", -1)]
+ )
+ await self.messages_collection.create_index(
+ [("session_id", 1), ("timestamp", 1)]
+ )
+ await self.messages_collection.create_index(
+ [("user_id", 1), ("timestamp", -1)]
+ )
# Initialize LLM client and memory service
self.llm_client = get_llm_client()
@@ -214,23 +222,25 @@ async def create_session(self, user_id: str, title: str = None) -> ChatSession:
await self.initialize()
session = ChatSession(
- session_id=str(uuid4()),
- user_id=user_id,
- title=title or "New Chat"
+ session_id=str(uuid4()), user_id=user_id, title=title or "New Chat"
)
await self.sessions_collection.insert_one(session.to_dict())
logger.info(f"Created new chat session {session.session_id} for user {user_id}")
return session
- async def get_user_sessions(self, user_id: str, limit: int = 50) -> List[ChatSession]:
+ async def get_user_sessions(
+ self, user_id: str, limit: int = 50
+ ) -> List[ChatSession]:
"""Get all chat sessions for a user."""
if not self._initialized:
await self.initialize()
- cursor = self.sessions_collection.find(
- {"user_id": user_id}
- ).sort("updated_at", -1).limit(limit)
+ cursor = (
+ self.sessions_collection.find({"user_id": user_id})
+ .sort("updated_at", -1)
+ .limit(limit)
+ )
sessions = []
async for doc in cursor:
@@ -243,10 +253,9 @@ async def get_session(self, session_id: str, user_id: str) -> Optional[ChatSessi
if not self._initialized:
await self.initialize()
- doc = await self.sessions_collection.find_one({
- "session_id": session_id,
- "user_id": user_id
- })
+ doc = await self.sessions_collection.find_one(
+ {"session_id": session_id, "user_id": user_id}
+ )
if doc:
return ChatSession.from_dict(doc)
@@ -258,16 +267,14 @@ async def delete_session(self, session_id: str, user_id: str) -> bool:
await self.initialize()
# Delete all messages in the session
- await self.messages_collection.delete_many({
- "session_id": session_id,
- "user_id": user_id
- })
+ await self.messages_collection.delete_many(
+ {"session_id": session_id, "user_id": user_id}
+ )
# Delete the session
- result = await self.sessions_collection.delete_one({
- "session_id": session_id,
- "user_id": user_id
- })
+ result = await self.sessions_collection.delete_one(
+ {"session_id": session_id, "user_id": user_id}
+ )
success = result.deleted_count > 0
if success:
@@ -281,10 +288,13 @@ async def get_session_messages(
if not self._initialized:
await self.initialize()
- cursor = self.messages_collection.find({
- "session_id": session_id,
- "user_id": user_id
- }).sort("timestamp", 1).limit(limit)
+ cursor = (
+ self.messages_collection.find(
+ {"session_id": session_id, "user_id": user_id}
+ )
+ .sort("timestamp", 1)
+ .limit(limit)
+ )
messages = []
async for doc in cursor:
@@ -299,10 +309,10 @@ async def add_message(self, message: ChatMessage) -> bool:
try:
await self.messages_collection.insert_one(message.to_dict())
-
+
# Update session timestamp and title if needed
update_data = {"updated_at": message.timestamp}
-
+
# Auto-generate title from first user message if session has default title
if message.role == "user":
session = await self.get_session(message.session_id, message.user_id)
@@ -315,35 +325,43 @@ async def add_message(self, message: ChatMessage) -> bool:
await self.sessions_collection.update_one(
{"session_id": message.session_id, "user_id": message.user_id},
- {"$set": update_data}
+ {"$set": update_data},
)
-
+
return True
except Exception as e:
logger.error(f"Failed to add message to session {message.session_id}: {e}")
return False
- async def get_relevant_memories(self, query: str, user_id: str) -> List[MemoryEntry]:
+ async def get_relevant_memories(
+ self, query: str, user_id: str
+ ) -> List[MemoryEntry]:
"""Get relevant memories for the user's query."""
try:
memories = await self.memory_service.search_memories(
- query=query,
- user_id=user_id,
- limit=MAX_MEMORY_CONTEXT
+ query=query, user_id=user_id, limit=MAX_MEMORY_CONTEXT
+ )
+ logger.info(
+ f"Retrieved {len(memories)} relevant memories for query: {query[:50]}..."
)
- logger.info(f"Retrieved {len(memories)} relevant memories for query: {query[:50]}...")
return memories
except Exception as e:
logger.error(f"Failed to retrieve memories for user {user_id}: {e}")
return []
async def format_conversation_context(
- self, session_id: str, user_id: str, current_message: str, include_obsidian_memory: bool = False
+ self,
+ session_id: str,
+ user_id: str,
+ current_message: str,
+ include_obsidian_memory: bool = False,
) -> Tuple[str, List[str]]:
"""Format conversation context with memory integration."""
# Get recent conversation history
- messages = await self.get_session_messages(session_id, user_id, MAX_CONVERSATION_HISTORY)
-
+ messages = await self.get_session_messages(
+ session_id, user_id, MAX_CONVERSATION_HISTORY
+ )
+
# Get relevant memories
memories = await self.get_relevant_memories(current_message, user_id)
memory_ids = [memory.id for memory in memories if memory.id]
@@ -364,14 +382,18 @@ async def format_conversation_context(
if include_obsidian_memory:
try:
obsidian_service = get_obsidian_service()
- obsidian_result = await obsidian_service.search_obsidian(current_message)
+ obsidian_result = await obsidian_service.search_obsidian(
+ current_message
+ )
obsidian_context = obsidian_result["results"]
if obsidian_context:
context_parts.append("# Relevant Obsidian Notes:")
for entry in obsidian_context:
context_parts.append(entry)
context_parts.append("")
- logger.info(f"Added {len(obsidian_context)} Obsidian notes to context")
+ logger.info(
+ f"Added {len(obsidian_context)} Obsidian notes to context"
+ )
except ObsidianSearchError as exc:
logger.error(
"Failed to get Obsidian context (%s stage): %s",
@@ -399,7 +421,11 @@ async def format_conversation_context(
return context, memory_ids
async def generate_response_stream(
- self, session_id: str, user_id: str, message_content: str, include_obsidian_memory: bool = False
+ self,
+ session_id: str,
+ user_id: str,
+ message_content: str,
+ include_obsidian_memory: bool = False,
) -> AsyncGenerator[Dict, None]:
"""Generate streaming response with memory context."""
if not self._initialized:
@@ -412,23 +438,23 @@ async def generate_response_stream(
session_id=session_id,
user_id=user_id,
role="user",
- content=message_content
+ content=message_content,
)
await self.add_message(user_message)
# Format context with memories
context, memory_ids = await self.format_conversation_context(
- session_id, user_id, message_content, include_obsidian_memory=include_obsidian_memory
+ session_id,
+ user_id,
+ message_content,
+ include_obsidian_memory=include_obsidian_memory,
)
# Send memory context used
yield {
"type": "memory_context",
- "data": {
- "memory_ids": memory_ids,
- "memory_count": len(memory_ids)
- },
- "timestamp": time.time()
+ "data": {"memory_ids": memory_ids, "memory_count": len(memory_ids)},
+ "timestamp": time.time(),
}
# Get system prompt from config
@@ -438,8 +464,10 @@ async def generate_response_stream(
full_prompt = f"{system_prompt}\n\n{context}"
# Generate streaming response
- logger.info(f"Generating response for session {session_id} with {len(memory_ids)} memories")
-
+ logger.info(
+ f"Generating response for session {session_id} with {len(memory_ids)} memories"
+ )
+
# Resolve chat operation temperature from config
chat_temp = None
registry = get_models_registry()
@@ -457,16 +485,16 @@ async def generate_response_stream(
# Simulate streaming by yielding chunks
words = response_content.split()
current_text = ""
-
+
for i, word in enumerate(words):
current_text += word + " "
-
+
# Yield every few words to simulate streaming
if i % 3 == 0 or i == len(words) - 1:
yield {
"type": "token",
"data": current_text.strip(),
- "timestamp": time.time()
+ "timestamp": time.time(),
}
await asyncio.sleep(0.05) # Small delay for realistic streaming
@@ -477,7 +505,7 @@ async def generate_response_stream(
user_id=user_id,
role="assistant",
content=response_content.strip(),
- memories_used=memory_ids
+ memories_used=memory_ids,
)
await self.add_message(assistant_message)
@@ -486,20 +514,18 @@ async def generate_response_stream(
"type": "complete",
"data": {
"message_id": assistant_message.message_id,
- "memories_used": memory_ids
+ "memories_used": memory_ids,
},
- "timestamp": time.time()
+ "timestamp": time.time(),
}
except Exception as e:
logger.error(f"Error generating response for session {session_id}: {e}")
- yield {
- "type": "error",
- "data": {"error": str(e)},
- "timestamp": time.time()
- }
+ yield {"type": "error", "data": {"error": str(e)}, "timestamp": time.time()}
- async def update_session_title(self, session_id: str, user_id: str, title: str) -> bool:
+ async def update_session_title(
+ self, session_id: str, user_id: str, title: str
+ ) -> bool:
"""Update a session's title."""
if not self._initialized:
await self.initialize()
@@ -507,7 +533,7 @@ async def update_session_title(self, session_id: str, user_id: str, title: str)
try:
result = await self.sessions_collection.update_one(
{"session_id": session_id, "user_id": user_id},
- {"$set": {"title": title, "updated_at": datetime.utcnow()}}
+ {"$set": {"title": title, "updated_at": datetime.utcnow()}},
)
return result.modified_count > 0
except Exception as e:
@@ -521,33 +547,38 @@ async def get_chat_statistics(self, user_id: str) -> Dict:
try:
# Count sessions
- session_count = await self.sessions_collection.count_documents({"user_id": user_id})
-
+ session_count = await self.sessions_collection.count_documents(
+ {"user_id": user_id}
+ )
+
# Count messages
- message_count = await self.messages_collection.count_documents({"user_id": user_id})
-
+ message_count = await self.messages_collection.count_documents(
+ {"user_id": user_id}
+ )
+
# Get most recent session
latest_session = await self.sessions_collection.find_one(
- {"user_id": user_id},
- sort=[("updated_at", -1)]
+ {"user_id": user_id}, sort=[("updated_at", -1)]
)
-
+
return {
"total_sessions": session_count,
"total_messages": message_count,
- "last_chat": latest_session["updated_at"] if latest_session else None
+ "last_chat": latest_session["updated_at"] if latest_session else None,
}
except Exception as e:
logger.error(f"Failed to get chat statistics for user {user_id}: {e}")
return {"total_sessions": 0, "total_messages": 0, "last_chat": None}
- async def extract_memories_from_session(self, session_id: str, user_id: str) -> Tuple[bool, List[str], int]:
+ async def extract_memories_from_session(
+ self, session_id: str, user_id: str
+ ) -> Tuple[bool, List[str], int]:
"""Extract and store memories from a chat session.
-
+
Args:
session_id: ID of the chat session to extract memories from
user_id: User ID for authorization and memory scoping
-
+
Returns:
Tuple of (success: bool, memory_ids: List[str], memory_count: int)
"""
@@ -556,20 +587,23 @@ async def extract_memories_from_session(self, session_id: str, user_id: str) ->
try:
# Verify session belongs to user
- session = await self.sessions_collection.find_one({
- "session_id": session_id,
- "user_id": user_id
- })
-
+ session = await self.sessions_collection.find_one(
+ {"session_id": session_id, "user_id": user_id}
+ )
+
if not session:
logger.error(f"Session {session_id} not found for user {user_id}")
return False, [], 0
# Get all messages from the session
messages = await self.get_session_messages(session_id, user_id)
-
- if not messages or len(messages) < 2: # Need at least user + assistant message
- logger.info(f"Not enough messages in session {session_id} for memory extraction")
+
+ if (
+ not messages or len(messages) < 2
+ ): # Need at least user + assistant message
+ logger.info(
+ f"Not enough messages in session {session_id} for memory extraction"
+ )
return True, [], 0
# Format messages as a transcript
@@ -577,12 +611,12 @@ async def extract_memories_from_session(self, session_id: str, user_id: str) ->
for message in messages:
role = "User" if message.role == "user" else "Assistant"
transcript_parts.append(f"{role}: {message.content}")
-
+
transcript = "\n".join(transcript_parts)
-
+
# Get user email for memory service
user_email = session.get("user_email", f"user_{user_id}")
-
+
# Extract memories using the memory service
success, memory_ids = await self.memory_service.add_memory(
transcript=transcript,
@@ -590,16 +624,20 @@ async def extract_memories_from_session(self, session_id: str, user_id: str) ->
source_id=f"chat_{session_id}",
user_id=user_id,
user_email=user_email,
- allow_update=True # Allow deduplication and updates
+ allow_update=True, # Allow deduplication and updates
)
-
+
if success:
- logger.info(f"β
Extracted {len(memory_ids)} memories from chat session {session_id}")
+ logger.info(
+ f"β
Extracted {len(memory_ids)} memories from chat session {session_id}"
+ )
return True, memory_ids, len(memory_ids)
else:
- logger.error(f"β Failed to extract memories from chat session {session_id}")
+ logger.error(
+ f"β Failed to extract memories from chat session {session_id}"
+ )
return False, [], 0
-
+
except Exception as e:
logger.error(f"Failed to extract memories from session {session_id}: {e}")
return False, [], 0
diff --git a/backends/advanced/src/advanced_omi_backend/client.py b/backends/advanced/src/advanced_omi_backend/client.py
index 79ee2957..30eed2bc 100644
--- a/backends/advanced/src/advanced_omi_backend/client.py
+++ b/backends/advanced/src/advanced_omi_backend/client.py
@@ -18,7 +18,9 @@
audio_logger = logging.getLogger("audio_processing")
# Configuration constants
-NEW_CONVERSATION_TIMEOUT_MINUTES = float(os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5"))
+NEW_CONVERSATION_TIMEOUT_MINUTES = float(
+ os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5")
+)
class ClientState:
@@ -99,7 +101,9 @@ def record_speech_end(self, audio_uuid: str, timestamp: float):
f"(duration: {duration:.3f}s)"
)
else:
- audio_logger.warning(f"Speech end recorded for {audio_uuid} but no start time found")
+ audio_logger.warning(
+ f"Speech end recorded for {audio_uuid} but no start time found"
+ )
def update_transcript_received(self):
"""Update timestamp when transcript is received (for timeout detection)."""
diff --git a/backends/advanced/src/advanced_omi_backend/client_manager.py b/backends/advanced/src/advanced_omi_backend/client_manager.py
index 68fd6ef8..2e90a4de 100644
--- a/backends/advanced/src/advanced_omi_backend/client_manager.py
+++ b/backends/advanced/src/advanced_omi_backend/client_manager.py
@@ -98,7 +98,9 @@ def get_client_count(self) -> int:
"""
return len(self._active_clients)
- def create_client(self, client_id: str, chunk_dir, user_id: str, user_email: Optional[str] = None) -> "ClientState":
+ def create_client(
+ self, client_id: str, chunk_dir, user_id: str, user_email: Optional[str] = None
+ ) -> "ClientState":
"""
Atomically create and register a new client.
@@ -327,12 +329,16 @@ def unregister_client_user_mapping(client_id: str):
if client_id in _client_to_user_mapping:
user_id = _client_to_user_mapping.pop(client_id)
logger.info(f"β Unregistered active client {client_id} from user {user_id}")
- logger.info(f"π Active client mappings: {len(_client_to_user_mapping)} remaining")
+ logger.info(
+ f"π Active client mappings: {len(_client_to_user_mapping)} remaining"
+ )
else:
logger.warning(f"β οΈ Attempted to unregister non-existent client {client_id}")
-async def track_client_user_relationship_async(client_id: str, user_id: str, ttl: int = 86400):
+async def track_client_user_relationship_async(
+ client_id: str, user_id: str, ttl: int = 86400
+):
"""
Track that a client belongs to a user (async, writes to Redis for cross-container support).
@@ -346,11 +352,15 @@ async def track_client_user_relationship_async(client_id: str, user_id: str, ttl
if _redis_client:
try:
await _redis_client.setex(f"client:owner:{client_id}", ttl, user_id)
- logger.debug(f"β
Tracked client {client_id} β user {user_id} in Redis (TTL: {ttl}s)")
+ logger.debug(
+ f"β
Tracked client {client_id} β user {user_id} in Redis (TTL: {ttl}s)"
+ )
except Exception as e:
logger.warning(f"Failed to track client in Redis: {e}")
else:
- logger.debug(f"Tracked client {client_id} relationship to user {user_id} (in-memory only)")
+ logger.debug(
+ f"Tracked client {client_id} relationship to user {user_id} (in-memory only)"
+ )
def track_client_user_relationship(client_id: str, user_id: str):
@@ -424,7 +434,9 @@ def get_user_clients_active(user_id: str) -> list[str]:
if mapped_user_id == user_id
]
- logger.debug(f"π Found {len(user_clients)} active clients for user {user_id}: {user_clients}")
+ logger.debug(
+ f"π Found {len(user_clients)} active clients for user {user_id}: {user_clients}"
+ )
return user_clients
@@ -515,7 +527,9 @@ def generate_client_id(user: "User", device_name: Optional[str] = None) -> str:
if device_name:
# Sanitize device name: lowercase, alphanumeric + hyphens only, max 10 chars
- sanitized_device = "".join(c for c in device_name.lower() if c.isalnum() or c == "-")[:10]
+ sanitized_device = "".join(
+ c for c in device_name.lower() if c.isalnum() or c == "-"
+ )[:10]
base_client_id = f"{user_id_suffix}-{sanitized_device}"
# Check for existing client IDs in database
diff --git a/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py b/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py
index 9d93d884..39b63f87 100644
--- a/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py
+++ b/backends/advanced/src/advanced_omi_backend/clients/gdrive_audio_client.py
@@ -7,6 +7,7 @@
_drive_client_cache = None
+
def get_google_drive_client():
"""Singleton Google Drive client."""
global _drive_client_cache
@@ -22,8 +23,7 @@ def get_google_drive_client():
)
creds = Credentials.from_service_account_file(
- config.gdrive_credentials_path,
- scopes=config.gdrive_scopes
+ config.gdrive_credentials_path, scopes=config.gdrive_scopes
)
_drive_client_cache = build("drive", "v3", credentials=creds)
diff --git a/backends/advanced/src/advanced_omi_backend/config.py b/backends/advanced/src/advanced_omi_backend/config.py
index 4286492a..4d96eb6f 100644
--- a/backends/advanced/src/advanced_omi_backend/config.py
+++ b/backends/advanced/src/advanced_omi_backend/config.py
@@ -19,9 +19,7 @@
load_config,
)
from advanced_omi_backend.config_loader import reload_config as reload_omegaconf_config
-from advanced_omi_backend.config_loader import (
- save_config_section,
-)
+from advanced_omi_backend.config_loader import save_config_section
logger = logging.getLogger(__name__)
@@ -34,6 +32,7 @@
# Configuration Functions (OmegaConf-based)
# ============================================================================
+
def get_config_yml_path() -> Path:
"""
Get path to config.yml file.
@@ -43,6 +42,7 @@ def get_config_yml_path() -> Path:
"""
return get_config_dir() / "config.yml"
+
def get_config(force_reload: bool = False) -> dict:
"""
Get merged configuration using OmegaConf.
@@ -68,6 +68,7 @@ def reload_config():
# Diarization Settings (OmegaConf-based)
# ============================================================================
+
def get_diarization_settings() -> dict:
"""
Get diarization settings using OmegaConf.
@@ -75,7 +76,7 @@ def get_diarization_settings() -> dict:
Returns:
Dict with diarization configuration (resolved from YAML + env vars)
"""
- cfg = get_backend_config('diarization')
+ cfg = get_backend_config("diarization")
return OmegaConf.to_container(cfg, resolve=True)
@@ -89,16 +90,18 @@ def save_diarization_settings(settings: dict) -> bool:
Returns:
True if saved successfully, False otherwise
"""
- return save_config_section('backend.diarization', settings)
+ return save_config_section("backend.diarization", settings)
# ============================================================================
# Cleanup Settings (OmegaConf-based)
# ============================================================================
+
@dataclass
class CleanupSettings:
"""Cleanup configuration for soft-deleted conversations."""
+
auto_cleanup_enabled: bool = False
retention_days: int = 30
@@ -110,7 +113,7 @@ def get_cleanup_settings() -> dict:
Returns:
Dict with auto_cleanup_enabled and retention_days
"""
- cfg = get_backend_config('cleanup')
+ cfg = get_backend_config("cleanup")
return OmegaConf.to_container(cfg, resolve=True)
@@ -125,13 +128,15 @@ def save_cleanup_settings(settings: CleanupSettings) -> bool:
True if saved successfully, False otherwise
"""
from dataclasses import asdict
- return save_config_section('backend.cleanup', asdict(settings))
+
+ return save_config_section("backend.cleanup", asdict(settings))
# ============================================================================
# Speech Detection Settings (OmegaConf-based)
# ============================================================================
+
def get_speech_detection_settings() -> dict:
"""
Get speech detection settings using OmegaConf.
@@ -139,7 +144,7 @@ def get_speech_detection_settings() -> dict:
Returns:
Dict with min_words, min_confidence, min_duration
"""
- cfg = get_backend_config('speech_detection')
+ cfg = get_backend_config("speech_detection")
return OmegaConf.to_container(cfg, resolve=True)
@@ -147,6 +152,7 @@ def get_speech_detection_settings() -> dict:
# Conversation Stop Settings (OmegaConf-based)
# ============================================================================
+
def get_conversation_stop_settings() -> dict:
"""
Get conversation stop settings using OmegaConf.
@@ -154,12 +160,14 @@ def get_conversation_stop_settings() -> dict:
Returns:
Dict with transcription_buffer_seconds, speech_inactivity_threshold
"""
- cfg = get_backend_config('conversation_stop')
+ cfg = get_backend_config("conversation_stop")
settings = OmegaConf.to_container(cfg, resolve=True)
# Add min_word_confidence from speech_detection for backward compatibility
- speech_cfg = get_backend_config('speech_detection')
- settings['min_word_confidence'] = OmegaConf.to_container(speech_cfg, resolve=True).get('min_confidence', 0.7)
+ speech_cfg = get_backend_config("speech_detection")
+ settings["min_word_confidence"] = OmegaConf.to_container(
+ speech_cfg, resolve=True
+ ).get("min_confidence", 0.7)
return settings
@@ -168,6 +176,7 @@ def get_conversation_stop_settings() -> dict:
# Audio Storage Settings (OmegaConf-based)
# ============================================================================
+
def get_audio_storage_settings() -> dict:
"""
Get audio storage settings using OmegaConf.
@@ -175,7 +184,7 @@ def get_audio_storage_settings() -> dict:
Returns:
Dict with audio_base_path, audio_chunks_path
"""
- cfg = get_backend_config('audio_storage')
+ cfg = get_backend_config("audio_storage")
return OmegaConf.to_container(cfg, resolve=True)
@@ -183,6 +192,7 @@ def get_audio_storage_settings() -> dict:
# Transcription Job Timeout (OmegaConf-based)
# ============================================================================
+
def get_transcription_job_timeout() -> int:
"""
Get transcription job timeout in seconds from config.
@@ -190,15 +200,16 @@ def get_transcription_job_timeout() -> int:
Returns:
Job timeout in seconds (default 900 = 15 minutes)
"""
- cfg = get_backend_config('transcription')
+ cfg = get_backend_config("transcription")
settings = OmegaConf.to_container(cfg, resolve=True) if cfg else {}
- return int(settings.get('job_timeout_seconds', 900))
+ return int(settings.get("job_timeout_seconds", 900))
# ============================================================================
# Miscellaneous Settings (OmegaConf-based)
# ============================================================================
+
def get_misc_settings() -> dict:
"""
Get miscellaneous configuration settings using OmegaConf.
@@ -207,23 +218,37 @@ def get_misc_settings() -> dict:
Dict with always_persist_enabled and use_provider_segments
"""
# Get audio settings for always_persist_enabled
- audio_cfg = get_backend_config('audio')
- audio_settings = OmegaConf.to_container(audio_cfg, resolve=True) if audio_cfg else {}
+ audio_cfg = get_backend_config("audio")
+ audio_settings = (
+ OmegaConf.to_container(audio_cfg, resolve=True) if audio_cfg else {}
+ )
# Get transcription settings for use_provider_segments
- transcription_cfg = get_backend_config('transcription')
- transcription_settings = OmegaConf.to_container(transcription_cfg, resolve=True) if transcription_cfg else {}
+ transcription_cfg = get_backend_config("transcription")
+ transcription_settings = (
+ OmegaConf.to_container(transcription_cfg, resolve=True)
+ if transcription_cfg
+ else {}
+ )
# Get speaker recognition settings for per_segment_speaker_id
- speaker_cfg = get_backend_config('speaker_recognition')
- speaker_settings = OmegaConf.to_container(speaker_cfg, resolve=True) if speaker_cfg else {}
+ speaker_cfg = get_backend_config("speaker_recognition")
+ speaker_settings = (
+ OmegaConf.to_container(speaker_cfg, resolve=True) if speaker_cfg else {}
+ )
return {
- 'always_persist_enabled': audio_settings.get('always_persist_enabled', False),
- 'use_provider_segments': transcription_settings.get('use_provider_segments', False),
- 'per_segment_speaker_id': speaker_settings.get('per_segment_speaker_id', False),
- 'transcription_job_timeout_seconds': int(transcription_settings.get('job_timeout_seconds', 900)),
- 'always_batch_retranscribe': transcription_settings.get('always_batch_retranscribe', False),
+ "always_persist_enabled": audio_settings.get("always_persist_enabled", False),
+ "use_provider_segments": transcription_settings.get(
+ "use_provider_segments", False
+ ),
+ "per_segment_speaker_id": speaker_settings.get("per_segment_speaker_id", False),
+ "transcription_job_timeout_seconds": int(
+ transcription_settings.get("job_timeout_seconds", 900)
+ ),
+ "always_batch_retranscribe": transcription_settings.get(
+ "always_batch_retranscribe", False
+ ),
}
@@ -240,33 +265,41 @@ def save_misc_settings(settings: dict) -> bool:
success = True
# Save audio settings if always_persist_enabled is provided
- if 'always_persist_enabled' in settings:
- audio_settings = {'always_persist_enabled': settings['always_persist_enabled']}
- if not save_config_section('backend.audio', audio_settings):
+ if "always_persist_enabled" in settings:
+ audio_settings = {"always_persist_enabled": settings["always_persist_enabled"]}
+ if not save_config_section("backend.audio", audio_settings):
success = False
# Save transcription settings if use_provider_segments is provided
- if 'use_provider_segments' in settings:
- transcription_settings = {'use_provider_segments': settings['use_provider_segments']}
- if not save_config_section('backend.transcription', transcription_settings):
+ if "use_provider_segments" in settings:
+ transcription_settings = {
+ "use_provider_segments": settings["use_provider_segments"]
+ }
+ if not save_config_section("backend.transcription", transcription_settings):
success = False
# Save speaker recognition settings if per_segment_speaker_id is provided
- if 'per_segment_speaker_id' in settings:
- speaker_settings = {'per_segment_speaker_id': settings['per_segment_speaker_id']}
- if not save_config_section('backend.speaker_recognition', speaker_settings):
+ if "per_segment_speaker_id" in settings:
+ speaker_settings = {
+ "per_segment_speaker_id": settings["per_segment_speaker_id"]
+ }
+ if not save_config_section("backend.speaker_recognition", speaker_settings):
success = False
# Save transcription job timeout if provided
- if 'transcription_job_timeout_seconds' in settings:
- timeout_settings = {'job_timeout_seconds': settings['transcription_job_timeout_seconds']}
- if not save_config_section('backend.transcription', timeout_settings):
+ if "transcription_job_timeout_seconds" in settings:
+ timeout_settings = {
+ "job_timeout_seconds": settings["transcription_job_timeout_seconds"]
+ }
+ if not save_config_section("backend.transcription", timeout_settings):
success = False
# Save always_batch_retranscribe if provided
- if 'always_batch_retranscribe' in settings:
- batch_settings = {'always_batch_retranscribe': settings['always_batch_retranscribe']}
- if not save_config_section('backend.transcription', batch_settings):
+ if "always_batch_retranscribe" in settings:
+ batch_settings = {
+ "always_batch_retranscribe": settings["always_batch_retranscribe"]
+ }
+ if not save_config_section("backend.transcription", batch_settings):
success = False
- return success
\ No newline at end of file
+ return success
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py
index ba434229..f94ff320 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/audio_controller.py
@@ -12,6 +12,9 @@
import time
import uuid
+from fastapi import UploadFile
+from fastapi.responses import JSONResponse
+
from advanced_omi_backend.config import get_transcription_job_timeout
from advanced_omi_backend.controllers.queue_controller import (
JOB_RESULT_TTL,
@@ -29,11 +32,7 @@
convert_any_to_wav,
validate_and_prepare_audio,
)
-from advanced_omi_backend.workers.transcription_jobs import (
- transcribe_full_audio_job,
-)
-from fastapi import UploadFile
-from fastapi.responses import JSONResponse
+from advanced_omi_backend.workers.transcription_jobs import transcribe_full_audio_job
logger = logging.getLogger(__name__)
audio_logger = logging.getLogger("audio_processing")
@@ -50,7 +49,7 @@ async def upload_and_process_audio_files(
user: User,
files: list[UploadFile],
device_name: str = "upload",
- source: str = "upload"
+ source: str = "upload",
) -> dict:
"""
Upload audio files and process them directly.
@@ -80,11 +79,13 @@ async def upload_and_process_audio_files(
_, ext = os.path.splitext(filename.lower())
if not ext or ext not in SUPPORTED_AUDIO_EXTENSIONS:
supported = ", ".join(sorted(SUPPORTED_AUDIO_EXTENSIONS))
- processed_files.append({
- "filename": filename,
- "status": "error",
- "error": f"Unsupported format '{ext}'. Supported: {supported}",
- })
+ processed_files.append(
+ {
+ "filename": filename,
+ "status": "error",
+ "error": f"Unsupported format '{ext}'. Supported: {supported}",
+ }
+ )
continue
is_video_source = ext in VIDEO_EXTENSIONS
@@ -101,37 +102,47 @@ async def upload_and_process_audio_files(
try:
content = await convert_any_to_wav(content, ext)
except AudioValidationError as e:
- processed_files.append({
- "filename": filename,
- "status": "error",
- "error": str(e),
- })
+ processed_files.append(
+ {
+ "filename": filename,
+ "status": "error",
+ "error": str(e),
+ }
+ )
continue
# Track external source for deduplication (Google Drive, etc.)
external_source_id = None
external_source_type = None
if source == "gdrive":
- external_source_id = getattr(file, "file_id", None) or getattr(file, "audio_uuid", None)
+ external_source_id = getattr(file, "file_id", None) or getattr(
+ file, "audio_uuid", None
+ )
external_source_type = "gdrive"
if not external_source_id:
- audio_logger.warning(f"Missing file_id for gdrive file: {filename}")
+ audio_logger.warning(
+ f"Missing file_id for gdrive file: {filename}"
+ )
timestamp = int(time.time() * 1000)
# Validate and prepare audio (read format from WAV file)
try:
- audio_data, sample_rate, sample_width, channels, duration = await validate_and_prepare_audio(
- audio_data=content,
- expected_sample_rate=16000, # Expecting 16kHz
- convert_to_mono=True, # Convert stereo to mono
- auto_resample=True # Auto-resample if sample rate doesn't match
+ audio_data, sample_rate, sample_width, channels, duration = (
+ await validate_and_prepare_audio(
+ audio_data=content,
+ expected_sample_rate=16000, # Expecting 16kHz
+ convert_to_mono=True, # Convert stereo to mono
+ auto_resample=True, # Auto-resample if sample rate doesn't match
+ )
)
except AudioValidationError as e:
- processed_files.append({
- "filename": filename,
- "status": "error",
- "error": str(e),
- })
+ processed_files.append(
+ {
+ "filename": filename,
+ "status": "error",
+ "error": str(e),
+ }
+ )
continue
audio_logger.info(
@@ -139,7 +150,11 @@ async def upload_and_process_audio_files(
)
# Generate title from filename
- title = filename.rsplit('.', 1)[0][:50] if filename != "unknown" else "Uploaded Audio"
+ title = (
+ filename.rsplit(".", 1)[0][:50]
+ if filename != "unknown"
+ else "Uploaded Audio"
+ )
conversation = create_conversation(
user_id=user.user_id,
@@ -150,9 +165,13 @@ async def upload_and_process_audio_files(
external_source_type=external_source_type,
)
await conversation.insert()
- conversation_id = conversation.conversation_id # Get the auto-generated ID
+ conversation_id = (
+ conversation.conversation_id
+ ) # Get the auto-generated ID
- audio_logger.info(f"π Created conversation {conversation_id} for uploaded file")
+ audio_logger.info(
+ f"π Created conversation {conversation_id} for uploaded file"
+ )
# Convert audio directly to MongoDB chunks
try:
@@ -170,24 +189,28 @@ async def upload_and_process_audio_files(
except ValueError as val_error:
# Handle validation errors (e.g., file too long)
audio_logger.error(f"Audio validation failed: {val_error}")
- processed_files.append({
- "filename": filename,
- "status": "error",
- "error": str(val_error),
- })
+ processed_files.append(
+ {
+ "filename": filename,
+ "status": "error",
+ "error": str(val_error),
+ }
+ )
# Delete the conversation since it won't have audio chunks
await conversation.delete()
continue
except Exception as chunk_error:
audio_logger.error(
f"Failed to convert uploaded file to chunks: {chunk_error}",
- exc_info=True
+ exc_info=True,
+ )
+ processed_files.append(
+ {
+ "filename": filename,
+ "status": "error",
+ "error": f"Audio conversion failed: {str(chunk_error)}",
+ }
)
- processed_files.append({
- "filename": filename,
- "status": "error",
- "error": f"Audio conversion failed: {str(chunk_error)}",
- })
# Delete the conversation since it won't have audio chunks
await conversation.delete()
continue
@@ -208,9 +231,14 @@ async def upload_and_process_audio_files(
result_ttl=JOB_RESULT_TTL,
job_id=transcribe_job_id,
description=f"Transcribe uploaded file {conversation_id[:8]}",
- meta={'conversation_id': conversation_id, 'client_id': client_id}
+ meta={
+ "conversation_id": conversation_id,
+ "client_id": client_id,
+ },
+ )
+ audio_logger.info(
+ f"π₯ Enqueued transcription job {transcription_job.id} for uploaded file"
)
- audio_logger.info(f"π₯ Enqueued transcription job {transcription_job.id} for uploaded file")
else:
audio_logger.warning(
f"β οΈ Skipping transcription for conversation {conversation_id}: "
@@ -223,16 +251,18 @@ async def upload_and_process_audio_files(
user_id=user.user_id,
transcript_version_id=version_id, # Pass the version_id from transcription job
depends_on_job=transcription_job, # Wait for transcription to complete (or None)
- client_id=client_id # Pass client_id for UI tracking
+ client_id=client_id, # Pass client_id for UI tracking
)
file_result = {
"filename": filename,
"status": "started", # RQ standard: job has been enqueued
"conversation_id": conversation_id,
- "transcript_job_id": transcription_job.id if transcription_job else None,
- "speaker_job_id": job_ids['speaker_recognition'],
- "memory_job_id": job_ids['memory'],
+ "transcript_job_id": (
+ transcription_job.id if transcription_job else None
+ ),
+ "speaker_job_id": job_ids["speaker_recognition"],
+ "memory_job_id": job_ids["memory"],
"duration_seconds": round(duration, 2),
}
if is_video_source:
@@ -243,10 +273,10 @@ async def upload_and_process_audio_files(
job_chain = []
if transcription_job:
job_chain.append(transcription_job.id)
- if job_ids['speaker_recognition']:
- job_chain.append(job_ids['speaker_recognition'])
- if job_ids['memory']:
- job_chain.append(job_ids['memory'])
+ if job_ids["speaker_recognition"]:
+ job_chain.append(job_ids["speaker_recognition"])
+ if job_ids["memory"]:
+ job_chain.append(job_ids["memory"])
audio_logger.info(
f"β
Processed {filename} β conversation {conversation_id}, "
@@ -256,19 +286,23 @@ async def upload_and_process_audio_files(
except (OSError, IOError) as e:
# File I/O errors during audio processing
audio_logger.exception(f"File I/O error processing {filename}")
- processed_files.append({
- "filename": filename,
- "status": "error",
- "error": str(e),
- })
+ processed_files.append(
+ {
+ "filename": filename,
+ "status": "error",
+ "error": str(e),
+ }
+ )
except Exception as e:
# Unexpected errors during file processing
audio_logger.exception(f"Unexpected error processing file {filename}")
- processed_files.append({
- "filename": filename,
- "status": "error",
- "error": str(e),
- })
+ processed_files.append(
+ {
+ "filename": filename,
+ "status": "error",
+ "error": str(e),
+ }
+ )
successful_files = [f for f in processed_files if f.get("status") == "started"]
failed_files = [f for f in processed_files if f.get("status") == "error"]
@@ -291,7 +325,9 @@ async def upload_and_process_audio_files(
return JSONResponse(status_code=400, content=response_body)
elif len(failed_files) > 0:
# SOME files failed (partial success) - return 207 Multi-Status
- audio_logger.warning(f"Partial upload: {len(successful_files)} succeeded, {len(failed_files)} failed")
+ audio_logger.warning(
+ f"Partial upload: {len(successful_files)} succeeded, {len(failed_files)} failed"
+ )
return JSONResponse(status_code=207, content=response_body)
else:
# All files succeeded - return 200 OK
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py
index b400d3ed..d603e7ea 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/client_controller.py
@@ -6,10 +6,7 @@
from fastapi.responses import JSONResponse
-from advanced_omi_backend.client_manager import (
- ClientManager,
- get_user_clients_active,
-)
+from advanced_omi_backend.client_manager import ClientManager, get_user_clients_active
from advanced_omi_backend.users import User
logger = logging.getLogger(__name__)
@@ -37,7 +34,9 @@ async def get_active_clients(user: User, client_manager: ClientManager):
# Filter to only the user's clients
user_clients = [
- client for client in all_clients if client["client_id"] in user_active_clients
+ client
+ for client in all_clients
+ if client["client_id"] in user_active_clients
]
return {
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py
index d2cfc7df..316c3299 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/queue_controller.py
@@ -73,6 +73,7 @@ def get_job_status_from_rq(job: Job) -> str:
return status_str
+
# Queue name constants
TRANSCRIPTION_QUEUE = "transcription"
MEMORY_QUEUE = "memory"
@@ -86,9 +87,13 @@ def get_job_status_from_rq(job: Job) -> str:
JOB_RESULT_TTL = int(os.getenv("RQ_RESULT_TTL", 86400)) # 24 hour default
# Create queues with custom result TTL
-transcription_queue = Queue(TRANSCRIPTION_QUEUE, connection=redis_conn, default_timeout=86400) # 24 hours for streaming jobs
+transcription_queue = Queue(
+ TRANSCRIPTION_QUEUE, connection=redis_conn, default_timeout=86400
+) # 24 hours for streaming jobs
memory_queue = Queue(MEMORY_QUEUE, connection=redis_conn, default_timeout=300)
-audio_queue = Queue(AUDIO_QUEUE, connection=redis_conn, default_timeout=86400) # 24 hours for all-day sessions
+audio_queue = Queue(
+ AUDIO_QUEUE, connection=redis_conn, default_timeout=86400
+) # 24 hours for all-day sessions
default_queue = Queue(DEFAULT_QUEUE, connection=redis_conn, default_timeout=300)
@@ -123,7 +128,14 @@ def get_job_stats() -> Dict[str, Any]:
canceled_jobs += len(queue.canceled_job_registry)
deferred_jobs += len(queue.deferred_job_registry)
- total_jobs = queued_jobs + started_jobs + finished_jobs + failed_jobs + canceled_jobs + deferred_jobs
+ total_jobs = (
+ queued_jobs
+ + started_jobs
+ + finished_jobs
+ + failed_jobs
+ + canceled_jobs
+ + deferred_jobs
+ )
return {
"total_jobs": total_jobs,
@@ -133,7 +145,7 @@ def get_job_stats() -> Dict[str, Any]:
"failed_jobs": failed_jobs,
"canceled_jobs": canceled_jobs,
"deferred_jobs": deferred_jobs,
- "timestamp": datetime.utcnow().isoformat()
+ "timestamp": datetime.utcnow().isoformat(),
}
@@ -142,7 +154,7 @@ def get_jobs(
offset: int = 0,
queue_name: str = None,
job_type: str = None,
- client_id: str = None
+ client_id: str = None,
) -> Dict[str, Any]:
"""
Get jobs from a specific queue or all queues with optional filtering.
@@ -157,9 +169,13 @@ def get_jobs(
Returns:
Dict with jobs list and pagination metadata matching frontend expectations
"""
- logger.info(f"π DEBUG get_jobs: Filtering - queue_name={queue_name}, job_type={job_type}, client_id={client_id}")
+ logger.info(
+ f"π DEBUG get_jobs: Filtering - queue_name={queue_name}, job_type={job_type}, client_id={client_id}"
+ )
all_jobs = []
- seen_job_ids = set() # Track which job IDs we've already processed to avoid duplicates
+ seen_job_ids = (
+ set()
+ ) # Track which job IDs we've already processed to avoid duplicates
queues_to_check = [queue_name] if queue_name else QUEUE_NAMES
logger.info(f"π DEBUG get_jobs: Checking queues: {queues_to_check}")
@@ -170,10 +186,19 @@ def get_jobs(
# Collect jobs from all registries (using RQ standard status names)
registries = [
(queue.job_ids, "queued"),
- (queue.started_job_registry.get_job_ids(), "started"), # RQ standard, not "processing"
- (queue.finished_job_registry.get_job_ids(), "finished"), # RQ standard, not "completed"
+ (
+ queue.started_job_registry.get_job_ids(),
+ "started",
+ ), # RQ standard, not "processing"
+ (
+ queue.finished_job_registry.get_job_ids(),
+ "finished",
+ ), # RQ standard, not "completed"
(queue.failed_job_registry.get_job_ids(), "failed"),
- (queue.deferred_job_registry.get_job_ids(), "deferred"), # Jobs waiting for dependencies
+ (
+ queue.deferred_job_registry.get_job_ids(),
+ "deferred",
+ ), # Jobs waiting for dependencies
]
for job_ids, status in registries:
@@ -190,46 +215,76 @@ def get_jobs(
user_id = job.kwargs.get("user_id", "") if job.kwargs else ""
# Extract just the function name (e.g., "listen_for_speech_job" from "module.listen_for_speech_job")
- func_name = job.func_name.split('.')[-1] if job.func_name else "unknown"
+ func_name = (
+ job.func_name.split(".")[-1] if job.func_name else "unknown"
+ )
# Debug: Log job details before filtering
- logger.debug(f"π DEBUG get_jobs: Job {job_id} - func_name={func_name}, full_func_name={job.func_name}, meta_client_id={job.meta.get('client_id', '') if job.meta else ''}, status={status}")
+ logger.debug(
+ f"π DEBUG get_jobs: Job {job_id} - func_name={func_name}, full_func_name={job.func_name}, meta_client_id={job.meta.get('client_id', '') if job.meta else ''}, status={status}"
+ )
# Apply job_type filter
if job_type and job_type not in func_name:
- logger.debug(f"π DEBUG get_jobs: Filtered out {job_id} - job_type '{job_type}' not in func_name '{func_name}'")
+ logger.debug(
+ f"π DEBUG get_jobs: Filtered out {job_id} - job_type '{job_type}' not in func_name '{func_name}'"
+ )
continue
# Apply client_id filter (partial match in meta)
if client_id:
- job_client_id = job.meta.get("client_id", "") if job.meta else ""
+ job_client_id = (
+ job.meta.get("client_id", "") if job.meta else ""
+ )
if client_id not in job_client_id:
- logger.debug(f"π DEBUG get_jobs: Filtered out {job_id} - client_id '{client_id}' not in job_client_id '{job_client_id}'")
+ logger.debug(
+ f"π DEBUG get_jobs: Filtered out {job_id} - client_id '{client_id}' not in job_client_id '{job_client_id}'"
+ )
continue
- logger.debug(f"π DEBUG get_jobs: Including job {job_id} in results")
-
- all_jobs.append({
- "job_id": job.id,
- "job_type": func_name,
- "user_id": user_id,
- "status": status,
- "priority": "normal", # RQ doesn't track priority in metadata
- "data": {
- "description": job.description or "",
- "queue": qname,
- },
- "result": job.result if hasattr(job, 'result') else None,
- "meta": job.meta if job.meta else {}, # Include job metadata
- "error_message": str(job.exc_info) if job.exc_info else None,
- "created_at": job.created_at.isoformat() if job.created_at else None,
- "started_at": job.started_at.isoformat() if job.started_at else None,
- "completed_at": job.ended_at.isoformat() if job.ended_at else None,
- "retry_count": job.retries_left if hasattr(job, 'retries_left') else 0,
- "max_retries": 3, # Default max retries
- "progress_percent": (job.meta or {}).get("batch_progress", {}).get("percent", 0),
- "progress_message": (job.meta or {}).get("batch_progress", {}).get("message", ""),
- })
+ logger.debug(
+ f"π DEBUG get_jobs: Including job {job_id} in results"
+ )
+
+ all_jobs.append(
+ {
+ "job_id": job.id,
+ "job_type": func_name,
+ "user_id": user_id,
+ "status": status,
+ "priority": "normal", # RQ doesn't track priority in metadata
+ "data": {
+ "description": job.description or "",
+ "queue": qname,
+ },
+ "result": job.result if hasattr(job, "result") else None,
+ "meta": (
+ job.meta if job.meta else {}
+ ), # Include job metadata
+ "error_message": (
+ str(job.exc_info) if job.exc_info else None
+ ),
+ "created_at": (
+ job.created_at.isoformat() if job.created_at else None
+ ),
+ "started_at": (
+ job.started_at.isoformat() if job.started_at else None
+ ),
+ "completed_at": (
+ job.ended_at.isoformat() if job.ended_at else None
+ ),
+ "retry_count": (
+ job.retries_left if hasattr(job, "retries_left") else 0
+ ),
+ "max_retries": 3, # Default max retries
+ "progress_percent": (job.meta or {})
+ .get("batch_progress", {})
+ .get("percent", 0),
+ "progress_message": (job.meta or {})
+ .get("batch_progress", {})
+ .get("message", ""),
+ }
+ )
except Exception as e:
logger.error(f"Error fetching job {job_id}: {e}")
@@ -238,10 +293,12 @@ def get_jobs(
# Paginate
total_jobs = len(all_jobs)
- paginated_jobs = all_jobs[offset:offset + limit]
+ paginated_jobs = all_jobs[offset : offset + limit]
has_more = (offset + limit) < total_jobs
- logger.info(f"π DEBUG get_jobs: Found {total_jobs} matching jobs (returning {len(paginated_jobs)} after pagination)")
+ logger.info(
+ f"π DEBUG get_jobs: Found {total_jobs} matching jobs (returning {len(paginated_jobs)} after pagination)"
+ )
return {
"jobs": paginated_jobs,
@@ -250,7 +307,7 @@ def get_jobs(
"limit": limit,
"offset": offset,
"has_more": has_more,
- }
+ },
}
@@ -281,7 +338,7 @@ def is_job_complete(job):
return False
# Check dependent jobs
- for dep_id in (job.dependent_ids or []):
+ for dep_id in job.dependent_ids or []:
try:
dep_job = Job.fetch(dep_id, connection=redis_conn)
if not is_job_complete(dep_job):
@@ -310,7 +367,7 @@ def is_job_complete(job):
job = Job.fetch(job_id, connection=redis_conn)
# Only check jobs with client_id in meta
- if job.meta and job.meta.get('client_id') == client_id:
+ if job.meta and job.meta.get("client_id") == client_id:
if not is_job_complete(job):
return False
except Exception as e:
@@ -320,9 +377,7 @@ def is_job_complete(job):
def start_streaming_jobs(
- session_id: str,
- user_id: str,
- client_id: str
+ session_id: str, user_id: str, client_id: str
) -> Dict[str, str]:
"""
Enqueue jobs for streaming audio session (initial session setup).
@@ -351,7 +406,7 @@ def start_streaming_jobs(
# Read always_persist from global config NOW (backend process has fresh config)
misc_settings = get_misc_settings()
- always_persist = misc_settings.get('always_persist_enabled', False)
+ always_persist = misc_settings.get("always_persist_enabled", False)
# Enqueue speech detection job
speech_job = transcription_queue.enqueue(
@@ -365,7 +420,7 @@ def start_streaming_jobs(
failure_ttl=86400, # Cleanup failed jobs after 24h
job_id=f"speech-detect_{session_id[:12]}",
description=f"Listening for speech...",
- meta={'client_id': client_id, 'session_level': True}
+ meta={"client_id": client_id, "session_level": True},
)
# Log job enqueue with TTL information for debugging
actual_ttl = redis_conn.ttl(f"rq:job:{speech_job.id}")
@@ -379,7 +434,9 @@ def start_streaming_jobs(
# Store job ID for cleanup (keyed by client_id for easy WebSocket cleanup)
try:
- redis_conn.set(f"speech_detection_job:{client_id}", speech_job.id, ex=86400) # 24 hour TTL
+ redis_conn.set(
+ f"speech_detection_job:{client_id}", speech_job.id, ex=86400
+ ) # 24 hour TTL
logger.info(f"π Stored speech detection job ID for client {client_id}")
except Exception as e:
logger.warning(f"β οΈ Failed to store job ID for {client_id}: {e}")
@@ -399,7 +456,10 @@ def start_streaming_jobs(
failure_ttl=86400, # Cleanup failed jobs after 24h
job_id=f"audio-persist_{session_id[:12]}",
description=f"Audio persistence for session {session_id[:12]}",
- meta={'client_id': client_id, 'session_level': True} # Mark as session-level job
+ meta={
+ "client_id": client_id,
+ "session_level": True,
+ }, # Mark as session-level job
)
# Log job enqueue with TTL information for debugging
actual_ttl = redis_conn.ttl(f"rq:job:{audio_job.id}")
@@ -411,19 +471,16 @@ def start_streaming_jobs(
f"queue_length={audio_queue.count}, client_id={client_id}"
)
- return {
- 'speech_detection': speech_job.id,
- 'audio_persistence': audio_job.id
- }
+ return {"speech_detection": speech_job.id, "audio_persistence": audio_job.id}
def start_post_conversation_jobs(
conversation_id: str,
user_id: str,
transcript_version_id: Optional[str] = None,
- depends_on_job = None,
+ depends_on_job=None,
client_id: Optional[str] = None,
- end_reason: str = "file_upload"
+ end_reason: str = "file_upload",
) -> Dict[str, str]:
"""
Start post-conversation processing jobs after conversation is created.
@@ -458,21 +515,27 @@ def start_post_conversation_jobs(
version_id = transcript_version_id or str(uuid.uuid4())
# Build job metadata (include client_id if provided for UI tracking)
- job_meta = {'conversation_id': conversation_id}
+ job_meta = {"conversation_id": conversation_id}
if client_id:
- job_meta['client_id'] = client_id
+ job_meta["client_id"] = client_id
# Check if speaker recognition is enabled
- speaker_config = get_service_config('speaker_recognition')
- speaker_enabled = speaker_config.get('enabled', True) # Default to True for backward compatibility
+ speaker_config = get_service_config("speaker_recognition")
+ speaker_enabled = speaker_config.get(
+ "enabled", True
+ ) # Default to True for backward compatibility
# Step 1: Speaker recognition job (conditional - only if enabled)
- speaker_dependency = depends_on_job # Start with upstream dependency (transcription if file upload)
+ speaker_dependency = (
+ depends_on_job # Start with upstream dependency (transcription if file upload)
+ )
speaker_job = None
if speaker_enabled:
speaker_job_id = f"speaker_{conversation_id[:12]}"
- logger.info(f"π DEBUG: Creating speaker job with job_id={speaker_job_id}, conversation_id={conversation_id[:12]}")
+ logger.info(
+ f"π DEBUG: Creating speaker job with job_id={speaker_job_id}, conversation_id={conversation_id[:12]}"
+ )
speaker_job = transcription_queue.enqueue(
recognise_speakers_job,
@@ -483,26 +546,36 @@ def start_post_conversation_jobs(
depends_on=speaker_dependency,
job_id=speaker_job_id,
description=f"Speaker recognition for conversation {conversation_id[:8]}",
- meta=job_meta
+ meta=job_meta,
)
speaker_dependency = speaker_job # Chain for next jobs
if depends_on_job:
- logger.info(f"π₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (depends on {depends_on_job.id})")
+ logger.info(
+ f"π₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (depends on {depends_on_job.id})"
+ )
else:
- logger.info(f"π₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (no dependencies, starts immediately)")
+ logger.info(
+ f"π₯ RQ: Enqueued speaker recognition job {speaker_job.id}, meta={speaker_job.meta} (no dependencies, starts immediately)"
+ )
else:
- logger.info(f"βοΈ Speaker recognition disabled, skipping speaker job for conversation {conversation_id[:8]}")
+ logger.info(
+ f"βοΈ Speaker recognition disabled, skipping speaker job for conversation {conversation_id[:8]}"
+ )
# Step 2: Memory extraction job (conditional - only if enabled)
# Check if memory extraction is enabled
- memory_config = get_service_config('memory.extraction')
- memory_enabled = memory_config.get('enabled', True) # Default to True for backward compatibility
+ memory_config = get_service_config("memory.extraction")
+ memory_enabled = memory_config.get(
+ "enabled", True
+ ) # Default to True for backward compatibility
memory_job = None
if memory_enabled:
# Depends on speaker job if it was created, otherwise depends on upstream (transcription or nothing)
memory_job_id = f"memory_{conversation_id[:12]}"
- logger.info(f"π DEBUG: Creating memory job with job_id={memory_job_id}, conversation_id={conversation_id[:12]}")
+ logger.info(
+ f"π DEBUG: Creating memory job with job_id={memory_job_id}, conversation_id={conversation_id[:12]}"
+ )
memory_job = memory_queue.enqueue(
process_memory_job,
@@ -512,23 +585,33 @@ def start_post_conversation_jobs(
depends_on=speaker_dependency, # Either speaker_job or upstream dependency
job_id=memory_job_id,
description=f"Memory extraction for conversation {conversation_id[:8]}",
- meta=job_meta
+ meta=job_meta,
)
if speaker_job:
- logger.info(f"π₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on speaker job {speaker_job.id})")
+ logger.info(
+ f"π₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on speaker job {speaker_job.id})"
+ )
elif depends_on_job:
- logger.info(f"π₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on {depends_on_job.id})")
+ logger.info(
+ f"π₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (depends on {depends_on_job.id})"
+ )
else:
- logger.info(f"π₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (no dependencies, starts immediately)")
+ logger.info(
+ f"π₯ RQ: Enqueued memory extraction job {memory_job.id}, meta={memory_job.meta} (no dependencies, starts immediately)"
+ )
else:
- logger.info(f"βοΈ Memory extraction disabled, skipping memory job for conversation {conversation_id[:8]}")
+ logger.info(
+ f"βοΈ Memory extraction disabled, skipping memory job for conversation {conversation_id[:8]}"
+ )
# Step 3: Title/summary generation job
# Depends on memory job to avoid race condition (both jobs save the conversation document)
# and to ensure fresh memories are available for context-enriched summaries
title_dependency = memory_job if memory_job else speaker_dependency
title_job_id = f"title_summary_{conversation_id[:12]}"
- logger.info(f"π DEBUG: Creating title/summary job with job_id={title_job_id}, conversation_id={conversation_id[:12]}")
+ logger.info(
+ f"π DEBUG: Creating title/summary job with job_id={title_job_id}, conversation_id={conversation_id[:12]}"
+ )
title_summary_job = default_queue.enqueue(
generate_title_summary_job,
@@ -538,21 +621,31 @@ def start_post_conversation_jobs(
depends_on=title_dependency,
job_id=title_job_id,
description=f"Generate title and summary for conversation {conversation_id[:8]}",
- meta=job_meta
+ meta=job_meta,
)
if memory_job:
- logger.info(f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on memory job {memory_job.id})")
+ logger.info(
+ f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on memory job {memory_job.id})"
+ )
elif speaker_job:
- logger.info(f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on speaker job {speaker_job.id})")
+ logger.info(
+ f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on speaker job {speaker_job.id})"
+ )
elif depends_on_job:
- logger.info(f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on {depends_on_job.id})")
+ logger.info(
+ f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (depends on {depends_on_job.id})"
+ )
else:
- logger.info(f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (no dependencies, starts immediately)")
+ logger.info(
+ f"π₯ RQ: Enqueued title/summary job {title_summary_job.id}, meta={title_summary_job.meta} (no dependencies, starts immediately)"
+ )
# Step 5: Dispatch conversation.complete event (runs after both memory and title/summary complete)
# This ensures plugins receive the event after all processing is done
event_job_id = f"event_complete_{conversation_id[:12]}"
- logger.info(f"π DEBUG: Creating conversation complete event job with job_id={event_job_id}, conversation_id={conversation_id[:12]}")
+ logger.info(
+ f"π DEBUG: Creating conversation complete event job with job_id={event_job_id}, conversation_id={conversation_id[:12]}"
+ )
# Event job depends on memory and title/summary jobs that were actually enqueued
# Build dependency list excluding None values
@@ -571,29 +664,33 @@ def start_post_conversation_jobs(
end_reason, # Use the end_reason parameter (defaults to 'file_upload' for backward compatibility)
job_timeout=120, # 2 minutes
result_ttl=JOB_RESULT_TTL,
- depends_on=event_dependencies if event_dependencies else None, # Wait for jobs that were enqueued
+ depends_on=(
+ event_dependencies if event_dependencies else None
+ ), # Wait for jobs that were enqueued
job_id=event_job_id,
description=f"Dispatch conversation complete event ({end_reason}) for {conversation_id[:8]}",
- meta=job_meta
+ meta=job_meta,
)
# Log event dispatch dependencies
if event_dependencies:
dep_ids = [job.id for job in event_dependencies]
- logger.info(f"π₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (depends on {', '.join(dep_ids)})")
+ logger.info(
+ f"π₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (depends on {', '.join(dep_ids)})"
+ )
else:
- logger.info(f"π₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (no dependencies, starts immediately)")
+ logger.info(
+ f"π₯ RQ: Enqueued conversation complete event job {event_dispatch_job.id}, meta={event_dispatch_job.meta} (no dependencies, starts immediately)"
+ )
return {
- 'speaker_recognition': speaker_job.id if speaker_job else None,
- 'memory': memory_job.id if memory_job else None,
- 'title_summary': title_summary_job.id,
- 'event_dispatch': event_dispatch_job.id
+ "speaker_recognition": speaker_job.id if speaker_job else None,
+ "memory": memory_job.id if memory_job else None,
+ "title_summary": title_summary_job.id,
+ "event_dispatch": event_dispatch_job.id,
}
-
-
def get_queue_health() -> Dict[str, Any]:
"""Get health status of all queues and workers."""
health = {
@@ -637,15 +734,18 @@ def get_queue_health() -> Dict[str, Any]:
else:
health["idle_workers"] += 1
- health["workers"].append({
- "name": worker.name,
- "state": state,
- "queues": [q.name for q in worker.queues],
- "current_job": current_job,
- })
+ health["workers"].append(
+ {
+ "name": worker.name,
+ "state": state,
+ "queues": [q.name for q in worker.queues],
+ "current_job": current_job,
+ }
+ )
return health
+
# needs tidying but works for now
async def cleanup_stuck_stream_workers(request):
"""Clean up stuck Redis Stream consumers and pending messages from all active streams."""
@@ -660,7 +760,7 @@ async def cleanup_stuck_stream_workers(request):
if not redis_client:
return JSONResponse(
status_code=503,
- content={"error": "Redis client for audio streaming not initialized"}
+ content={"error": "Redis client for audio streaming not initialized"},
)
cleanup_results = {}
@@ -673,17 +773,25 @@ async def cleanup_stuck_stream_workers(request):
stream_keys = await redis_client.keys("audio:stream:*")
for stream_key in stream_keys:
- stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ stream_name = (
+ stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ )
try:
# First check stream age - delete old streams (>1 hour) immediately
- stream_info = await redis_client.execute_command('XINFO', 'STREAM', stream_name)
+ stream_info = await redis_client.execute_command(
+ "XINFO", "STREAM", stream_name
+ )
# Parse stream info
info_dict = {}
for i in range(0, len(stream_info), 2):
- key_name = stream_info[i].decode() if isinstance(stream_info[i], bytes) else str(stream_info[i])
- info_dict[key_name] = stream_info[i+1]
+ key_name = (
+ stream_info[i].decode()
+ if isinstance(stream_info[i], bytes)
+ else str(stream_info[i])
+ )
+ info_dict[key_name] = stream_info[i + 1]
stream_length = int(info_dict.get("length", 0))
last_entry = info_dict.get("last-entry")
@@ -695,12 +803,14 @@ async def cleanup_stuck_stream_workers(request):
if stream_length == 0:
should_delete_stream = True
stream_age = 0
- elif last_entry and isinstance(last_entry, list) and len(last_entry) > 0:
+ elif (
+ last_entry and isinstance(last_entry, list) and len(last_entry) > 0
+ ):
try:
last_id = last_entry[0]
if isinstance(last_id, bytes):
last_id = last_id.decode()
- last_timestamp_ms = int(last_id.split('-')[0])
+ last_timestamp_ms = int(last_id.split("-")[0])
last_timestamp_s = last_timestamp_ms / 1000
stream_age = current_time - last_timestamp_s
@@ -718,23 +828,33 @@ async def cleanup_stuck_stream_workers(request):
"cleaned": 0,
"deleted_consumers": 0,
"deleted_stream": True,
- "stream_age": stream_age
+ "stream_age": stream_age,
}
continue
# Get consumer groups
- groups = await redis_client.execute_command('XINFO', 'GROUPS', stream_name)
+ groups = await redis_client.execute_command(
+ "XINFO", "GROUPS", stream_name
+ )
if not groups:
- cleanup_results[stream_name] = {"message": "No consumer groups found", "cleaned": 0, "deleted_stream": False}
+ cleanup_results[stream_name] = {
+ "message": "No consumer groups found",
+ "cleaned": 0,
+ "deleted_stream": False,
+ }
continue
# Parse first group
group_dict = {}
group = groups[0]
for i in range(0, len(group), 2):
- key = group[i].decode() if isinstance(group[i], bytes) else str(group[i])
- value = group[i+1]
+ key = (
+ group[i].decode()
+ if isinstance(group[i], bytes)
+ else str(group[i])
+ )
+ value = group[i + 1]
if isinstance(value, bytes):
try:
value = value.decode()
@@ -749,7 +869,9 @@ async def cleanup_stuck_stream_workers(request):
pending_count = int(group_dict.get("pending", 0))
# Get consumers for this group to check per-consumer pending
- consumers = await redis_client.execute_command('XINFO', 'CONSUMERS', stream_name, group_name)
+ consumers = await redis_client.execute_command(
+ "XINFO", "CONSUMERS", stream_name, group_name
+ )
cleaned_count = 0
total_consumer_pending = 0
@@ -759,8 +881,12 @@ async def cleanup_stuck_stream_workers(request):
for consumer in consumers:
consumer_dict = {}
for i in range(0, len(consumer), 2):
- key = consumer[i].decode() if isinstance(consumer[i], bytes) else str(consumer[i])
- value = consumer[i+1]
+ key = (
+ consumer[i].decode()
+ if isinstance(consumer[i], bytes)
+ else str(consumer[i])
+ )
+ value = consumer[i + 1]
if isinstance(value, bytes):
try:
value = value.decode()
@@ -780,12 +906,20 @@ async def cleanup_stuck_stream_workers(request):
is_dead = consumer_idle_ms > 300000
if consumer_pending > 0:
- logger.info(f"Found {consumer_pending} pending messages for consumer {consumer_name} (idle: {consumer_idle_ms}ms)")
+ logger.info(
+ f"Found {consumer_pending} pending messages for consumer {consumer_name} (idle: {consumer_idle_ms}ms)"
+ )
# Get pending messages for this specific consumer
try:
pending_messages = await redis_client.execute_command(
- 'XPENDING', stream_name, group_name, '-', '+', str(consumer_pending), consumer_name
+ "XPENDING",
+ stream_name,
+ group_name,
+ "-",
+ "+",
+ str(consumer_pending),
+ consumer_name,
)
# XPENDING returns flat list: [msg_id, consumer, idle_ms, delivery_count, msg_id, ...]
@@ -799,31 +933,55 @@ async def cleanup_stuck_stream_workers(request):
# Claim the message to a cleanup worker
try:
await redis_client.execute_command(
- 'XCLAIM', stream_name, group_name, 'cleanup-worker', '0', msg_id
+ "XCLAIM",
+ stream_name,
+ group_name,
+ "cleanup-worker",
+ "0",
+ msg_id,
)
# Acknowledge it immediately
- await redis_client.xack(stream_name, group_name, msg_id)
+ await redis_client.xack(
+ stream_name, group_name, msg_id
+ )
cleaned_count += 1
except Exception as claim_error:
- logger.warning(f"Failed to claim/ack message {msg_id}: {claim_error}")
+ logger.warning(
+ f"Failed to claim/ack message {msg_id}: {claim_error}"
+ )
except Exception as consumer_error:
- logger.error(f"Error processing consumer {consumer_name}: {consumer_error}")
+ logger.error(
+ f"Error processing consumer {consumer_name}: {consumer_error}"
+ )
# Delete dead consumers (idle > 5 minutes with no pending messages)
if is_dead and consumer_pending == 0:
try:
await redis_client.execute_command(
- 'XGROUP', 'DELCONSUMER', stream_name, group_name, consumer_name
+ "XGROUP",
+ "DELCONSUMER",
+ stream_name,
+ group_name,
+ consumer_name,
)
deleted_consumers += 1
- logger.info(f"π§Ή Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)")
+ logger.info(
+ f"π§Ή Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)"
+ )
except Exception as delete_error:
- logger.warning(f"Failed to delete consumer {consumer_name}: {delete_error}")
+ logger.warning(
+ f"Failed to delete consumer {consumer_name}: {delete_error}"
+ )
if total_consumer_pending == 0 and deleted_consumers == 0:
- cleanup_results[stream_name] = {"message": "No pending messages or dead consumers", "cleaned": 0, "deleted_consumers": 0, "deleted_stream": False}
+ cleanup_results[stream_name] = {
+ "message": "No pending messages or dead consumers",
+ "cleaned": 0,
+ "deleted_consumers": 0,
+ "deleted_stream": False,
+ }
continue
total_cleaned += cleaned_count
@@ -833,14 +991,11 @@ async def cleanup_stuck_stream_workers(request):
"cleaned": cleaned_count,
"deleted_consumers": deleted_consumers,
"deleted_stream": False,
- "original_pending": pending_count
+ "original_pending": pending_count,
}
except Exception as e:
- cleanup_results[stream_name] = {
- "error": str(e),
- "cleaned": 0
- }
+ cleanup_results[stream_name] = {"error": str(e), "cleaned": 0}
return {
"success": True,
@@ -849,11 +1004,12 @@ async def cleanup_stuck_stream_workers(request):
"total_deleted_streams": total_deleted_streams,
"streams": cleanup_results, # New key for per-stream results
"providers": cleanup_results, # Keep for backward compatibility with frontend
- "timestamp": time.time()
+ "timestamp": time.time(),
}
except Exception as e:
logger.error(f"Error cleaning up stuck workers: {e}", exc_info=True)
return JSONResponse(
- status_code=500, content={"error": f"Failed to cleanup stuck workers: {str(e)}"}
+ status_code=500,
+ content={"error": f"Failed to cleanup stuck workers: {str(e)}"},
)
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py
index 9b3a2de9..f30401ec 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/session_controller.py
@@ -24,7 +24,7 @@ async def mark_session_complete(
"user_stopped",
"inactivity_timeout",
"max_duration",
- "all_jobs_complete"
+ "all_jobs_complete",
],
) -> None:
"""
@@ -57,12 +57,17 @@ async def mark_session_complete(
"""
session_key = f"audio:session:{session_id}"
mark_time = time.time()
- await redis_client.hset(session_key, mapping={
- "status": "finished",
- "completed_at": str(mark_time),
- "completion_reason": reason
- })
- logger.info(f"β
Session {session_id[:12]} marked finished: {reason} [TIME: {mark_time:.3f}]")
+ await redis_client.hset(
+ session_key,
+ mapping={
+ "status": "finished",
+ "completed_at": str(mark_time),
+ "completion_reason": reason,
+ },
+ )
+ logger.info(
+ f"β
Session {session_id[:12]} marked finished: {reason} [TIME: {mark_time:.3f}]"
+ )
async def request_conversation_close(
@@ -92,7 +97,9 @@ async def request_conversation_close(
if not await redis_client.exists(session_key):
return False
await redis_client.hset(session_key, "conversation_close_requested", reason)
- logger.info(f"π Conversation close requested for session {session_id[:12]}: {reason}")
+ logger.info(
+ f"π Conversation close requested for session {session_id[:12]}: {reason}"
+ )
return True
@@ -117,7 +124,9 @@ async def get_session_info(redis_client, session_id: str) -> Optional[Dict]:
# Get conversation count for this session
conversation_count_key = f"session:conversation_count:{session_id}"
conversation_count_bytes = await redis_client.get(conversation_count_key)
- conversation_count = int(conversation_count_bytes.decode()) if conversation_count_bytes else 0
+ conversation_count = (
+ int(conversation_count_bytes.decode()) if conversation_count_bytes else 0
+ )
started_at = float(session_data.get(b"started_at", b"0"))
last_chunk_at = float(session_data.get(b"last_chunk_at", b"0"))
@@ -138,8 +147,12 @@ async def get_session_info(redis_client, session_id: str) -> Optional[Dict]:
# Speech detection events
"last_event": session_data.get(b"last_event", b"").decode(),
"speech_detected_at": session_data.get(b"speech_detected_at", b"").decode(),
- "speaker_check_status": session_data.get(b"speaker_check_status", b"").decode(),
- "identified_speakers": session_data.get(b"identified_speakers", b"").decode()
+ "speaker_check_status": session_data.get(
+ b"speaker_check_status", b""
+ ).decode(),
+ "identified_speakers": session_data.get(
+ b"identified_speakers", b""
+ ).decode(),
}
except Exception as e:
@@ -166,7 +179,7 @@ async def get_all_sessions(redis_client, limit: int = 100) -> List[Dict]:
cursor, keys = await redis_client.scan(
cursor, match="audio:session:*", count=limit
)
- session_keys.extend(keys[:limit - len(session_keys)])
+ session_keys.extend(keys[: limit - len(session_keys)])
# Get info for each session
sessions = []
@@ -221,7 +234,9 @@ async def increment_session_conversation_count(redis_client, session_id: str) ->
logger.info(f"π Conversation count for session {session_id}: {count}")
return count
except Exception as e:
- logger.error(f"Error incrementing conversation count for session {session_id}: {e}")
+ logger.error(
+ f"Error incrementing conversation count for session {session_id}: {e}"
+ )
return 0
@@ -241,7 +256,7 @@ async def get_streaming_status(request):
if not redis_client:
return JSONResponse(
status_code=503,
- content={"error": "Redis client for audio streaming not initialized"}
+ content={"error": "Redis client for audio streaming not initialized"},
)
# Get all sessions (both active and completed)
@@ -273,23 +288,48 @@ async def get_streaming_status(request):
# All jobs finished - this is truly a finished session
# Update Redis status if it wasn't already marked finished
if status != "finished":
- await mark_session_complete(redis_client, session_id, "all_jobs_complete")
+ await mark_session_complete(
+ redis_client, session_id, "all_jobs_complete"
+ )
# Get additional session data for completed sessions
session_key = f"audio:session:{session_id}"
session_data = await redis_client.hgetall(session_key)
- completed_sessions_from_redis.append({
- "session_id": session_id,
- "client_id": session_obj.get("client_id", ""),
- "conversation_id": session_data.get(b"conversation_id", b"").decode() if session_data and b"conversation_id" in session_data else None,
- "has_conversation": bool(session_data and session_data.get(b"conversation_id", b"")),
- "action": session_data.get(b"action", b"finished").decode() if session_data and b"action" in session_data else "finished",
- "reason": session_data.get(b"reason", b"").decode() if session_data and b"reason" in session_data else "",
- "completed_at": session_obj.get("last_chunk_at", 0),
- "audio_file": session_data.get(b"audio_file", b"").decode() if session_data and b"audio_file" in session_data else "",
- "conversation_count": session_obj.get("conversation_count", 0)
- })
+ completed_sessions_from_redis.append(
+ {
+ "session_id": session_id,
+ "client_id": session_obj.get("client_id", ""),
+ "conversation_id": (
+ session_data.get(b"conversation_id", b"").decode()
+ if session_data and b"conversation_id" in session_data
+ else None
+ ),
+ "has_conversation": bool(
+ session_data
+ and session_data.get(b"conversation_id", b"")
+ ),
+ "action": (
+ session_data.get(b"action", b"finished").decode()
+ if session_data and b"action" in session_data
+ else "finished"
+ ),
+ "reason": (
+ session_data.get(b"reason", b"").decode()
+ if session_data and b"reason" in session_data
+ else ""
+ ),
+ "completed_at": session_obj.get("last_chunk_at", 0),
+ "audio_file": (
+ session_data.get(b"audio_file", b"").decode()
+ if session_data and b"audio_file" in session_data
+ else ""
+ ),
+ "conversation_count": session_obj.get(
+ "conversation_count", 0
+ ),
+ }
+ )
else:
# Status says complete but jobs still processing - keep in active
active_sessions.append(session_obj)
@@ -314,16 +354,24 @@ async def get_streaming_status(request):
current_time = time.time()
for stream_key in stream_keys:
- stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ stream_name = (
+ stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ )
try:
# Check if stream exists
- stream_info = await redis_client.execute_command('XINFO', 'STREAM', stream_name)
+ stream_info = await redis_client.execute_command(
+ "XINFO", "STREAM", stream_name
+ )
# Parse stream info (returns flat list of key-value pairs)
info_dict = {}
for i in range(0, len(stream_info), 2):
- key = stream_info[i].decode() if isinstance(stream_info[i], bytes) else str(stream_info[i])
- value = stream_info[i+1]
+ key = (
+ stream_info[i].decode()
+ if isinstance(stream_info[i], bytes)
+ else str(stream_info[i])
+ )
+ value = stream_info[i + 1]
# Skip complex binary structures like first-entry and last-entry
# which contain message data that can't be JSON serialized
@@ -351,7 +399,7 @@ async def get_streaming_status(request):
if last_entry_id:
try:
# Redis Stream IDs format: "milliseconds-sequence"
- last_timestamp_ms = int(last_entry_id.split('-')[0])
+ last_timestamp_ms = int(last_entry_id.split("-")[0])
last_timestamp_s = last_timestamp_ms / 1000
stream_age_seconds = current_time - last_timestamp_s
except (ValueError, IndexError, AttributeError):
@@ -369,7 +417,9 @@ async def get_streaming_status(request):
session_idle_seconds = session_data.get("idle_seconds", 0)
# Get consumer groups
- groups = await redis_client.execute_command('XINFO', 'GROUPS', stream_name)
+ groups = await redis_client.execute_command(
+ "XINFO", "GROUPS", stream_name
+ )
stream_data = {
"stream_length": info_dict.get("length", 0),
@@ -378,19 +428,23 @@ async def get_streaming_status(request):
"session_age_seconds": session_age_seconds, # Age since session started
"session_idle_seconds": session_idle_seconds, # Time since last audio chunk
"client_id": client_id, # Include client_id for reference
- "consumer_groups": []
+ "consumer_groups": [],
}
# Track if stream has any active consumers
has_active_consumer = False
- min_consumer_idle_ms = float('inf')
+ min_consumer_idle_ms = float("inf")
# Parse consumer groups
for group in groups:
group_dict = {}
for i in range(0, len(group), 2):
- key = group[i].decode() if isinstance(group[i], bytes) else str(group[i])
- value = group[i+1]
+ key = (
+ group[i].decode()
+ if isinstance(group[i], bytes)
+ else str(group[i])
+ )
+ value = group[i + 1]
if isinstance(value, bytes):
try:
value = value.decode()
@@ -403,15 +457,21 @@ async def get_streaming_status(request):
group_name = group_name.decode()
# Get consumers for this group
- consumers = await redis_client.execute_command('XINFO', 'CONSUMERS', stream_name, group_name)
+ consumers = await redis_client.execute_command(
+ "XINFO", "CONSUMERS", stream_name, group_name
+ )
consumer_list = []
consumer_pending_total = 0
for consumer in consumers:
consumer_dict = {}
for i in range(0, len(consumer), 2):
- key = consumer[i].decode() if isinstance(consumer[i], bytes) else str(consumer[i])
- value = consumer[i+1]
+ key = (
+ consumer[i].decode()
+ if isinstance(consumer[i], bytes)
+ else str(consumer[i])
+ )
+ value = consumer[i + 1]
if isinstance(value, bytes):
try:
value = value.decode()
@@ -428,17 +488,21 @@ async def get_streaming_status(request):
consumer_pending_total += consumer_pending
# Track minimum idle time
- min_consumer_idle_ms = min(min_consumer_idle_ms, consumer_idle_ms)
+ min_consumer_idle_ms = min(
+ min_consumer_idle_ms, consumer_idle_ms
+ )
# Consumer is active if idle < 5 minutes (300000ms)
if consumer_idle_ms < 300000:
has_active_consumer = True
- consumer_list.append({
- "name": consumer_name,
- "pending": consumer_pending,
- "idle_ms": consumer_idle_ms
- })
+ consumer_list.append(
+ {
+ "name": consumer_name,
+ "pending": consumer_pending,
+ "idle_ms": consumer_idle_ms,
+ }
+ )
# Get group-level pending count (may be 0 even if consumers have pending)
try:
@@ -451,20 +515,24 @@ async def get_streaming_status(request):
# (Sometimes group pending is 0 but consumers still have pending messages)
effective_pending = max(group_pending_count, consumer_pending_total)
- stream_data["consumer_groups"].append({
- "name": str(group_name),
- "consumers": consumer_list,
- "pending": int(effective_pending)
- })
+ stream_data["consumer_groups"].append(
+ {
+ "name": str(group_name),
+ "consumers": consumer_list,
+ "pending": int(effective_pending),
+ }
+ )
# Determine if stream is active or completed
# Active: has active consumers OR pending messages OR recent activity (< 5 min)
# Completed: no active consumers and idle > 5 minutes but < 1 hour
- total_pending = sum(group["pending"] for group in stream_data["consumer_groups"])
+ total_pending = sum(
+ group["pending"] for group in stream_data["consumer_groups"]
+ )
is_active = (
- has_active_consumer or
- total_pending > 0 or
- stream_age_seconds < 300 # Less than 5 minutes old
+ has_active_consumer
+ or total_pending > 0
+ or stream_age_seconds < 300 # Less than 5 minutes old
)
if is_active:
@@ -487,7 +555,7 @@ async def get_streaming_status(request):
"finished": len(transcription_queue.finished_job_registry),
"failed": len(transcription_queue.failed_job_registry),
"canceled": len(transcription_queue.canceled_job_registry),
- "deferred": len(transcription_queue.deferred_job_registry)
+ "deferred": len(transcription_queue.deferred_job_registry),
},
"memory_queue": {
"queued": memory_queue.count,
@@ -495,7 +563,7 @@ async def get_streaming_status(request):
"finished": len(memory_queue.finished_job_registry),
"failed": len(memory_queue.failed_job_registry),
"canceled": len(memory_queue.canceled_job_registry),
- "deferred": len(memory_queue.deferred_job_registry)
+ "deferred": len(memory_queue.deferred_job_registry),
},
"default_queue": {
"queued": default_queue.count,
@@ -503,8 +571,8 @@ async def get_streaming_status(request):
"finished": len(default_queue.finished_job_registry),
"failed": len(default_queue.failed_job_registry),
"canceled": len(default_queue.canceled_job_registry),
- "deferred": len(default_queue.deferred_job_registry)
- }
+ "deferred": len(default_queue.deferred_job_registry),
+ },
}
return {
@@ -514,14 +582,14 @@ async def get_streaming_status(request):
"completed_streams": completed_streams,
"stream_health": active_streams, # Backward compatibility - use active_streams
"rq_queues": rq_stats,
- "timestamp": time.time()
+ "timestamp": time.time(),
}
except Exception as e:
logger.error(f"Error getting streaming status: {e}", exc_info=True)
return JSONResponse(
status_code=500,
- content={"error": f"Failed to get streaming status: {str(e)}"}
+ content={"error": f"Failed to get streaming status: {str(e)}"},
)
@@ -538,7 +606,7 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
if not redis_client:
return JSONResponse(
status_code=503,
- content={"error": "Redis client for audio streaming not initialized"}
+ content={"error": "Redis client for audio streaming not initialized"},
)
# Get all session keys
@@ -560,17 +628,18 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
age_seconds = current_time - started_at
# Clean up sessions older than max_age or stuck in "finalizing"
- should_clean = (
- age_seconds > max_age_seconds or
- (status == "finalizing" and age_seconds > 300) # Finalizing for more than 5 minutes
- )
+ should_clean = age_seconds > max_age_seconds or (
+ status == "finalizing" and age_seconds > 300
+ ) # Finalizing for more than 5 minutes
if should_clean:
- old_sessions.append({
- "session_id": session_id,
- "age_seconds": age_seconds,
- "status": status
- })
+ old_sessions.append(
+ {
+ "session_id": session_id,
+ "age_seconds": age_seconds,
+ "status": status,
+ }
+ )
await redis_client.delete(key)
cleaned_sessions += 1
@@ -580,17 +649,25 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
old_streams = []
for stream_key in stream_keys:
- stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ stream_name = (
+ stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ )
try:
# Check stream info to get last activity
- stream_info = await redis_client.execute_command('XINFO', 'STREAM', stream_name)
+ stream_info = await redis_client.execute_command(
+ "XINFO", "STREAM", stream_name
+ )
# Parse stream info
info_dict = {}
for i in range(0, len(stream_info), 2):
- key_name = stream_info[i].decode() if isinstance(stream_info[i], bytes) else str(stream_info[i])
- info_dict[key_name] = stream_info[i+1]
+ key_name = (
+ stream_info[i].decode()
+ if isinstance(stream_info[i], bytes)
+ else str(stream_info[i])
+ )
+ info_dict[key_name] = stream_info[i + 1]
stream_length = int(info_dict.get("length", 0))
last_entry = info_dict.get("last-entry")
@@ -603,7 +680,9 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
# Empty stream - safe to delete
should_delete = True
reason = "empty"
- elif last_entry and isinstance(last_entry, list) and len(last_entry) > 0:
+ elif (
+ last_entry and isinstance(last_entry, list) and len(last_entry) > 0
+ ):
# Extract timestamp from last entry ID
last_id = last_entry[0]
if isinstance(last_id, bytes):
@@ -611,7 +690,7 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
# Redis Stream IDs format: "milliseconds-sequence"
try:
- last_timestamp_ms = int(last_id.split('-')[0])
+ last_timestamp_ms = int(last_id.split("-")[0])
last_timestamp_s = last_timestamp_ms / 1000
age_seconds = current_time - last_timestamp_s
@@ -622,12 +701,16 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
except (ValueError, IndexError):
# If we can't parse timestamp, check if first entry is old
first_entry = info_dict.get("first-entry")
- if first_entry and isinstance(first_entry, list) and len(first_entry) > 0:
+ if (
+ first_entry
+ and isinstance(first_entry, list)
+ and len(first_entry) > 0
+ ):
try:
first_id = first_entry[0]
if isinstance(first_id, bytes):
first_id = first_id.decode()
- first_timestamp_ms = int(first_id.split('-')[0])
+ first_timestamp_ms = int(first_id.split("-")[0])
first_timestamp_s = first_timestamp_ms / 1000
age_seconds = current_time - first_timestamp_s
@@ -640,12 +723,14 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
if should_delete:
await redis_client.delete(stream_name)
cleaned_streams += 1
- old_streams.append({
- "stream_name": stream_name,
- "reason": reason,
- "age_seconds": age_seconds,
- "length": stream_length
- })
+ old_streams.append(
+ {
+ "stream_name": stream_name,
+ "reason": reason,
+ "age_seconds": age_seconds,
+ "length": stream_length,
+ }
+ )
except Exception as e:
logger.debug(f"Error checking stream {stream_name}: {e}")
@@ -657,11 +742,12 @@ async def cleanup_old_sessions(request, max_age_seconds: int = 3600):
"cleaned_streams": cleaned_streams,
"cleaned_session_details": old_sessions,
"cleaned_stream_details": old_streams,
- "timestamp": time.time()
+ "timestamp": time.time(),
}
except Exception as e:
logger.error(f"Error cleaning up old sessions: {e}", exc_info=True)
return JSONResponse(
- status_code=500, content={"error": f"Failed to cleanup old sessions: {str(e)}"}
+ status_code=500,
+ content={"error": f"Failed to cleanup old sessions: {str(e)}"},
)
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py
index bf3ce1b1..4a6f83c2 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/system_controller.py
@@ -7,30 +7,29 @@
import logging
import os
import re
-import signal
import shutil
+import signal
import time
import warnings
from datetime import UTC, datetime
+from io import StringIO
from pathlib import Path
from typing import Optional
-from io import StringIO
-
-from ruamel.yaml import YAML
from fastapi import HTTPException
+from ruamel.yaml import YAML
from advanced_omi_backend.config import (
get_diarization_settings as load_diarization_settings,
)
from advanced_omi_backend.config import get_misc_settings as load_misc_settings
-from advanced_omi_backend.config import (
- save_diarization_settings,
- save_misc_settings,
+from advanced_omi_backend.config import save_diarization_settings, save_misc_settings
+from advanced_omi_backend.config_loader import get_plugins_yml_path, save_config_section
+from advanced_omi_backend.model_registry import (
+ _find_config_path,
+ get_models_registry,
+ load_models_config,
)
-from advanced_omi_backend.config_loader import get_plugins_yml_path
-from advanced_omi_backend.config_loader import save_config_section
-from advanced_omi_backend.model_registry import _find_config_path, get_models_registry, load_models_config
from advanced_omi_backend.models.user import User
logger = logging.getLogger(__name__)
@@ -43,7 +42,7 @@
async def get_config_diagnostics():
"""
Get comprehensive configuration diagnostics.
-
+
Returns warnings, errors, and status for all configuration components.
"""
diagnostics = {
@@ -52,9 +51,9 @@ async def get_config_diagnostics():
"issues": [],
"warnings": [],
"info": [],
- "components": {}
+ "components": {},
}
-
+
# Test OmegaConf configuration loading
try:
from advanced_omi_backend.config_loader import load_config
@@ -63,7 +62,7 @@ async def get_config_diagnostics():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
config = load_config(force_reload=True)
-
+
# Check for OmegaConf warnings
for warning in w:
warning_msg = str(warning.message)
@@ -71,148 +70,174 @@ async def get_config_diagnostics():
# Extract the variable name from warning
if "variable '" in warning_msg.lower():
var_name = warning_msg.split("'")[1]
- diagnostics["warnings"].append({
- "component": "OmegaConf",
- "severity": "warning",
- "message": f"Environment variable '{var_name}' not set (using empty default)",
- "resolution": f"Set {var_name} in .env file if needed"
- })
-
+ diagnostics["warnings"].append(
+ {
+ "component": "OmegaConf",
+ "severity": "warning",
+ "message": f"Environment variable '{var_name}' not set (using empty default)",
+ "resolution": f"Set {var_name} in .env file if needed",
+ }
+ )
+
diagnostics["components"]["omegaconf"] = {
"status": "healthy",
- "message": "Configuration loaded successfully"
+ "message": "Configuration loaded successfully",
}
except Exception as e:
diagnostics["overall_status"] = "unhealthy"
- diagnostics["issues"].append({
- "component": "OmegaConf",
- "severity": "error",
- "message": f"Failed to load configuration: {str(e)}",
- "resolution": "Check config/defaults.yml and config/config.yml syntax"
- })
+ diagnostics["issues"].append(
+ {
+ "component": "OmegaConf",
+ "severity": "error",
+ "message": f"Failed to load configuration: {str(e)}",
+ "resolution": "Check config/defaults.yml and config/config.yml syntax",
+ }
+ )
diagnostics["components"]["omegaconf"] = {
"status": "unhealthy",
- "message": str(e)
+ "message": str(e),
}
-
+
# Test model registry
try:
from advanced_omi_backend.model_registry import get_models_registry
-
+
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
registry = get_models_registry()
-
+
# Capture model loading warnings
for warning in w:
warning_msg = str(warning.message)
- diagnostics["warnings"].append({
- "component": "Model Registry",
- "severity": "warning",
- "message": warning_msg,
- "resolution": "Check model definitions in config/defaults.yml"
- })
-
+ diagnostics["warnings"].append(
+ {
+ "component": "Model Registry",
+ "severity": "warning",
+ "message": warning_msg,
+ "resolution": "Check model definitions in config/defaults.yml",
+ }
+ )
+
if registry:
diagnostics["components"]["model_registry"] = {
"status": "healthy",
"message": f"Loaded {len(registry.models)} models",
"details": {
"total_models": len(registry.models),
- "defaults": dict(registry.defaults) if registry.defaults else {}
- }
+ "defaults": dict(registry.defaults) if registry.defaults else {},
+ },
}
-
+
# Check critical models
stt = registry.get_default("stt")
stt_stream = registry.get_default("stt_stream")
llm = registry.get_default("llm")
-
+
# STT check
if stt:
if stt.api_key:
- diagnostics["info"].append({
- "component": "STT (Batch)",
- "message": f"Configured: {stt.name} ({stt.model_provider}) - API key present"
- })
+ diagnostics["info"].append(
+ {
+ "component": "STT (Batch)",
+ "message": f"Configured: {stt.name} ({stt.model_provider}) - API key present",
+ }
+ )
else:
- diagnostics["warnings"].append({
- "component": "STT (Batch)",
- "severity": "warning",
- "message": f"{stt.name} ({stt.model_provider}) - No API key configured",
- "resolution": "Transcription can fail without API key"
- })
+ diagnostics["warnings"].append(
+ {
+ "component": "STT (Batch)",
+ "severity": "warning",
+ "message": f"{stt.name} ({stt.model_provider}) - No API key configured",
+ "resolution": "Transcription can fail without API key",
+ }
+ )
else:
- diagnostics["issues"].append({
- "component": "STT (Batch)",
- "severity": "error",
- "message": "No batch STT model configured",
- "resolution": "Set defaults.stt in config.yml"
- })
+ diagnostics["issues"].append(
+ {
+ "component": "STT (Batch)",
+ "severity": "error",
+ "message": "No batch STT model configured",
+ "resolution": "Set defaults.stt in config.yml",
+ }
+ )
diagnostics["overall_status"] = "partial"
-
+
# Streaming STT check
if stt_stream:
if stt_stream.api_key:
- diagnostics["info"].append({
- "component": "STT (Streaming)",
- "message": f"Configured: {stt_stream.name} ({stt_stream.model_provider}) - API key present"
- })
+ diagnostics["info"].append(
+ {
+ "component": "STT (Streaming)",
+ "message": f"Configured: {stt_stream.name} ({stt_stream.model_provider}) - API key present",
+ }
+ )
else:
- diagnostics["warnings"].append({
+ diagnostics["warnings"].append(
+ {
+ "component": "STT (Streaming)",
+ "severity": "warning",
+ "message": f"{stt_stream.name} ({stt_stream.model_provider}) - No API key configured",
+ "resolution": "Real-time transcription can fail without API key",
+ }
+ )
+ else:
+ diagnostics["warnings"].append(
+ {
"component": "STT (Streaming)",
"severity": "warning",
- "message": f"{stt_stream.name} ({stt_stream.model_provider}) - No API key configured",
- "resolution": "Real-time transcription can fail without API key"
- })
- else:
- diagnostics["warnings"].append({
- "component": "STT (Streaming)",
- "severity": "warning",
- "message": "No streaming STT model configured - streaming worker disabled",
- "resolution": "Set defaults.stt_stream in config.yml for WebSocket transcription"
- })
-
+ "message": "No streaming STT model configured - streaming worker disabled",
+ "resolution": "Set defaults.stt_stream in config.yml for WebSocket transcription",
+ }
+ )
+
# LLM check
if llm:
if llm.api_key:
- diagnostics["info"].append({
- "component": "LLM",
- "message": f"Configured: {llm.name} ({llm.model_provider}) - API key present"
- })
+ diagnostics["info"].append(
+ {
+ "component": "LLM",
+ "message": f"Configured: {llm.name} ({llm.model_provider}) - API key present",
+ }
+ )
else:
- diagnostics["warnings"].append({
- "component": "LLM",
- "severity": "warning",
- "message": f"{llm.name} ({llm.model_provider}) - No API key configured",
- "resolution": "Memory extraction can fail without API key"
- })
-
+ diagnostics["warnings"].append(
+ {
+ "component": "LLM",
+ "severity": "warning",
+ "message": f"{llm.name} ({llm.model_provider}) - No API key configured",
+ "resolution": "Memory extraction can fail without API key",
+ }
+ )
+
else:
diagnostics["overall_status"] = "unhealthy"
- diagnostics["issues"].append({
- "component": "Model Registry",
- "severity": "error",
- "message": "Failed to load model registry",
- "resolution": "Check config/defaults.yml for syntax errors"
- })
+ diagnostics["issues"].append(
+ {
+ "component": "Model Registry",
+ "severity": "error",
+ "message": "Failed to load model registry",
+ "resolution": "Check config/defaults.yml for syntax errors",
+ }
+ )
diagnostics["components"]["model_registry"] = {
"status": "unhealthy",
- "message": "Registry failed to load"
+ "message": "Registry failed to load",
}
except Exception as e:
diagnostics["overall_status"] = "partial"
- diagnostics["issues"].append({
- "component": "Model Registry",
- "severity": "error",
- "message": f"Error loading registry: {str(e)}",
- "resolution": "Check logs for detailed error information"
- })
+ diagnostics["issues"].append(
+ {
+ "component": "Model Registry",
+ "severity": "error",
+ "message": f"Error loading registry: {str(e)}",
+ "resolution": "Check logs for detailed error information",
+ }
+ )
diagnostics["components"]["model_registry"] = {
"status": "unhealthy",
- "message": str(e)
+ "message": str(e),
}
-
+
# Check environment variables (only warn about keys relevant to configured providers)
env_checks = [
("AUTH_SECRET_KEY", "Required for authentication"),
@@ -224,7 +249,9 @@ async def get_config_diagnostics():
# Add LLM API key check based on active provider
llm_model = registry.get_default("llm")
if llm_model and llm_model.model_provider == "openai":
- env_checks.append(("OPENAI_API_KEY", "Required for OpenAI LLM and embeddings"))
+ env_checks.append(
+ ("OPENAI_API_KEY", "Required for OpenAI LLM and embeddings")
+ )
elif llm_model and llm_model.model_provider == "groq":
env_checks.append(("GROQ_API_KEY", "Required for Groq LLM"))
@@ -233,20 +260,26 @@ async def get_config_diagnostics():
if stt_model:
provider = stt_model.model_provider
if provider == "deepgram":
- env_checks.append(("DEEPGRAM_API_KEY", "Required for Deepgram transcription"))
+ env_checks.append(
+ ("DEEPGRAM_API_KEY", "Required for Deepgram transcription")
+ )
elif provider == "smallest":
- env_checks.append(("SMALLEST_API_KEY", "Required for Smallest.ai Pulse transcription"))
-
+ env_checks.append(
+ ("SMALLEST_API_KEY", "Required for Smallest.ai Pulse transcription")
+ )
+
for env_var, description in env_checks:
value = os.getenv(env_var)
if not value or value == "":
- diagnostics["warnings"].append({
- "component": "Environment Variables",
- "severity": "warning",
- "message": f"{env_var} not set - {description}",
- "resolution": f"Set {env_var} in .env file"
- })
-
+ diagnostics["warnings"].append(
+ {
+ "component": "Environment Variables",
+ "severity": "warning",
+ "message": f"{env_var} not set - {description}",
+ "resolution": f"Set {env_var} in .env file",
+ }
+ )
+
return diagnostics
@@ -297,7 +330,9 @@ async def get_observability_config():
from advanced_omi_backend.config_loader import load_config
cfg = load_config()
- public_url = cfg.get("observability", {}).get("langfuse", {}).get("public_url", "")
+ public_url = (
+ cfg.get("observability", {}).get("langfuse", {}).get("public_url", "")
+ )
if public_url:
# Strip trailing slash and build session URL
session_base_url = f"{public_url.rstrip('/')}/project/chronicle/sessions"
@@ -321,10 +356,7 @@ async def get_diarization_settings():
try:
# Get settings using OmegaConf
settings = load_diarization_settings()
- return {
- "settings": settings,
- "status": "success"
- }
+ return {"settings": settings, "status": "success"}
except Exception as e:
logger.exception("Error getting diarization settings")
raise e
@@ -335,8 +367,13 @@ async def save_diarization_settings_controller(settings: dict):
try:
# Validate settings
valid_keys = {
- "diarization_source", "similarity_threshold", "min_duration", "collar",
- "min_duration_off", "min_speakers", "max_speakers"
+ "diarization_source",
+ "similarity_threshold",
+ "min_duration",
+ "collar",
+ "min_duration_off",
+ "min_speakers",
+ "max_speakers",
}
# Filter to only valid keys (allow round-trip GETβPOST)
@@ -348,19 +385,30 @@ async def save_diarization_settings_controller(settings: dict):
# Type validation for known keys only
if key in ["min_speakers", "max_speakers"]:
if not isinstance(value, int) or value < 1 or value > 20:
- raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be integer 1-20")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid value for {key}: must be integer 1-20",
+ )
elif key == "diarization_source":
if not isinstance(value, str) or value not in ["pyannote", "deepgram"]:
- raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be 'pyannote' or 'deepgram'")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid value for {key}: must be 'pyannote' or 'deepgram'",
+ )
else:
if not isinstance(value, (int, float)) or value < 0:
- raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be positive number")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid value for {key}: must be positive number",
+ )
filtered_settings[key] = value
# Reject if NO valid keys provided (completely invalid request)
if not filtered_settings:
- raise HTTPException(status_code=400, detail="No valid diarization settings provided")
+ raise HTTPException(
+ status_code=400, detail="No valid diarization settings provided"
+ )
# Get current settings and merge with new values
current_settings = load_diarization_settings()
@@ -373,14 +421,14 @@ async def save_diarization_settings_controller(settings: dict):
return {
"message": "Diarization settings saved successfully",
"settings": current_settings,
- "status": "success"
+ "status": "success",
}
else:
logger.warning("Settings save failed")
return {
"message": "Settings save failed",
"settings": current_settings,
- "status": "error"
+ "status": "error",
}
except Exception as e:
@@ -393,10 +441,7 @@ async def get_misc_settings():
try:
# Get settings using OmegaConf
settings = load_misc_settings()
- return {
- "settings": settings,
- "status": "success"
- }
+ return {"settings": settings, "status": "success"}
except Exception as e:
logger.exception("Error getting misc settings")
raise e
@@ -406,7 +451,12 @@ async def save_misc_settings_controller(settings: dict):
"""Save miscellaneous settings."""
try:
# Validate settings
- boolean_keys = {"always_persist_enabled", "use_provider_segments", "per_segment_speaker_id", "always_batch_retranscribe"}
+ boolean_keys = {
+ "always_persist_enabled",
+ "use_provider_segments",
+ "per_segment_speaker_id",
+ "always_batch_retranscribe",
+ }
integer_keys = {"transcription_job_timeout_seconds"}
valid_keys = boolean_keys | integer_keys
@@ -419,16 +469,24 @@ async def save_misc_settings_controller(settings: dict):
# Type validation
if key in boolean_keys:
if not isinstance(value, bool):
- raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be boolean")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid value for {key}: must be boolean",
+ )
elif key == "transcription_job_timeout_seconds":
if not isinstance(value, int) or value < 60 or value > 7200:
- raise HTTPException(status_code=400, detail=f"Invalid value for {key}: must be integer between 60 and 7200")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid value for {key}: must be integer between 60 and 7200",
+ )
filtered_settings[key] = value
# Reject if NO valid keys provided
if not filtered_settings:
- raise HTTPException(status_code=400, detail="No valid misc settings provided")
+ raise HTTPException(
+ status_code=400, detail="No valid misc settings provided"
+ )
# Save using OmegaConf
if save_misc_settings(filtered_settings):
@@ -439,14 +497,14 @@ async def save_misc_settings_controller(settings: dict):
return {
"message": "Miscellaneous settings saved successfully",
"settings": updated_settings,
- "status": "success"
+ "status": "success",
}
else:
logger.warning("Settings save failed")
return {
"message": "Settings save failed",
"settings": load_misc_settings(),
- "status": "error"
+ "status": "error",
}
except HTTPException:
@@ -472,9 +530,7 @@ async def get_cleanup_settings_controller(user: User) -> dict:
async def save_cleanup_settings_controller(
- auto_cleanup_enabled: bool,
- retention_days: int,
- user: User
+ auto_cleanup_enabled: bool, retention_days: int, user: User
) -> dict:
"""
Save cleanup settings (admin only).
@@ -504,19 +560,20 @@ async def save_cleanup_settings_controller(
# Create settings object
settings = CleanupSettings(
- auto_cleanup_enabled=auto_cleanup_enabled,
- retention_days=retention_days
+ auto_cleanup_enabled=auto_cleanup_enabled, retention_days=retention_days
)
# Save using OmegaConf
save_cleanup_settings(settings)
- logger.info(f"Admin {user.email} updated cleanup settings: auto_cleanup={auto_cleanup_enabled}, retention={retention_days}d")
+ logger.info(
+ f"Admin {user.email} updated cleanup settings: auto_cleanup={auto_cleanup_enabled}, retention={retention_days}d"
+ )
return {
"auto_cleanup_enabled": settings.auto_cleanup_enabled,
"retention_days": settings.retention_days,
- "message": "Cleanup settings saved successfully"
+ "message": "Cleanup settings saved successfully",
}
@@ -526,7 +583,7 @@ async def get_speaker_configuration(user: User):
return {
"primary_speakers": user.primary_speakers,
"user_id": user.user_id,
- "status": "success"
+ "status": "success",
}
except Exception as e:
logger.exception(f"Error getting speaker configuration for user {user.user_id}")
@@ -540,32 +597,36 @@ async def update_speaker_configuration(user: User, primary_speakers: list[dict])
for speaker in primary_speakers:
if not isinstance(speaker, dict):
raise ValueError("Each speaker must be a dictionary")
-
+
required_fields = ["speaker_id", "name", "user_id"]
for field in required_fields:
if field not in speaker:
raise ValueError(f"Missing required field: {field}")
-
+
# Enforce server-side user_id and add timestamp to each speaker
for speaker in primary_speakers:
speaker["user_id"] = user.user_id # Override client-supplied user_id
speaker["selected_at"] = datetime.now(UTC).isoformat()
-
+
# Update user model
user.primary_speakers = primary_speakers
await user.save()
-
- logger.info(f"Updated primary speakers configuration for user {user.user_id}: {len(primary_speakers)} speakers")
-
+
+ logger.info(
+ f"Updated primary speakers configuration for user {user.user_id}: {len(primary_speakers)} speakers"
+ )
+
return {
"message": "Primary speakers configuration updated successfully",
"primary_speakers": primary_speakers,
"count": len(primary_speakers),
- "status": "success"
+ "status": "success",
}
-
+
except Exception as e:
- logger.exception(f"Error updating speaker configuration for user {user.user_id}")
+ logger.exception(
+ f"Error updating speaker configuration for user {user.user_id}"
+ )
raise e
@@ -578,25 +639,25 @@ async def get_enrolled_speakers(user: User):
# Initialize speaker recognition client
speaker_client = SpeakerRecognitionClient()
-
+
if not speaker_client.enabled:
return {
"speakers": [],
"service_available": False,
"message": "Speaker recognition service is not configured or disabled",
- "status": "success"
+ "status": "success",
}
-
+
# Get enrolled speakers - using hardcoded user_id=1 for now (as noted in speaker_recognition_client.py)
speakers = await speaker_client.get_enrolled_speakers(user_id="1")
-
+
return {
"speakers": speakers.get("speakers", []) if speakers else [],
"service_available": True,
"message": "Successfully retrieved enrolled speakers",
- "status": "success"
+ "status": "success",
}
-
+
except Exception as e:
logger.exception(f"Error getting enrolled speakers for user {user.user_id}")
raise e
@@ -611,25 +672,25 @@ async def get_speaker_service_status():
# Initialize speaker recognition client
speaker_client = SpeakerRecognitionClient()
-
+
if not speaker_client.enabled:
return {
"service_available": False,
"healthy": False,
"message": "Speaker recognition service is not configured or disabled",
- "status": "disabled"
+ "status": "disabled",
}
-
+
# Perform health check
health_result = await speaker_client.health_check()
-
+
if health_result:
return {
"service_available": True,
"healthy": True,
"message": "Speaker recognition service is healthy",
"service_url": speaker_client.service_url,
- "status": "healthy"
+ "status": "healthy",
}
else:
return {
@@ -637,17 +698,17 @@ async def get_speaker_service_status():
"healthy": False,
"message": "Speaker recognition service is not responding",
"service_url": speaker_client.service_url,
- "status": "unhealthy"
+ "status": "unhealthy",
}
-
+
except Exception as e:
logger.exception("Error checking speaker service status")
raise e
-
# Memory Configuration Management Functions
+
async def get_memory_config_raw():
"""Get current memory configuration (memory section of config.yml) as YAML."""
try:
@@ -655,7 +716,7 @@ async def get_memory_config_raw():
if not os.path.exists(cfg_path):
raise FileNotFoundError(f"Config file not found: {cfg_path}")
- with open(cfg_path, 'r') as f:
+ with open(cfg_path, "r") as f:
data = _yaml.load(f) or {}
memory_section = data.get("memory", {})
stream = StringIO()
@@ -691,10 +752,10 @@ async def update_memory_config_raw(config_yaml: str):
shutil.copy2(cfg_path, backup_path)
# Update memory section and write file
- with open(cfg_path, 'r') as f:
+ with open(cfg_path, "r") as f:
data = _yaml.load(f) or {}
data["memory"] = new_mem
- with open(cfg_path, 'w') as f:
+ with open(cfg_path, "w") as f:
_yaml.dump(data, f)
# Reload registry
@@ -717,9 +778,13 @@ async def validate_memory_config(config_yaml: str):
try:
parsed = _yaml.load(config_yaml)
except Exception as e:
- raise HTTPException(status_code=400, detail=f"Invalid YAML syntax: {str(e)}")
+ raise HTTPException(
+ status_code=400, detail=f"Invalid YAML syntax: {str(e)}"
+ )
if not isinstance(parsed, dict):
- raise HTTPException(status_code=400, detail="Configuration must be a YAML object")
+ raise HTTPException(
+ status_code=400, detail="Configuration must be a YAML object"
+ )
# Minimal checks
# provider optional; timeout_seconds optional; extraction enabled/prompt optional
return {"message": "Configuration is valid", "status": "success"}
@@ -728,7 +793,9 @@ async def validate_memory_config(config_yaml: str):
raise
except Exception as e:
logger.exception("Error validating memory config")
- raise HTTPException(status_code=500, detail=f"Error validating memory config: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Error validating memory config: {str(e)}"
+ )
async def reload_memory_config():
@@ -736,7 +803,11 @@ async def reload_memory_config():
try:
cfg_path = _find_config_path()
load_models_config(force_reload=True)
- return {"message": "Configuration reloaded", "config_path": str(cfg_path), "status": "success"}
+ return {
+ "message": "Configuration reloaded",
+ "config_path": str(cfg_path),
+ "status": "success",
+ }
except Exception as e:
logger.exception("Error reloading config")
raise e
@@ -758,7 +829,7 @@ async def delete_all_user_memories(user: User):
"message": f"Successfully deleted {deleted_count} memories",
"deleted_count": deleted_count,
"user_id": user.user_id,
- "status": "success"
+ "status": "success",
}
except Exception as e:
@@ -768,6 +839,7 @@ async def delete_all_user_memories(user: User):
# Memory Provider Configuration Functions
+
async def get_memory_provider():
"""Get current memory provider configuration."""
try:
@@ -782,7 +854,7 @@ async def get_memory_provider():
return {
"current_provider": current_provider,
"available_providers": available_providers,
- "status": "success"
+ "status": "success",
}
except Exception as e:
@@ -798,7 +870,9 @@ async def set_memory_provider(provider: str):
valid_providers = ["chronicle", "openmemory_mcp"]
if provider not in valid_providers:
- raise ValueError(f"Invalid provider '{provider}'. Valid providers: {', '.join(valid_providers)}")
+ raise ValueError(
+ f"Invalid provider '{provider}'. Valid providers: {', '.join(valid_providers)}"
+ )
# Path to .env file (assuming we're running from backends/advanced/)
env_path = os.path.join(os.getcwd(), ".env")
@@ -807,7 +881,7 @@ async def set_memory_provider(provider: str):
raise FileNotFoundError(f".env file not found at {env_path}")
# Read current .env file
- with open(env_path, 'r') as file:
+ with open(env_path, "r") as file:
lines = file.readlines()
# Update or add MEMORY_PROVIDER line
@@ -823,7 +897,9 @@ async def set_memory_provider(provider: str):
# If MEMORY_PROVIDER wasn't found, add it
if not provider_found:
- updated_lines.append(f"\n# Memory Provider Configuration\nMEMORY_PROVIDER={provider}\n")
+ updated_lines.append(
+ f"\n# Memory Provider Configuration\nMEMORY_PROVIDER={provider}\n"
+ )
# Create backup
backup_path = f"{env_path}.bak"
@@ -831,7 +907,7 @@ async def set_memory_provider(provider: str):
logger.info(f"Created .env backup at {backup_path}")
# Write updated .env file
- with open(env_path, 'w') as file:
+ with open(env_path, "w") as file:
file.writelines(updated_lines)
# Update environment variable for current process
@@ -845,7 +921,7 @@ async def set_memory_provider(provider: str):
"env_path": env_path,
"backup_created": True,
"requires_restart": True,
- "status": "success"
+ "status": "success",
}
except Exception as e:
@@ -855,6 +931,7 @@ async def set_memory_provider(provider: str):
# LLM Operations Configuration Functions
+
async def get_llm_operations():
"""Get LLM operation configurations and available models."""
try:
@@ -902,29 +979,49 @@ async def save_llm_operations(operations: dict):
for op_name, op_value in operations.items():
if not isinstance(op_value, dict):
- raise HTTPException(status_code=400, detail=f"Operation '{op_name}' must be a dict")
+ raise HTTPException(
+ status_code=400, detail=f"Operation '{op_name}' must be a dict"
+ )
extra_keys = set(op_value.keys()) - valid_keys
if extra_keys:
- raise HTTPException(status_code=400, detail=f"Invalid keys for '{op_name}': {extra_keys}")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid keys for '{op_name}': {extra_keys}",
+ )
if "temperature" in op_value and op_value["temperature"] is not None:
t = op_value["temperature"]
if not isinstance(t, (int, float)) or t < 0 or t > 2:
- raise HTTPException(status_code=400, detail=f"Invalid temperature for '{op_name}': must be 0-2")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid temperature for '{op_name}': must be 0-2",
+ )
if "max_tokens" in op_value and op_value["max_tokens"] is not None:
mt = op_value["max_tokens"]
if not isinstance(mt, int) or mt <= 0:
- raise HTTPException(status_code=400, detail=f"Invalid max_tokens for '{op_name}': must be positive int")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid max_tokens for '{op_name}': must be positive int",
+ )
if "model" in op_value and op_value["model"] is not None:
if not registry.get_by_name(op_value["model"]):
- raise HTTPException(status_code=400, detail=f"Model '{op_value['model']}' not found in registry")
-
- if "response_format" in op_value and op_value["response_format"] is not None:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Model '{op_value['model']}' not found in registry",
+ )
+
+ if (
+ "response_format" in op_value
+ and op_value["response_format"] is not None
+ ):
if op_value["response_format"] != "json":
- raise HTTPException(status_code=400, detail=f"response_format must be 'json' or null")
+ raise HTTPException(
+ status_code=400,
+ detail=f"response_format must be 'json' or null",
+ )
if save_config_section("llm_operations", operations):
load_models_config(force_reload=True)
@@ -958,11 +1055,21 @@ async def test_llm_model(model_name: Optional[str]):
if model_name:
model_def = registry.get_by_name(model_name)
if not model_def:
- return {"success": False, "model_name": model_name, "error": f"Model '{model_name}' not found", "status": "error"}
+ return {
+ "success": False,
+ "model_name": model_name,
+ "error": f"Model '{model_name}' not found",
+ "status": "error",
+ }
else:
model_def = registry.get_default("llm")
if not model_def:
- return {"success": False, "model_name": None, "error": "No default LLM configured", "status": "error"}
+ return {
+ "success": False,
+ "model_name": None,
+ "error": "No default LLM configured",
+ "status": "error",
+ }
client = create_openai_client(
api_key=model_def.api_key or "",
@@ -998,6 +1105,7 @@ async def test_llm_model(model_name: Optional[str]):
# Chat Configuration Management Functions
+
async def get_chat_config_yaml() -> str:
"""Get chat system prompt as plain text."""
try:
@@ -1012,11 +1120,11 @@ async def get_chat_config_yaml() -> str:
if not os.path.exists(config_path):
return default_prompt
- with open(config_path, 'r') as f:
+ with open(config_path, "r") as f:
full_config = _yaml.load(f) or {}
- chat_config = full_config.get('chat', {})
- system_prompt = chat_config.get('system_prompt', default_prompt)
+ chat_config = full_config.get("chat", {})
+ system_prompt = chat_config.get("system_prompt", default_prompt)
# Return just the prompt text, not the YAML structure
return system_prompt
@@ -1042,26 +1150,26 @@ async def save_chat_config_yaml(prompt_text: str) -> dict:
raise ValueError("Prompt too long (maximum 10000 characters)")
# Create chat config dict
- chat_config = {'system_prompt': prompt_text}
+ chat_config = {"system_prompt": prompt_text}
# Load full config
if os.path.exists(config_path):
- with open(config_path, 'r') as f:
+ with open(config_path, "r") as f:
full_config = _yaml.load(f) or {}
else:
full_config = {}
# Backup existing config
if os.path.exists(config_path):
- backup_path = str(config_path) + '.backup'
+ backup_path = str(config_path) + ".backup"
shutil.copy2(config_path, backup_path)
logger.info(f"Created config backup at {backup_path}")
# Update chat section
- full_config['chat'] = chat_config
+ full_config["chat"] = chat_config
# Save
- with open(config_path, 'w') as f:
+ with open(config_path, "w") as f:
_yaml.dump(full_config, f)
# Reload config in memory (hot-reload)
@@ -1087,7 +1195,10 @@ async def validate_chat_config_yaml(prompt_text: str) -> dict:
if len(prompt_text) < 10:
return {"valid": False, "error": "Prompt too short (minimum 10 characters)"}
if len(prompt_text) > 10000:
- return {"valid": False, "error": "Prompt too long (maximum 10000 characters)"}
+ return {
+ "valid": False,
+ "error": "Prompt too long (maximum 10000 characters)",
+ }
return {"valid": True, "message": "Configuration is valid"}
@@ -1098,6 +1209,7 @@ async def validate_chat_config_yaml(prompt_text: str) -> dict:
# Plugin Configuration Management Functions
+
async def get_plugins_config_yaml() -> str:
"""Get plugins configuration as YAML text."""
try:
@@ -1120,7 +1232,7 @@ async def get_plugins_config_yaml() -> str:
if not plugins_yml_path.exists():
return default_config
- with open(plugins_yml_path, 'r') as f:
+ with open(plugins_yml_path, "r") as f:
yaml_content = f.read()
return yaml_content
@@ -1142,7 +1254,7 @@ async def save_plugins_config_yaml(yaml_content: str) -> dict:
raise ValueError("Configuration must be a YAML dictionary")
# Validate has 'plugins' key
- if 'plugins' not in parsed_config:
+ if "plugins" not in parsed_config:
raise ValueError("Configuration must contain 'plugins' key")
except ValueError:
@@ -1155,12 +1267,12 @@ async def save_plugins_config_yaml(yaml_content: str) -> dict:
# Backup existing config
if plugins_yml_path.exists():
- backup_path = str(plugins_yml_path) + '.backup'
+ backup_path = str(plugins_yml_path) + ".backup"
shutil.copy2(plugins_yml_path, backup_path)
logger.info(f"Created plugins config backup at {backup_path}")
# Save new config
- with open(plugins_yml_path, 'w') as f:
+ with open(plugins_yml_path, "w") as f:
f.write(yaml_content)
# Hot-reload plugins and signal worker restart
@@ -1201,35 +1313,55 @@ async def validate_plugins_config_yaml(yaml_content: str) -> dict:
if not isinstance(parsed_config, dict):
return {"valid": False, "error": "Configuration must be a YAML dictionary"}
- if 'plugins' not in parsed_config:
+ if "plugins" not in parsed_config:
return {"valid": False, "error": "Configuration must contain 'plugins' key"}
- plugins = parsed_config['plugins']
+ plugins = parsed_config["plugins"]
if not isinstance(plugins, dict):
return {"valid": False, "error": "'plugins' must be a dictionary"}
# Validate each plugin
- valid_access_levels = ['transcript', 'conversation', 'memory']
- valid_trigger_types = ['wake_word', 'always', 'conditional']
+ valid_access_levels = ["transcript", "conversation", "memory"]
+ valid_trigger_types = ["wake_word", "always", "conditional"]
for plugin_id, plugin_config in plugins.items():
if not isinstance(plugin_config, dict):
- return {"valid": False, "error": f"Plugin '{plugin_id}' config must be a dictionary"}
+ return {
+ "valid": False,
+ "error": f"Plugin '{plugin_id}' config must be a dictionary",
+ }
# Check required fields
- if 'enabled' in plugin_config and not isinstance(plugin_config['enabled'], bool):
- return {"valid": False, "error": f"Plugin '{plugin_id}': 'enabled' must be boolean"}
+ if "enabled" in plugin_config and not isinstance(
+ plugin_config["enabled"], bool
+ ):
+ return {
+ "valid": False,
+ "error": f"Plugin '{plugin_id}': 'enabled' must be boolean",
+ }
- if 'access_level' in plugin_config and plugin_config['access_level'] not in valid_access_levels:
- return {"valid": False, "error": f"Plugin '{plugin_id}': invalid access_level (must be one of {valid_access_levels})"}
+ if (
+ "access_level" in plugin_config
+ and plugin_config["access_level"] not in valid_access_levels
+ ):
+ return {
+ "valid": False,
+ "error": f"Plugin '{plugin_id}': invalid access_level (must be one of {valid_access_levels})",
+ }
- if 'trigger' in plugin_config:
- trigger = plugin_config['trigger']
+ if "trigger" in plugin_config:
+ trigger = plugin_config["trigger"]
if not isinstance(trigger, dict):
- return {"valid": False, "error": f"Plugin '{plugin_id}': 'trigger' must be a dictionary"}
+ return {
+ "valid": False,
+ "error": f"Plugin '{plugin_id}': 'trigger' must be a dictionary",
+ }
- if 'type' in trigger and trigger['type'] not in valid_trigger_types:
- return {"valid": False, "error": f"Plugin '{plugin_id}': invalid trigger type (must be one of {valid_trigger_types})"}
+ if "type" in trigger and trigger["type"] not in valid_trigger_types:
+ return {
+ "valid": False,
+ "error": f"Plugin '{plugin_id}': invalid trigger type (must be one of {valid_trigger_types})",
+ }
return {"valid": True, "message": "Configuration is valid"}
@@ -1314,9 +1446,11 @@ async def reload_plugins_controller(app=None) -> dict:
return {
"success": reload_result.get("success", False),
- "message": "Plugins reloaded and worker restart signaled"
- if worker_signal_sent
- else "Plugins reloaded but worker restart signal failed",
+ "message": (
+ "Plugins reloaded and worker restart signaled"
+ if worker_signal_sent
+ else "Plugins reloaded but worker restart signal failed"
+ ),
"reload": reload_result,
"worker_signal_sent": worker_signal_sent,
}
@@ -1324,6 +1458,7 @@ async def reload_plugins_controller(app=None) -> dict:
# Structured Plugin Configuration Management Functions (Form-based UI)
+
async def get_plugins_metadata() -> dict:
"""Get plugin metadata for form-based configuration UI.
@@ -1350,30 +1485,28 @@ async def get_plugins_metadata() -> dict:
orchestration_configs = {}
if plugins_yml_path.exists():
- with open(plugins_yml_path, 'r') as f:
+ with open(plugins_yml_path, "r") as f:
plugins_data = _yaml.load(f) or {}
- orchestration_configs = plugins_data.get('plugins', {})
+ orchestration_configs = plugins_data.get("plugins", {})
# Build metadata for each plugin
plugins_metadata = []
for plugin_id, plugin_class in discovered_plugins.items():
# Get orchestration config (or empty dict if not configured)
- orchestration_config = orchestration_configs.get(plugin_id, {
- 'enabled': False,
- 'events': [],
- 'condition': {'type': 'always'}
- })
+ orchestration_config = orchestration_configs.get(
+ plugin_id,
+ {"enabled": False, "events": [], "condition": {"type": "always"}},
+ )
# Get complete metadata including schema
- metadata = get_plugin_metadata(plugin_id, plugin_class, orchestration_config)
+ metadata = get_plugin_metadata(
+ plugin_id, plugin_class, orchestration_config
+ )
plugins_metadata.append(metadata)
logger.info(f"Retrieved metadata for {len(plugins_metadata)} plugins")
- return {
- "plugins": plugins_metadata,
- "status": "success"
- }
+ return {"plugins": plugins_metadata, "status": "success"}
except Exception as e:
logger.exception("Error getting plugins metadata")
@@ -1396,7 +1529,10 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict:
Success message with list of updated files
"""
try:
- from advanced_omi_backend.services.plugin_service import _get_plugins_dir, discover_plugins
+ from advanced_omi_backend.services.plugin_service import (
+ _get_plugins_dir,
+ discover_plugins,
+ )
# Validate plugin exists
discovered_plugins = discover_plugins()
@@ -1406,84 +1542,87 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict:
updated_files = []
# 1. Update config/plugins.yml (orchestration)
- if 'orchestration' in config:
+ if "orchestration" in config:
plugins_yml_path = get_plugins_yml_path()
# Load current plugins.yml
if plugins_yml_path.exists():
- with open(plugins_yml_path, 'r') as f:
+ with open(plugins_yml_path, "r") as f:
plugins_data = _yaml.load(f) or {}
else:
plugins_data = {}
- if 'plugins' not in plugins_data:
- plugins_data['plugins'] = {}
+ if "plugins" not in plugins_data:
+ plugins_data["plugins"] = {}
# Update orchestration config
- orchestration = config['orchestration']
- plugins_data['plugins'][plugin_id] = {
- 'enabled': orchestration.get('enabled', False),
- 'events': orchestration.get('events', []),
- 'condition': orchestration.get('condition', {'type': 'always'})
+ orchestration = config["orchestration"]
+ plugins_data["plugins"][plugin_id] = {
+ "enabled": orchestration.get("enabled", False),
+ "events": orchestration.get("events", []),
+ "condition": orchestration.get("condition", {"type": "always"}),
}
# Create backup
if plugins_yml_path.exists():
- backup_path = str(plugins_yml_path) + '.backup'
+ backup_path = str(plugins_yml_path) + ".backup"
shutil.copy2(plugins_yml_path, backup_path)
# Create config directory if needed
plugins_yml_path.parent.mkdir(parents=True, exist_ok=True)
# Write updated plugins.yml
- with open(plugins_yml_path, 'w') as f:
+ with open(plugins_yml_path, "w") as f:
_yaml.dump(plugins_data, f)
updated_files.append(str(plugins_yml_path))
- logger.info(f"Updated orchestration config for '{plugin_id}' in {plugins_yml_path}")
+ logger.info(
+ f"Updated orchestration config for '{plugin_id}' in {plugins_yml_path}"
+ )
# 2. Update plugins/{plugin_id}/config.yml (settings with env var references)
- if 'settings' in config:
+ if "settings" in config:
plugins_dir = _get_plugins_dir()
plugin_config_path = plugins_dir / plugin_id / "config.yml"
# Load current config.yml
if plugin_config_path.exists():
- with open(plugin_config_path, 'r') as f:
+ with open(plugin_config_path, "r") as f:
plugin_config_data = _yaml.load(f) or {}
else:
plugin_config_data = {}
# Update settings (preserve ${ENV_VAR} references)
- settings = config['settings']
+ settings = config["settings"]
plugin_config_data.update(settings)
# Create backup
if plugin_config_path.exists():
- backup_path = str(plugin_config_path) + '.backup'
+ backup_path = str(plugin_config_path) + ".backup"
shutil.copy2(plugin_config_path, backup_path)
# Write updated config.yml
- with open(plugin_config_path, 'w') as f:
+ with open(plugin_config_path, "w") as f:
_yaml.dump(plugin_config_data, f)
updated_files.append(str(plugin_config_path))
logger.info(f"Updated settings for '{plugin_id}' in {plugin_config_path}")
# 3. Update per-plugin .env (only changed env vars)
- if 'env_vars' in config and config['env_vars']:
+ if "env_vars" in config and config["env_vars"]:
from advanced_omi_backend.services.plugin_service import save_plugin_env
# Filter out masked values (unchanged secrets)
changed_vars = {
- k: v for k, v in config['env_vars'].items()
- if v != 'β’β’β’β’β’β’β’β’β’β’β’β’'
+ k: v for k, v in config["env_vars"].items() if v != "β’β’β’β’β’β’β’β’β’β’β’β’"
}
if changed_vars:
env_path = save_plugin_env(plugin_id, changed_vars)
updated_files.append(str(env_path))
- logger.info(f"Saved {len(changed_vars)} env var(s) to per-plugin .env for '{plugin_id}'")
+ logger.info(
+ f"Saved {len(changed_vars)} env var(s) to per-plugin .env for '{plugin_id}'"
+ )
# Update os.environ so hot-reload picks up changes immediately
for k, v in changed_vars.items():
@@ -1496,7 +1635,9 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict:
except Exception as reload_err:
logger.warning(f"Auto-reload failed, manual restart needed: {reload_err}")
- message = f"Plugin '{plugin_id}' configuration updated and reloaded successfully."
+ message = (
+ f"Plugin '{plugin_id}' configuration updated and reloaded successfully."
+ )
if reload_result is None:
message = f"Plugin '{plugin_id}' configuration updated. Restart backend for changes to take effect."
@@ -1505,7 +1646,7 @@ async def update_plugin_config_structured(plugin_id: str, config: dict) -> dict:
"message": message,
"updated_files": updated_files,
"reload": reload_result,
- "status": "success"
+ "status": "success",
}
except Exception as e:
@@ -1541,29 +1682,29 @@ async def test_plugin_connection(plugin_id: str, config: dict) -> dict:
plugin_class = discovered_plugins[plugin_id]
# Check if plugin supports testing
- if not hasattr(plugin_class, 'test_connection'):
+ if not hasattr(plugin_class, "test_connection"):
return {
"success": False,
"message": f"Plugin '{plugin_id}' does not support connection testing",
- "status": "unsupported"
+ "status": "unsupported",
}
# Build complete config from provided data
test_config = {}
# Merge settings
- if 'settings' in config:
- test_config.update(config['settings'])
+ if "settings" in config:
+ test_config.update(config["settings"])
# Load per-plugin env for resolving masked values
plugin_env = load_plugin_env(plugin_id)
# Add env vars (expand any ${ENV_VAR} references with test values)
- if 'env_vars' in config:
- for key, value in config['env_vars'].items():
+ if "env_vars" in config:
+ for key, value in config["env_vars"].items():
# For masked values, resolve from per-plugin .env then os.environ
- if value == 'β’β’β’β’β’β’β’β’β’β’β’β’':
- value = plugin_env.get(key) or os.getenv(key, '')
+ if value == "β’β’β’β’β’β’β’β’β’β’β’β’":
+ value = plugin_env.get(key) or os.getenv(key, "")
test_config[key.lower()] = value
# Expand any remaining env var references
@@ -1572,7 +1713,9 @@ async def test_plugin_connection(plugin_id: str, config: dict) -> dict:
# Call plugin's test_connection static method
result = await plugin_class.test_connection(test_config)
- logger.info(f"Test connection for '{plugin_id}': {result.get('message', 'No message')}")
+ logger.info(
+ f"Test connection for '{plugin_id}': {result.get('message', 'No message')}"
+ )
return result
@@ -1581,12 +1724,13 @@ async def test_plugin_connection(plugin_id: str, config: dict) -> dict:
return {
"success": False,
"message": f"Connection test failed: {str(e)}",
- "status": "error"
+ "status": "error",
}
# Plugin Lifecycle Management Functions (create / write-code / delete)
+
def _snake_to_pascal(snake_str: str) -> str:
"""Convert snake_case to PascalCase."""
return "".join(word.capitalize() for word in snake_str.split("_"))
@@ -1615,25 +1759,40 @@ async def create_plugin(
Returns:
Success dict with plugin_id and created_files list
"""
- from advanced_omi_backend.services.plugin_service import _get_plugins_dir, discover_plugins
+ from advanced_omi_backend.services.plugin_service import (
+ _get_plugins_dir,
+ discover_plugins,
+ )
# Validate name
if not plugin_name.replace("_", "").isalnum():
- return {"success": False, "error": "Plugin name must be alphanumeric with underscores only"}
+ return {
+ "success": False,
+ "error": "Plugin name must be alphanumeric with underscores only",
+ }
if not re.match(r"^[a-z][a-z0-9_]*$", plugin_name):
- return {"success": False, "error": "Plugin name must be lowercase snake_case starting with a letter"}
+ return {
+ "success": False,
+ "error": "Plugin name must be lowercase snake_case starting with a letter",
+ }
plugins_dir = _get_plugins_dir()
plugin_dir = plugins_dir / plugin_name
# Collision check
if plugin_dir.exists():
- return {"success": False, "error": f"Plugin '{plugin_name}' already exists at {plugin_dir}"}
+ return {
+ "success": False,
+ "error": f"Plugin '{plugin_name}' already exists at {plugin_dir}",
+ }
discovered = discover_plugins()
if plugin_name in discovered:
- return {"success": False, "error": f"Plugin '{plugin_name}' is already registered"}
+ return {
+ "success": False,
+ "error": f"Plugin '{plugin_name}' is already registered",
+ }
class_name = _snake_to_pascal(plugin_name) + "Plugin"
created_files: list[str] = []
@@ -1650,8 +1809,14 @@ async def create_plugin(
(plugin_dir / "plugin.py").write_text(plugin_code, encoding="utf-8")
else:
# Write standard boilerplate
- events_str = ", ".join(f'"{e}"' for e in events) if events else '"conversation.complete"'
- boilerplate = inspect.cleandoc(f'''
+ events_str = (
+ ", ".join(f'"{e}"' for e in events)
+ if events
+ else '"conversation.complete"'
+ )
+ boilerplate = (
+ inspect.cleandoc(
+ f'''
"""
{class_name} implementation.
@@ -1688,7 +1853,10 @@ async def cleanup(self):
async def on_conversation_complete(self, context: PluginContext) -> Optional[PluginResult]:
logger.info(f"Processing conversation for user: {{context.user_id}}")
return PluginResult(success=True, message="OK")
- ''') + "\n"
+ '''
+ )
+ + "\n"
+ )
(plugin_dir / "plugin.py").write_text(boilerplate, encoding="utf-8")
created_files.append("plugin.py")
@@ -1699,7 +1867,7 @@ async def on_conversation_complete(self, context: PluginContext) -> Optional[Plu
# config.yml
config_yml = {"description": description}
- with open(plugin_dir / "config.yml", 'w', encoding="utf-8") as f:
+ with open(plugin_dir / "config.yml", "w", encoding="utf-8") as f:
_yaml.dump(config_yml, f)
created_files.append("config.yml")
@@ -1765,7 +1933,10 @@ async def write_plugin_code(
plugin_dir = plugins_dir / plugin_id
if not plugin_dir.exists():
- return {"success": False, "error": f"Plugin '{plugin_id}' not found at {plugin_dir}"}
+ return {
+ "success": False,
+ "error": f"Plugin '{plugin_id}' not found at {plugin_dir}",
+ }
updated_files: list[str] = []
@@ -1848,9 +2019,14 @@ async def delete_plugin(plugin_id: str, remove_files: bool = False) -> dict:
logger.info(f"Removed plugin directory: {plugin_dir}")
if not removed_from_yml and not files_removed:
- return {"success": False, "error": f"Plugin '{plugin_id}' not found in plugins.yml or on disk"}
+ return {
+ "success": False,
+ "error": f"Plugin '{plugin_id}' not found in plugins.yml or on disk",
+ }
- logger.info(f"Deleted plugin '{plugin_id}' (yml={removed_from_yml}, files={files_removed})")
+ logger.info(
+ f"Deleted plugin '{plugin_id}' (yml={removed_from_yml}, files={files_removed})"
+ )
return {
"success": True,
"plugin_id": plugin_id,
diff --git a/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py b/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py
index ce801327..a0c10e4f 100644
--- a/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py
+++ b/backends/advanced/src/advanced_omi_backend/controllers/user_controller.py
@@ -9,11 +9,7 @@
from fastapi import HTTPException
from fastapi.responses import JSONResponse
-from advanced_omi_backend.auth import (
- ADMIN_EMAIL,
- UserManager,
- get_user_db,
-)
+from advanced_omi_backend.auth import ADMIN_EMAIL, UserManager, get_user_db
from advanced_omi_backend.client_manager import get_user_clients_all
from advanced_omi_backend.database import db, users_col
from advanced_omi_backend.models.conversation import Conversation
@@ -50,7 +46,9 @@ async def create_user(user_data: UserCreate):
# If we get here, user exists
return JSONResponse(
status_code=409,
- content={"message": f"User with email {user_data.email} already exists"},
+ content={
+ "message": f"User with email {user_data.email} already exists"
+ },
)
except Exception:
# User doesn't exist, continue with creation
@@ -61,15 +59,17 @@ async def create_user(user_data: UserCreate):
# Return the full user object (serialized via UserRead schema)
from advanced_omi_backend.models.user import UserRead
+
user_read = UserRead.model_validate(user)
return JSONResponse(
status_code=201,
- content=user_read.model_dump(mode='json'),
+ content=user_read.model_dump(mode="json"),
)
except Exception as e:
import traceback
+
error_details = traceback.format_exc()
logger.error(f"Error creating user: {e}")
logger.error(f"Full traceback: {error_details}")
@@ -89,15 +89,16 @@ async def update_user(user_id: str, user_data: UserUpdate):
logger.error(f"Invalid ObjectId format for user_id {user_id}: {e}")
return JSONResponse(
status_code=400,
- content={"message": f"Invalid user_id format: {user_id}. Must be a valid ObjectId."},
+ content={
+ "message": f"Invalid user_id format: {user_id}. Must be a valid ObjectId."
+ },
)
# Check if user exists
existing_user = await users_col.find_one({"_id": object_id})
if not existing_user:
return JSONResponse(
- status_code=404,
- content={"message": f"User {user_id} not found"}
+ status_code=404, content={"message": f"User {user_id} not found"}
)
# Get user database and create user manager
@@ -114,15 +115,17 @@ async def update_user(user_id: str, user_data: UserUpdate):
# Return the full user object (serialized via UserRead schema)
from advanced_omi_backend.models.user import UserRead
+
user_read = UserRead.model_validate(updated_user)
return JSONResponse(
status_code=200,
- content=user_read.model_dump(mode='json'),
+ content=user_read.model_dump(mode="json"),
)
except Exception as e:
import traceback
+
error_details = traceback.format_exc()
logger.error(f"Error updating user: {e}")
logger.error(f"Full traceback: {error_details}")
@@ -154,7 +157,9 @@ async def delete_user(
# Check if user exists
existing_user = await users_col.find_one({"_id": object_id})
if not existing_user:
- return JSONResponse(status_code=404, content={"message": f"User {user_id} not found"})
+ return JSONResponse(
+ status_code=404, content={"message": f"User {user_id} not found"}
+ )
# Prevent deletion of administrator user
user_email = existing_user.get("email", "")
@@ -176,7 +181,9 @@ async def delete_user(
if delete_conversations:
# Delete all conversations for this user
- conversations_result = await Conversation.find(Conversation.user_id == user_id).delete()
+ conversations_result = await Conversation.find(
+ Conversation.user_id == user_id
+ ).delete()
deleted_data["conversations_deleted"] = conversations_result.deleted_count
if delete_memories:
@@ -196,7 +203,9 @@ async def delete_user(
message = f"User {user_id} deleted successfully"
deleted_items = []
if delete_conversations and deleted_data.get("conversations_deleted", 0) > 0:
- deleted_items.append(f"{deleted_data['conversations_deleted']} conversations")
+ deleted_items.append(
+ f"{deleted_data['conversations_deleted']} conversations"
+ )
if delete_memories and deleted_data.get("memories_deleted", 0) > 0:
deleted_items.append(f"{deleted_data['memories_deleted']} memories")
diff --git a/backends/advanced/src/advanced_omi_backend/cron_scheduler.py b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py
index a496516f..3e3f6b8b 100644
--- a/backends/advanced/src/advanced_omi_backend/cron_scheduler.py
+++ b/backends/advanced/src/advanced_omi_backend/cron_scheduler.py
@@ -32,6 +32,7 @@
# Data classes
# ---------------------------------------------------------------------------
+
@dataclass
class CronJobConfig:
job_id: str
@@ -66,6 +67,7 @@ def _get_job_func(job_id: str) -> Optional[JobFunc]:
# Scheduler
# ---------------------------------------------------------------------------
+
class CronScheduler:
def __init__(self) -> None:
self.jobs: Dict[str, CronJobConfig] = {}
@@ -79,6 +81,7 @@ def __init__(self) -> None:
async def start(self) -> None:
"""Load config, restore state from Redis, and start the scheduler loop."""
import os
+
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
self._redis = aioredis.from_url(redis_url, decode_responses=True)
@@ -129,7 +132,9 @@ async def update_job(
if not croniter.is_valid(schedule):
raise ValueError(f"Invalid cron expression: {schedule}")
cfg.schedule = schedule
- cfg.next_run = croniter(schedule, datetime.now(timezone.utc)).get_next(datetime)
+ cfg.next_run = croniter(schedule, datetime.now(timezone.utc)).get_next(
+ datetime
+ )
if enabled is not None:
cfg.enabled = enabled
@@ -137,7 +142,11 @@ async def update_job(
# Persist changes to config.yml
save_config_section(
f"cron_jobs.{job_id}",
- {"enabled": cfg.enabled, "schedule": cfg.schedule, "description": cfg.description},
+ {
+ "enabled": cfg.enabled,
+ "schedule": cfg.schedule,
+ "description": cfg.description,
+ },
)
# Update next_run in Redis
@@ -147,22 +156,26 @@ async def update_job(
cfg.next_run.isoformat(),
)
- logger.info(f"Updated cron job '{job_id}': enabled={cfg.enabled}, schedule={cfg.schedule}")
+ logger.info(
+ f"Updated cron job '{job_id}': enabled={cfg.enabled}, schedule={cfg.schedule}"
+ )
async def get_all_jobs_status(self) -> List[dict]:
"""Return status of all registered cron jobs."""
result = []
for job_id, cfg in self.jobs.items():
- result.append({
- "job_id": job_id,
- "enabled": cfg.enabled,
- "schedule": cfg.schedule,
- "description": cfg.description,
- "last_run": cfg.last_run.isoformat() if cfg.last_run else None,
- "next_run": cfg.next_run.isoformat() if cfg.next_run else None,
- "running": cfg.running,
- "last_error": cfg.last_error,
- })
+ result.append(
+ {
+ "job_id": job_id,
+ "enabled": cfg.enabled,
+ "schedule": cfg.schedule,
+ "description": cfg.description,
+ "last_run": cfg.last_run.isoformat() if cfg.last_run else None,
+ "next_run": cfg.next_run.isoformat() if cfg.next_run else None,
+ "running": cfg.running,
+ "last_error": cfg.last_error,
+ }
+ )
return result
# -- internals -----------------------------------------------------------
@@ -175,7 +188,9 @@ def _load_jobs_from_config(self) -> None:
for job_id, job_cfg in cron_section.items():
schedule = str(job_cfg.get("schedule", "0 * * * *"))
if not croniter.is_valid(schedule):
- logger.warning(f"Invalid cron expression for job '{job_id}': {schedule} β skipping")
+ logger.warning(
+ f"Invalid cron expression for job '{job_id}': {schedule} β skipping"
+ )
continue
now = datetime.now(timezone.utc)
self.jobs[job_id] = CronJobConfig(
diff --git a/backends/advanced/src/advanced_omi_backend/database.py b/backends/advanced/src/advanced_omi_backend/database.py
index 1b214b6d..392b22c8 100644
--- a/backends/advanced/src/advanced_omi_backend/database.py
+++ b/backends/advanced/src/advanced_omi_backend/database.py
@@ -44,5 +44,3 @@ def get_collections():
return {
"users_col": users_col,
}
-
-
diff --git a/backends/advanced/src/advanced_omi_backend/llm_client.py b/backends/advanced/src/advanced_omi_backend/llm_client.py
index 96ccc77b..4aaa223f 100644
--- a/backends/advanced/src/advanced_omi_backend/llm_client.py
+++ b/backends/advanced/src/advanced_omi_backend/llm_client.py
@@ -11,7 +11,10 @@
from typing import Any, Dict, Optional
from advanced_omi_backend.model_registry import get_models_registry
-from advanced_omi_backend.openai_factory import create_openai_client, is_langfuse_enabled
+from advanced_omi_backend.openai_factory import (
+ create_openai_client,
+ is_langfuse_enabled,
+)
from advanced_omi_backend.services.memory.config import (
load_config_yml as _load_root_config,
)
@@ -62,7 +65,9 @@ def __init__(
self.base_url = base_url
self.model = model
if not self.api_key or not self.base_url or not self.model:
- raise ValueError(f"LLM configuration incomplete: api_key={'set' if self.api_key else 'MISSING'}, base_url={'set' if self.base_url else 'MISSING'}, model={'set' if self.model else 'MISSING'}")
+ raise ValueError(
+ f"LLM configuration incomplete: api_key={'set' if self.api_key else 'MISSING'}, base_url={'set' if self.base_url else 'MISSING'}, model={'set' if self.model else 'MISSING'}"
+ )
# Initialize OpenAI client with optional Langfuse tracing
try:
@@ -71,14 +76,19 @@ def __init__(
)
self.logger.info(f"OpenAI client initialized, base_url: {self.base_url}")
except ImportError:
- self.logger.error("OpenAI library not installed. Install with: pip install openai")
+ self.logger.error(
+ "OpenAI library not installed. Install with: pip install openai"
+ )
raise
except Exception as e:
self.logger.error(f"Failed to initialize OpenAI client: {e}")
raise
def generate(
- self, prompt: str, model: str | None = None, temperature: float | None = None,
+ self,
+ prompt: str,
+ model: str | None = None,
+ temperature: float | None = None,
**langfuse_kwargs,
) -> str:
"""Generate text completion using OpenAI-compatible API."""
@@ -101,8 +111,12 @@ def generate(
raise
def chat_with_tools(
- self, messages: list, tools: list | None = None, model: str | None = None,
- temperature: float | None = None, **langfuse_kwargs,
+ self,
+ messages: list,
+ tools: list | None = None,
+ model: str | None = None,
+ temperature: float | None = None,
+ **langfuse_kwargs,
):
"""Chat completion with tool/function calling support. Returns raw response object."""
model_name = model or self.model
@@ -127,14 +141,18 @@ def health_check(self) -> Dict:
"status": "β
Connected",
"base_url": self.base_url,
"default_model": self.model,
- "api_key_configured": bool(self.api_key and self.api_key != "dummy"),
+ "api_key_configured": bool(
+ self.api_key and self.api_key != "dummy"
+ ),
}
else:
return {
"status": "β οΈ Configuration incomplete",
"base_url": self.base_url,
"default_model": self.model,
- "api_key_configured": bool(self.api_key and self.api_key != "dummy"),
+ "api_key_configured": bool(
+ self.api_key and self.api_key != "dummy"
+ ),
}
except Exception as e:
self.logger.error(f"Health check failed: {e}")
@@ -157,11 +175,13 @@ class LLMClientFactory:
def create_client() -> LLMClient:
"""Create an LLM client based on model registry configuration (config.yml)."""
registry = get_models_registry()
-
+
if registry:
llm_def = registry.get_default("llm")
if llm_def:
- logger.info(f"Creating LLM client from registry: {llm_def.name} ({llm_def.model_provider})")
+ logger.info(
+ f"Creating LLM client from registry: {llm_def.name} ({llm_def.model_provider})"
+ )
params = llm_def.model_params or {}
return OpenAILLMClient(
api_key=llm_def.api_key,
@@ -169,7 +189,7 @@ def create_client() -> LLMClient:
model=llm_def.model_name,
temperature=params.get("temperature", 0.1),
)
-
+
raise ValueError("No default LLM defined in config.yml")
@staticmethod
diff --git a/backends/advanced/src/advanced_omi_backend/main.py b/backends/advanced/src/advanced_omi_backend/main.py
index ee60696f..5b0f3f05 100644
--- a/backends/advanced/src/advanced_omi_backend/main.py
+++ b/backends/advanced/src/advanced_omi_backend/main.py
@@ -46,5 +46,5 @@
port=port,
reload=False, # Set to True for development
access_log=False, # Disabled - using custom RequestLoggingMiddleware instead
- log_level="info"
+ log_level="info",
)
diff --git a/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py b/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py
index 069d5239..21fa98d9 100644
--- a/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py
+++ b/backends/advanced/src/advanced_omi_backend/middleware/app_middleware.py
@@ -26,7 +26,9 @@ def setup_cors_middleware(app: FastAPI) -> None:
config = get_app_config()
logger.info(f"π CORS configured with origins: {config.allowed_origins}")
- logger.info(f"π CORS also allows Tailscale IPs via regex: {config.tailscale_regex}")
+ logger.info(
+ f"π CORS also allows Tailscale IPs via regex: {config.tailscale_regex}"
+ )
app.add_middleware(
CORSMiddleware,
@@ -157,6 +159,7 @@ async def dispatch(self, request: Request, call_next):
# Recreate response with the body we consumed
from starlette.responses import Response
+
return Response(
content=response_body,
status_code=response.status_code,
@@ -191,8 +194,8 @@ async def database_exception_handler(request: Request, exc: Exception):
content={
"detail": "Unable to connect to server. Please check your connection and try again.",
"error_type": "connection_failure",
- "error_category": "database"
- }
+ "error_category": "database",
+ },
)
@app.exception_handler(ConnectionError)
@@ -204,8 +207,8 @@ async def connection_exception_handler(request: Request, exc: ConnectionError):
content={
"detail": "Unable to connect to server. Please check your connection and try again.",
"error_type": "connection_failure",
- "error_category": "network"
- }
+ "error_category": "network",
+ },
)
@app.exception_handler(HTTPException)
@@ -218,15 +221,12 @@ async def http_exception_handler(request: Request, exc: HTTPException):
content={
"detail": exc.detail,
"error_type": "authentication_failure",
- "error_category": "security"
- }
+ "error_category": "security",
+ },
)
# For other HTTP exceptions, return as normal
- return JSONResponse(
- status_code=exc.status_code,
- content={"detail": exc.detail}
- )
+ return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
def setup_middleware(app: FastAPI) -> None:
@@ -236,4 +236,4 @@ def setup_middleware(app: FastAPI) -> None:
logger.info("π Request logging middleware enabled")
setup_cors_middleware(app)
- setup_exception_handlers(app)
\ No newline at end of file
+ setup_exception_handlers(app)
diff --git a/backends/advanced/src/advanced_omi_backend/model_registry.py b/backends/advanced/src/advanced_omi_backend/model_registry.py
index bc7e5fc5..e7a6ddf5 100644
--- a/backends/advanced/src/advanced_omi_backend/model_registry.py
+++ b/backends/advanced/src/advanced_omi_backend/model_registry.py
@@ -30,76 +30,95 @@
class ModelDef(BaseModel):
"""Model definition with validation.
-
+
Represents a single model configuration (LLM, embedding, STT, TTS, etc.)
from config.yml with automatic validation and type checking.
"""
-
+
model_config = ConfigDict(
- extra='allow', # Allow extra fields for extensibility
+ extra="allow", # Allow extra fields for extensibility
validate_assignment=True, # Validate on attribute assignment
arbitrary_types_allowed=True,
)
-
+
name: str = Field(..., min_length=1, description="Unique model identifier")
- model_type: str = Field(..., description="Model type: llm, embedding, stt, tts, etc.")
- model_provider: str = Field(default="unknown", description="Provider name: openai, ollama, deepgram, parakeet, vibevoice, etc.")
- api_family: str = Field(default="openai", description="API family: openai, http, websocket, etc.")
+ model_type: str = Field(
+ ..., description="Model type: llm, embedding, stt, tts, etc."
+ )
+ model_provider: str = Field(
+ default="unknown",
+ description="Provider name: openai, ollama, deepgram, parakeet, vibevoice, etc.",
+ )
+ api_family: str = Field(
+ default="openai", description="API family: openai, http, websocket, etc."
+ )
model_name: str = Field(default="", description="Provider-specific model name")
model_url: str = Field(default="", description="Base URL for API requests")
- api_key: Optional[str] = Field(default=None, description="API key or authentication token")
- description: Optional[str] = Field(default=None, description="Human-readable description")
- model_params: Dict[str, Any] = Field(default_factory=dict, description="Model-specific parameters")
- model_output: Optional[str] = Field(default=None, description="Output format: json, text, vector, etc.")
- embedding_dimensions: Optional[int] = Field(default=None, ge=1, description="Embedding vector dimensions")
- operations: Dict[str, Any] = Field(default_factory=dict, description="API operation definitions")
+ api_key: Optional[str] = Field(
+ default=None, description="API key or authentication token"
+ )
+ description: Optional[str] = Field(
+ default=None, description="Human-readable description"
+ )
+ model_params: Dict[str, Any] = Field(
+ default_factory=dict, description="Model-specific parameters"
+ )
+ model_output: Optional[str] = Field(
+ default=None, description="Output format: json, text, vector, etc."
+ )
+ embedding_dimensions: Optional[int] = Field(
+ default=None, ge=1, description="Embedding vector dimensions"
+ )
+ operations: Dict[str, Any] = Field(
+ default_factory=dict, description="API operation definitions"
+ )
capabilities: List[str] = Field(
default_factory=list,
- description="Provider capabilities: word_timestamps, segments, diarization (for STT providers)"
+ description="Provider capabilities: word_timestamps, segments, diarization (for STT providers)",
)
-
- @field_validator('model_name', mode='before')
+
+ @field_validator("model_name", mode="before")
@classmethod
def default_model_name(cls, v: Any, info) -> str:
"""Default model_name to name if not provided."""
- if not v and info.data.get('name'):
- return info.data['name']
+ if not v and info.data.get("name"):
+ return info.data["name"]
return v or ""
-
- @field_validator('model_url', mode='before')
+
+ @field_validator("model_url", mode="before")
@classmethod
def validate_url(cls, v: Any) -> str:
"""Ensure URL doesn't have trailing whitespace."""
if isinstance(v, str):
return v.strip()
return v or ""
-
- @field_validator('api_key', mode='before')
+
+ @field_validator("api_key", mode="before")
@classmethod
def sanitize_api_key(cls, v: Any) -> Optional[str]:
"""Sanitize API key, treat empty strings as None."""
if isinstance(v, str):
v = v.strip()
- if not v or v.lower() in ['dummy', 'none', 'null']:
+ if not v or v.lower() in ["dummy", "none", "null"]:
return None
return v
return v
-
- @model_validator(mode='after')
+
+ @model_validator(mode="after")
def validate_model(self) -> ModelDef:
"""Cross-field validation."""
# Ensure embedding models have dimensions specified
- if self.model_type == 'embedding' and not self.embedding_dimensions:
+ if self.model_type == "embedding" and not self.embedding_dimensions:
# Common defaults
defaults = {
- 'text-embedding-3-small': 1536,
- 'text-embedding-3-large': 3072,
- 'text-embedding-ada-002': 1536,
- 'nomic-embed-text-v1.5': 768,
+ "text-embedding-3-small": 1536,
+ "text-embedding-3-large": 3072,
+ "text-embedding-ada-002": 1536,
+ "nomic-embed-text-v1.5": 768,
}
if self.model_name in defaults:
self.embedding_dimensions = defaults[self.model_name]
-
+
return self
@@ -180,52 +199,49 @@ class AppModels(BaseModel):
"""
model_config = ConfigDict(
- extra='allow',
+ extra="allow",
validate_assignment=True,
)
defaults: Dict[str, str] = Field(
- default_factory=dict,
- description="Default model names for each model_type"
+ default_factory=dict, description="Default model names for each model_type"
)
models: Dict[str, ModelDef] = Field(
default_factory=dict,
- description="All available model definitions keyed by name"
+ description="All available model definitions keyed by name",
)
memory: Dict[str, Any] = Field(
- default_factory=dict,
- description="Memory service configuration"
+ default_factory=dict, description="Memory service configuration"
)
speaker_recognition: Dict[str, Any] = Field(
- default_factory=dict,
- description="Speaker recognition service configuration"
+ default_factory=dict, description="Speaker recognition service configuration"
)
chat: Dict[str, Any] = Field(
default_factory=dict,
- description="Chat service configuration including system prompt"
+ description="Chat service configuration including system prompt",
)
llm_operations: Dict[str, LLMOperationConfig] = Field(
default_factory=dict,
- description="Per-operation LLM configuration (temperature, model override, etc.)"
+ description="Per-operation LLM configuration (temperature, model override, etc.)",
)
-
+
def get_by_name(self, name: str) -> Optional[ModelDef]:
"""Get a model by its unique name.
-
+
Args:
name: Model name to look up
-
+
Returns:
ModelDef if found, None otherwise
"""
return self.models.get(name)
-
+
def get_default(self, model_type: str) -> Optional[ModelDef]:
"""Get the default model for a given type.
-
+
Args:
model_type: Type of model (llm, embedding, stt, tts, etc.)
-
+
Returns:
Default ModelDef for the type, or first available model of that type,
or None if no models of that type exist
@@ -236,25 +252,25 @@ def get_default(self, model_type: str) -> Optional[ModelDef]:
model = self.get_by_name(name)
if model:
return model
-
+
# Fallback: first model of that type
for m in self.models.values():
if m.model_type == model_type:
return m
-
+
return None
-
+
def get_all_by_type(self, model_type: str) -> List[ModelDef]:
"""Get all models of a specific type.
-
+
Args:
model_type: Type of model to filter by
-
+
Returns:
List of ModelDef objects matching the type
"""
return [m for m in self.models.values() if m.model_type == model_type]
-
+
def list_model_types(self) -> List[str]:
"""Get all unique model types in the registry.
@@ -340,6 +356,7 @@ def _find_config_path() -> Path:
Path to config.yml
"""
from advanced_omi_backend.config import get_config_yml_path
+
return get_config_yml_path()
@@ -413,13 +430,13 @@ def load_models_config(force_reload: bool = False) -> Optional[AppModels]:
def get_models_registry() -> Optional[AppModels]:
"""Get the global models registry.
-
+
This is the primary interface for accessing model configurations.
The registry is loaded once and cached for performance.
-
+
Returns:
AppModels instance, or None if config.yml not found
-
+
Example:
>>> registry = get_models_registry()
>>> if registry:
diff --git a/backends/advanced/src/advanced_omi_backend/models/__init__.py b/backends/advanced/src/advanced_omi_backend/models/__init__.py
index a19fa0db..4dbedc18 100644
--- a/backends/advanced/src/advanced_omi_backend/models/__init__.py
+++ b/backends/advanced/src/advanced_omi_backend/models/__init__.py
@@ -7,4 +7,4 @@
# Models can be imported directly from their files
# e.g. from .job import TranscriptionJob
-# e.g. from .conversation import Conversation, create_conversation
\ No newline at end of file
+# e.g. from .conversation import Conversation, create_conversation
diff --git a/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py b/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py
index 5f3b4c1d..d8ed0125 100644
--- a/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py
+++ b/backends/advanced/src/advanced_omi_backend/models/audio_chunk.py
@@ -41,10 +41,7 @@ class AudioChunkDocument(Document):
conversation_id: Indexed(str) = Field(
description="Parent conversation ID (UUID format)"
)
- chunk_index: int = Field(
- description="Sequential chunk number (0-based)",
- ge=0
- )
+ chunk_index: int = Field(description="Sequential chunk number (0-based)", ge=0)
# Audio data
audio_data: bytes = Field(
@@ -53,61 +50,48 @@ class AudioChunkDocument(Document):
# Size tracking
original_size: int = Field(
- description="Original PCM size in bytes (before compression)",
- gt=0
+ description="Original PCM size in bytes (before compression)", gt=0
)
compressed_size: int = Field(
- description="Opus-encoded size in bytes (after compression)",
- gt=0
+ description="Opus-encoded size in bytes (after compression)", gt=0
)
# Time boundaries
start_time: float = Field(
- description="Start time in seconds from conversation start",
- ge=0.0
+ description="Start time in seconds from conversation start", ge=0.0
)
end_time: float = Field(
- description="End time in seconds from conversation start",
- gt=0.0
+ description="End time in seconds from conversation start", gt=0.0
)
duration: float = Field(
- description="Chunk duration in seconds (typically 10.0)",
- gt=0.0
+ description="Chunk duration in seconds (typically 10.0)", gt=0.0
)
# Audio format
- sample_rate: int = Field(
- default=16000,
- description="Original PCM sample rate (Hz)"
- )
+ sample_rate: int = Field(default=16000, description="Original PCM sample rate (Hz)")
channels: int = Field(
- default=1,
- description="Number of audio channels (1=mono, 2=stereo)"
+ default=1, description="Number of audio channels (1=mono, 2=stereo)"
)
# Optional analysis
has_speech: Optional[bool] = Field(
- default=None,
- description="Voice Activity Detection result (if available)"
+ default=None, description="Voice Activity Detection result (if available)"
)
# Metadata
created_at: datetime = Field(
- default_factory=datetime.utcnow,
- description="Chunk creation timestamp"
+ default_factory=datetime.utcnow, description="Chunk creation timestamp"
)
# Soft delete fields
deleted: bool = Field(
- default=False,
- description="Whether this chunk was soft-deleted"
+ default=False, description="Whether this chunk was soft-deleted"
)
deleted_at: Optional[datetime] = Field(
- default=None,
- description="When the chunk was marked as deleted"
+ default=None, description="When the chunk was marked as deleted"
)
- @field_serializer('audio_data')
+ @field_serializer("audio_data")
def serialize_audio_data(self, v: bytes) -> Binary:
"""
Convert bytes to BSON Binary for MongoDB storage.
@@ -121,20 +105,18 @@ def serialize_audio_data(self, v: bytes) -> Binary:
class Settings:
"""Beanie document settings."""
+
name = "audio_chunks"
indexes = [
# Primary query: Retrieve chunks in order for a conversation
[("conversation_id", 1), ("chunk_index", 1)],
-
# Conversation lookup and counting
"conversation_id",
-
# Maintenance queries (cleanup, monitoring)
"created_at",
-
# Soft delete filtering
- "deleted"
+ "deleted",
]
@property
diff --git a/backends/advanced/src/advanced_omi_backend/models/conversation.py b/backends/advanced/src/advanced_omi_backend/models/conversation.py
index 79b6d798..62d0286e 100644
--- a/backends/advanced/src/advanced_omi_backend/models/conversation.py
+++ b/backends/advanced/src/advanced_omi_backend/models/conversation.py
@@ -22,179 +22,238 @@ class Conversation(Document):
class MemoryProvider(str, Enum):
"""Supported memory providers."""
+
CHRONICLE = "chronicle"
OPENMEMORY_MCP = "openmemory_mcp"
FRIEND_LITE = "friend_lite" # Legacy value
class ConversationStatus(str, Enum):
"""Conversation processing status."""
+
ACTIVE = "active" # Has running jobs or open websocket
COMPLETED = "completed" # All jobs succeeded
FAILED = "failed" # One or more jobs failed
class EndReason(str, Enum):
"""Reason for conversation ending."""
+
USER_STOPPED = "user_stopped" # User manually stopped recording
- INACTIVITY_TIMEOUT = "inactivity_timeout" # No speech detected for threshold period
- WEBSOCKET_DISCONNECT = "websocket_disconnect" # Connection lost (Bluetooth, network, etc.)
+ INACTIVITY_TIMEOUT = (
+ "inactivity_timeout" # No speech detected for threshold period
+ )
+ WEBSOCKET_DISCONNECT = (
+ "websocket_disconnect" # Connection lost (Bluetooth, network, etc.)
+ )
MAX_DURATION = "max_duration" # Hit maximum conversation duration
- CLOSE_REQUESTED = "close_requested" # External close signal (API, plugin, button)
+ CLOSE_REQUESTED = (
+ "close_requested" # External close signal (API, plugin, button)
+ )
ERROR = "error" # Processing error forced conversation end
UNKNOWN = "unknown" # Unknown or legacy reason
# Nested Models
class Word(BaseModel):
"""Individual word with timestamp in a transcript."""
+
word: str = Field(description="Word text")
start: float = Field(description="Start time in seconds")
end: float = Field(description="End time in seconds")
confidence: Optional[float] = Field(None, description="Confidence score (0-1)")
speaker: Optional[int] = Field(None, description="Speaker ID from diarization")
- speaker_confidence: Optional[float] = Field(None, description="Speaker diarization confidence")
+ speaker_confidence: Optional[float] = Field(
+ None, description="Speaker diarization confidence"
+ )
class SegmentType(str, Enum):
"""Type of transcript segment."""
+
SPEECH = "speech"
- EVENT = "event" # Non-speech: [laughter], [music], etc.
- NOTE = "note" # User-inserted annotation/tag
+ EVENT = "event" # Non-speech: [laughter], [music], etc.
+ NOTE = "note" # User-inserted annotation/tag
class SpeakerSegment(BaseModel):
"""Individual speaker segment in a transcript."""
+
start: float = Field(description="Start time in seconds")
end: float = Field(description="End time in seconds")
text: str = Field(description="Transcript text for this segment")
speaker: str = Field(description="Speaker identifier")
segment_type: str = Field(
default="speech",
- description="Type: speech, event (non-speech from ASR), or note (user-inserted)"
+ description="Type: speech, event (non-speech from ASR), or note (user-inserted)",
+ )
+ identified_as: Optional[str] = Field(
+ None,
+ description="Speaker name from speaker recognition (None if not identified)",
)
- identified_as: Optional[str] = Field(None, description="Speaker name from speaker recognition (None if not identified)")
confidence: Optional[float] = Field(None, description="Confidence score (0-1)")
- words: List["Conversation.Word"] = Field(default_factory=list, description="Word-level timestamps for this segment")
+ words: List["Conversation.Word"] = Field(
+ default_factory=list, description="Word-level timestamps for this segment"
+ )
class TranscriptVersion(BaseModel):
"""Version of a transcript with processing metadata."""
+
version_id: str = Field(description="Unique version identifier")
transcript: Optional[str] = Field(None, description="Full transcript text")
words: List["Conversation.Word"] = Field(
default_factory=list,
- description="Word-level timestamps for entire transcript"
+ description="Word-level timestamps for entire transcript",
)
segments: List["Conversation.SpeakerSegment"] = Field(
default_factory=list,
- description="Speaker segments (filled by speaker recognition)"
+ description="Speaker segments (filled by speaker recognition)",
+ )
+ provider: Optional[str] = Field(
+ None,
+ description="Transcription provider used (deepgram, parakeet, vibevoice, etc.)",
+ )
+ model: Optional[str] = Field(
+ None, description="Model used (e.g., nova-3, parakeet)"
)
- provider: Optional[str] = Field(None, description="Transcription provider used (deepgram, parakeet, vibevoice, etc.)")
- model: Optional[str] = Field(None, description="Model used (e.g., nova-3, parakeet)")
created_at: datetime = Field(description="When this version was created")
- processing_time_seconds: Optional[float] = Field(None, description="Time taken to process")
+ processing_time_seconds: Optional[float] = Field(
+ None, description="Time taken to process"
+ )
diarization_source: Optional[str] = Field(
None,
- description="Source of speaker diarization: 'provider' (transcription service), 'pyannote' (speaker recognition), or None"
+ description="Source of speaker diarization: 'provider' (transcription service), 'pyannote' (speaker recognition), or None",
+ )
+ metadata: Dict[str, Any] = Field(
+ default_factory=dict, description="Additional provider-specific metadata"
)
- metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional provider-specific metadata")
class MemoryVersion(BaseModel):
"""Version of memory extraction with processing metadata."""
+
version_id: str = Field(description="Unique version identifier")
memory_count: int = Field(description="Number of memories extracted")
- transcript_version_id: str = Field(description="Which transcript version was used")
- provider: "Conversation.MemoryProvider" = Field(description="Memory provider used")
- model: Optional[str] = Field(None, description="Model used (e.g., gpt-4o-mini, llama3)")
+ transcript_version_id: str = Field(
+ description="Which transcript version was used"
+ )
+ provider: "Conversation.MemoryProvider" = Field(
+ description="Memory provider used"
+ )
+ model: Optional[str] = Field(
+ None, description="Model used (e.g., gpt-4o-mini, llama3)"
+ )
created_at: datetime = Field(description="When this version was created")
- processing_time_seconds: Optional[float] = Field(None, description="Time taken to process")
- metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional provider-specific metadata")
+ processing_time_seconds: Optional[float] = Field(
+ None, description="Time taken to process"
+ )
+ metadata: Dict[str, Any] = Field(
+ default_factory=dict, description="Additional provider-specific metadata"
+ )
# Core identifiers
- conversation_id: Indexed(str, unique=True) = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique conversation identifier")
+ conversation_id: Indexed(str, unique=True) = Field(
+ default_factory=lambda: str(uuid.uuid4()),
+ description="Unique conversation identifier",
+ )
user_id: Indexed(str) = Field(description="User who owns this conversation")
client_id: Indexed(str) = Field(description="Client device identifier")
# External file tracking (for deduplication of imported files)
external_source_id: Optional[str] = Field(
None,
- description="External file identifier (e.g., Google Drive file_id) for deduplication"
+ description="External file identifier (e.g., Google Drive file_id) for deduplication",
)
external_source_type: Optional[str] = Field(
- None,
- description="Type of external source (gdrive, dropbox, s3, etc.)"
+ None, description="Type of external source (gdrive, dropbox, s3, etc.)"
)
# MongoDB chunk-based audio storage (new system)
audio_chunks_count: Optional[int] = Field(
- None,
- description="Total number of 10-second audio chunks stored in MongoDB"
+ None, description="Total number of 10-second audio chunks stored in MongoDB"
)
audio_total_duration: Optional[float] = Field(
- None,
- description="Total audio duration in seconds (sum of all chunks)"
+ None, description="Total audio duration in seconds (sum of all chunks)"
)
audio_compression_ratio: Optional[float] = Field(
None,
- description="Compression ratio (compressed_size / original_size), typically ~0.047 for Opus"
+ description="Compression ratio (compressed_size / original_size), typically ~0.047 for Opus",
)
# Markers (e.g., button events) captured during the session
markers: List[Dict[str, Any]] = Field(
default_factory=list,
- description="Markers captured during audio session (button events, bookmarks, etc.)"
+ description="Markers captured during audio session (button events, bookmarks, etc.)",
)
# Creation metadata
- created_at: Indexed(datetime) = Field(default_factory=datetime.utcnow, description="When the conversation was created")
+ created_at: Indexed(datetime) = Field(
+ default_factory=datetime.utcnow, description="When the conversation was created"
+ )
# Processing status tracking
- deleted: bool = Field(False, description="Whether this conversation was deleted due to processing failure")
- deletion_reason: Optional[str] = Field(None, description="Reason for deletion (no_meaningful_speech, audio_file_not_ready, etc.)")
- deleted_at: Optional[datetime] = Field(None, description="When the conversation was marked as deleted")
+ deleted: bool = Field(
+ False,
+ description="Whether this conversation was deleted due to processing failure",
+ )
+ deletion_reason: Optional[str] = Field(
+ None,
+ description="Reason for deletion (no_meaningful_speech, audio_file_not_ready, etc.)",
+ )
+ deleted_at: Optional[datetime] = Field(
+ None, description="When the conversation was marked as deleted"
+ )
# Always persist audio flag and processing status
processing_status: Optional[str] = Field(
None,
- description="Processing status: pending_transcription, transcription_failed, completed"
+ description="Processing status: pending_transcription, transcription_failed, completed",
)
always_persist: bool = Field(
default=False,
- description="Flag indicating conversation was created for audio persistence"
+ description="Flag indicating conversation was created for audio persistence",
)
# Conversation completion tracking
- end_reason: Optional["Conversation.EndReason"] = Field(None, description="Reason why the conversation ended")
- completed_at: Optional[datetime] = Field(None, description="When the conversation was completed/closed")
+ end_reason: Optional["Conversation.EndReason"] = Field(
+ None, description="Reason why the conversation ended"
+ )
+ completed_at: Optional[datetime] = Field(
+ None, description="When the conversation was completed/closed"
+ )
# Star/favorite
- starred: bool = Field(False, description="Whether this conversation is starred/favorited")
- starred_at: Optional[datetime] = Field(None, description="When the conversation was starred")
+ starred: bool = Field(
+ False, description="Whether this conversation is starred/favorited"
+ )
+ starred_at: Optional[datetime] = Field(
+ None, description="When the conversation was starred"
+ )
# Summary fields (auto-generated from transcript)
title: Optional[str] = Field(None, description="Auto-generated conversation title")
- summary: Optional[str] = Field(None, description="Auto-generated short summary (1-2 sentences)")
- detailed_summary: Optional[str] = Field(None, description="Auto-generated detailed summary (comprehensive, corrected content)")
+ summary: Optional[str] = Field(
+ None, description="Auto-generated short summary (1-2 sentences)"
+ )
+ detailed_summary: Optional[str] = Field(
+ None,
+ description="Auto-generated detailed summary (comprehensive, corrected content)",
+ )
# Versioned processing
transcript_versions: List["Conversation.TranscriptVersion"] = Field(
- default_factory=list,
- description="All transcript processing attempts"
+ default_factory=list, description="All transcript processing attempts"
)
memory_versions: List["Conversation.MemoryVersion"] = Field(
- default_factory=list,
- description="All memory extraction attempts"
+ default_factory=list, description="All memory extraction attempts"
)
# Active version pointers
active_transcript_version: Optional[str] = Field(
- None,
- description="Version ID of currently active transcript"
+ None, description="Version ID of currently active transcript"
)
active_memory_version: Optional[str] = Field(
- None,
- description="Version ID of currently active memory extraction"
+ None, description="Version ID of currently active memory extraction"
)
# Legacy fields removed - use transcript_versions[active_transcript_version] and memory_versions[active_memory_version]
# Frontend should access: conversation.active_transcript.segments, conversation.active_transcript.transcript
- @model_validator(mode='before')
+ @model_validator(mode="before")
@classmethod
def clean_legacy_data(cls, data: Any) -> Any:
"""Clean up legacy/malformed data before Pydantic validation."""
@@ -203,26 +262,32 @@ def clean_legacy_data(cls, data: Any) -> Any:
return data
# Fix malformed transcript_versions (from old schema versions)
- if 'transcript_versions' in data and isinstance(data['transcript_versions'], list):
- for version in data['transcript_versions']:
+ if "transcript_versions" in data and isinstance(
+ data["transcript_versions"], list
+ ):
+ for version in data["transcript_versions"]:
if isinstance(version, dict):
# If segments is not a list, clear it
- if 'segments' in version and not isinstance(version['segments'], list):
- version['segments'] = []
+ if "segments" in version and not isinstance(
+ version["segments"], list
+ ):
+ version["segments"] = []
# If transcript is a dict, clear it
- if 'transcript' in version and isinstance(version['transcript'], dict):
- version['transcript'] = None
+ if "transcript" in version and isinstance(
+ version["transcript"], dict
+ ):
+ version["transcript"] = None
# Normalize provider to lowercase (legacy data had "Deepgram" instead of "deepgram")
- if 'provider' in version and isinstance(version['provider'], str):
- version['provider'] = version['provider'].lower()
+ if "provider" in version and isinstance(version["provider"], str):
+ version["provider"] = version["provider"].lower()
# Fix speaker IDs in segments (legacy data had integers, need strings)
- if 'segments' in version and isinstance(version['segments'], list):
- for segment in version['segments']:
- if isinstance(segment, dict) and 'speaker' in segment:
- if isinstance(segment['speaker'], int):
- segment['speaker'] = f"Speaker {segment['speaker']}"
- elif not isinstance(segment['speaker'], str):
- segment['speaker'] = "unknown"
+ if "segments" in version and isinstance(version["segments"], list):
+ for segment in version["segments"]:
+ if isinstance(segment, dict) and "speaker" in segment:
+ if isinstance(segment["speaker"], int):
+ segment["speaker"] = f"Speaker {segment['speaker']}"
+ elif not isinstance(segment["speaker"], str):
+ segment["speaker"] = "unknown"
return data
@@ -325,7 +390,7 @@ def add_transcript_version(
model: Optional[str] = None,
processing_time_seconds: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None,
- set_as_active: bool = True
+ set_as_active: bool = True,
) -> "Conversation.TranscriptVersion":
"""Add a new transcript version and optionally set it as active."""
new_version = Conversation.TranscriptVersion(
@@ -337,7 +402,7 @@ def add_transcript_version(
model=model,
created_at=datetime.now(),
processing_time_seconds=processing_time_seconds,
- metadata=metadata or {}
+ metadata=metadata or {},
)
self.transcript_versions.append(new_version)
@@ -356,7 +421,7 @@ def add_memory_version(
model: Optional[str] = None,
processing_time_seconds: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None,
- set_as_active: bool = True
+ set_as_active: bool = True,
) -> "Conversation.MemoryVersion":
"""Add a new memory version and optionally set it as active."""
new_version = Conversation.MemoryVersion(
@@ -367,7 +432,7 @@ def add_memory_version(
model=model,
created_at=datetime.now(),
processing_time_seconds=processing_time_seconds,
- metadata=metadata or {}
+ metadata=metadata or {},
)
self.memory_versions.append(new_version)
@@ -399,12 +464,27 @@ class Settings:
"conversation_id",
"user_id",
"created_at",
- [("user_id", 1), ("deleted", 1), ("created_at", -1)], # Compound index for paginated list queries
- IndexModel([("external_source_id", 1)], sparse=True), # Sparse index for deduplication
+ [
+ ("user_id", 1),
+ ("deleted", 1),
+ ("created_at", -1),
+ ], # Compound index for paginated list queries
+ IndexModel(
+ [("external_source_id", 1)], sparse=True
+ ), # Sparse index for deduplication
IndexModel(
- [("title", "text"), ("summary", "text"), ("detailed_summary", "text"),
- ("transcript_versions.transcript", "text")],
- weights={"title": 10, "summary": 5, "detailed_summary": 3, "transcript_versions.transcript": 1},
+ [
+ ("title", "text"),
+ ("summary", "text"),
+ ("detailed_summary", "text"),
+ ("transcript_versions.transcript", "text"),
+ ],
+ weights={
+ "title": 10,
+ "summary": 5,
+ "detailed_summary": 3,
+ "transcript_versions.transcript": 1,
+ },
name="conversation_text_search",
),
]
@@ -458,4 +538,4 @@ def create_conversation(
if conversation_id is not None:
conv_data["conversation_id"] = conversation_id
- return Conversation(**conv_data)
\ No newline at end of file
+ return Conversation(**conv_data)
diff --git a/backends/advanced/src/advanced_omi_backend/models/job.py b/backends/advanced/src/advanced_omi_backend/models/job.py
index f7f44d4c..91870643 100644
--- a/backends/advanced/src/advanced_omi_backend/models/job.py
+++ b/backends/advanced/src/advanced_omi_backend/models/job.py
@@ -27,11 +27,12 @@
_beanie_initialized = False
_beanie_init_lock = asyncio.Lock()
+
async def _ensure_beanie_initialized():
"""Ensure Beanie is initialized in the current process (for RQ workers)."""
global _beanie_initialized
async with _beanie_init_lock:
- if _beanie_initialized:
+ if _beanie_initialized:
return
try:
import os
@@ -85,10 +86,11 @@ class JobPriority(str, Enum):
- NORMAL: 5 minutes timeout (default)
- LOW: 3 minutes timeout
"""
- URGENT = "urgent" # 1 - Process immediately
- HIGH = "high" # 2 - Process before normal
- NORMAL = "normal" # 3 - Default priority
- LOW = "low" # 4 - Process when idle
+
+ URGENT = "urgent" # 1 - Process immediately
+ HIGH = "high" # 2 - Process before normal
+ NORMAL = "normal" # 3 - Default priority
+ LOW = "low" # 4 - Process when idle
class BaseRQJob(ABC):
@@ -188,6 +190,7 @@ def run(self, **kwargs) -> Dict[str, Any]:
asyncio.set_event_loop(loop)
try:
+
async def process():
await self._setup()
try:
@@ -207,11 +210,15 @@ async def process():
except Exception as e:
elapsed = time.time() - self.job_start_time
- logger.error(f"β {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True)
+ logger.error(
+ f"β {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True
+ )
raise
-def async_job(redis: bool = True, beanie: bool = True, timeout: int = 300, result_ttl: int = 3600):
+def async_job(
+ redis: bool = True, beanie: bool = True, timeout: int = 300, result_ttl: int = 3600
+):
"""
Decorator to convert async functions into RQ-compatible job functions.
@@ -239,6 +246,7 @@ async def my_job(arg1, arg2, redis_client=None):
queue.enqueue(my_job, arg1_value, arg2_value) # Uses timeout=600
queue.enqueue(my_job, arg1_value, arg2_value, job_timeout=1200) # Override
"""
+
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Dict[str, Any]:
@@ -254,6 +262,7 @@ def wrapper(*args, **kwargs) -> Dict[str, Any]:
asyncio.set_event_loop(loop)
try:
+
async def process():
nonlocal redis_client
@@ -267,8 +276,9 @@ async def process():
from advanced_omi_backend.controllers.queue_controller import (
REDIS_URL,
)
+
redis_client = redis_async.from_url(REDIS_URL)
- kwargs['redis_client'] = redis_client
+ kwargs["redis_client"] = redis_client
logger.debug(f"Redis client created")
try:
@@ -292,7 +302,9 @@ async def process():
except Exception as e:
elapsed = time.time() - start_time
- logger.error(f"β {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True)
+ logger.error(
+ f"β {job_name} failed after {elapsed:.2f}s: {e}", exc_info=True
+ )
raise
# Store default job configuration as attributes for RQ introspection
@@ -300,4 +312,5 @@ async def process():
wrapper.result_ttl = result_ttl
return wrapper
- return decorator
\ No newline at end of file
+
+ return decorator
diff --git a/backends/advanced/src/advanced_omi_backend/models/user.py b/backends/advanced/src/advanced_omi_backend/models/user.py
index 7291f9bb..77e0cc38 100644
--- a/backends/advanced/src/advanced_omi_backend/models/user.py
+++ b/backends/advanced/src/advanced_omi_backend/models/user.py
@@ -80,7 +80,9 @@ def user_id(self) -> str:
"""Return string representation of MongoDB ObjectId for backward compatibility."""
return str(self.id)
- def register_client(self, client_id: str, device_name: Optional[str] = None) -> None:
+ def register_client(
+ self, client_id: str, device_name: Optional[str] = None
+ ) -> None:
"""Register a new client for this user."""
# Check if client already exists
if client_id in self.registered_clients:
diff --git a/backends/advanced/src/advanced_omi_backend/models/waveform.py b/backends/advanced/src/advanced_omi_backend/models/waveform.py
index caf6fd49..ccf1158c 100644
--- a/backends/advanced/src/advanced_omi_backend/models/waveform.py
+++ b/backends/advanced/src/advanced_omi_backend/models/waveform.py
@@ -32,12 +32,10 @@ class WaveformData(Document):
# Metadata
duration_seconds: float = Field(description="Total audio duration in seconds")
created_at: datetime = Field(
- default_factory=datetime.utcnow,
- description="When this waveform was generated"
+ default_factory=datetime.utcnow, description="When this waveform was generated"
)
processing_time_seconds: Optional[float] = Field(
- None,
- description="Time taken to generate waveform"
+ None, description="Time taken to generate waveform"
)
class Settings:
diff --git a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py
index 90c47460..35b0a2b0 100644
--- a/backends/advanced/src/advanced_omi_backend/plugins/__init__.py
+++ b/backends/advanced/src/advanced_omi_backend/plugins/__init__.py
@@ -20,13 +20,13 @@
from .services import PluginServices
__all__ = [
- 'BasePlugin',
- 'ButtonActionType',
- 'ButtonState',
- 'ConversationCloseReason',
- 'PluginContext',
- 'PluginEvent',
- 'PluginResult',
- 'PluginRouter',
- 'PluginServices',
+ "BasePlugin",
+ "ButtonActionType",
+ "ButtonState",
+ "ConversationCloseReason",
+ "PluginContext",
+ "PluginEvent",
+ "PluginResult",
+ "PluginRouter",
+ "PluginServices",
]
diff --git a/backends/advanced/src/advanced_omi_backend/routers/api_router.py b/backends/advanced/src/advanced_omi_backend/routers/api_router.py
index e4c89531..63e7180d 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/api_router.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/api_router.py
@@ -47,12 +47,15 @@
router.include_router(obsidian_router)
router.include_router(system_router)
router.include_router(queue_router)
-router.include_router(health_router) # Also include under /api for frontend compatibility
+router.include_router(
+ health_router
+) # Also include under /api for frontend compatibility
# Conditionally include test routes (only in test environments)
if os.getenv("DEBUG_DIR"):
try:
from .modules.test_routes import router as test_router
+
router.include_router(test_router)
logger.info("β
Test routes loaded (test environment detected)")
except Exception as e:
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py b/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py
index 501377fc..a65463d7 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/__init__.py
@@ -35,19 +35,19 @@
from .websocket_routes import router as websocket_router
__all__ = [
- "admin_router",
- "annotation_router",
- "audio_router",
- "chat_router",
- "client_router",
- "conversation_router",
- "finetuning_router",
- "health_router",
- "knowledge_graph_router",
- "memory_router",
- "obsidian_router",
- "queue_router",
- "system_router",
- "user_router",
- "websocket_router",
+ "admin_router",
+ "annotation_router",
+ "audio_router",
+ "chat_router",
+ "client_router",
+ "conversation_router",
+ "finetuning_router",
+ "health_router",
+ "knowledge_graph_router",
+ "memory_router",
+ "obsidian_router",
+ "queue_router",
+ "system_router",
+ "user_router",
+ "websocket_router",
]
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py
index 49594dd0..2b6f0dd9 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/admin_routes.py
@@ -21,32 +21,29 @@
def require_admin(current_user: User = Depends(current_active_user)) -> User:
"""Dependency to require admin/superuser permissions."""
if not current_user.is_superuser:
- raise HTTPException(
- status_code=403,
- detail="Admin permissions required"
- )
+ raise HTTPException(status_code=403, detail="Admin permissions required")
return current_user
@router.get("/cleanup/settings")
-async def get_cleanup_settings_admin(
- admin: User = Depends(require_admin)
-):
+async def get_cleanup_settings_admin(admin: User = Depends(require_admin)):
"""Get current cleanup settings (admin only)."""
from advanced_omi_backend.config import get_cleanup_settings
settings = get_cleanup_settings()
return {
**settings,
- "note": "Cleanup settings are stored in /app/data/cleanup_config.json"
+ "note": "Cleanup settings are stored in /app/data/cleanup_config.json",
}
@router.post("/cleanup")
async def trigger_cleanup(
dry_run: bool = Query(False, description="Preview what would be deleted"),
- retention_days: Optional[int] = Query(None, description="Override retention period"),
- admin: User = Depends(require_admin)
+ retention_days: Optional[int] = Query(
+ None, description="Override retention period"
+ ),
+ admin: User = Depends(require_admin),
):
"""Manually trigger cleanup of soft-deleted conversations (admin only)."""
try:
@@ -64,7 +61,9 @@ async def trigger_cleanup(
job_timeout="30m",
)
- logger.info(f"Admin {admin.email} triggered cleanup job {job.id} (dry_run={dry_run}, retention={retention_days or 'default'})")
+ logger.info(
+ f"Admin {admin.email} triggered cleanup job {job.id} (dry_run={dry_run}, retention={retention_days or 'default'})"
+ )
return JSONResponse(
status_code=200,
@@ -73,22 +72,23 @@ async def trigger_cleanup(
"job_id": job.id,
"retention_days": retention_days or "default (from config)",
"dry_run": dry_run,
- "note": "Check job status at /api/queue/jobs/{job_id}"
- }
+ "note": "Check job status at /api/queue/jobs/{job_id}",
+ },
)
except Exception as e:
logger.error(f"Failed to trigger cleanup: {e}")
return JSONResponse(
- status_code=500,
- content={"error": f"Failed to trigger cleanup: {str(e)}"}
+ status_code=500, content={"error": f"Failed to trigger cleanup: {str(e)}"}
)
@router.get("/cleanup/preview")
async def preview_cleanup(
- retention_days: Optional[int] = Query(None, description="Preview with specific retention period"),
- admin: User = Depends(require_admin)
+ retention_days: Optional[int] = Query(
+ None, description="Preview with specific retention period"
+ ),
+ admin: User = Depends(require_admin),
):
"""Preview what would be deleted by cleanup (admin only)."""
try:
@@ -100,26 +100,24 @@ async def preview_cleanup(
# Use provided retention or default from config
if retention_days is None:
settings_dict = get_cleanup_settings()
- retention_days = settings_dict['retention_days']
+ retention_days = settings_dict["retention_days"]
cutoff_date = datetime.utcnow() - timedelta(days=retention_days)
# Count conversations that would be deleted
count = await Conversation.find(
- Conversation.deleted == True,
- Conversation.deleted_at < cutoff_date
+ Conversation.deleted == True, Conversation.deleted_at < cutoff_date
).count()
return {
"retention_days": retention_days,
"cutoff_date": cutoff_date.isoformat(),
"conversations_to_delete": count,
- "note": f"Conversations deleted before {cutoff_date.date()} would be purged"
+ "note": f"Conversations deleted before {cutoff_date.date()} would be purged",
}
except Exception as e:
logger.error(f"Failed to preview cleanup: {e}")
return JSONResponse(
- status_code=500,
- content={"error": f"Failed to preview cleanup: {str(e)}"}
+ status_code=500, content={"error": f"Failed to preview cleanup: {str(e)}"}
)
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py
index fd1c659f..81065a8e 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/audio_routes.py
@@ -49,7 +49,10 @@ def _safe_filename(conversation: "Conversation") -> str:
@router.post("/upload_audio_from_gdrive")
async def upload_audio_from_drive_folder(
- gdrive_folder_id: str = Query(..., description="Google Drive Folder ID containing audio files (e.g., the string after /folders/ in the URL)"),
+ gdrive_folder_id: str = Query(
+ ...,
+ description="Google Drive Folder ID containing audio files (e.g., the string after /folders/ in the URL)",
+ ),
current_user: User = Depends(current_superuser),
device_name: str = Query(default="upload"),
):
@@ -67,7 +70,9 @@ async def upload_audio_from_drive_folder(
async def get_conversation_audio(
conversation_id: str,
request: Request,
- token: Optional[str] = Query(default=None, description="JWT token for audio element access"),
+ token: Optional[str] = Query(
+ default=None, description="JWT token for audio element access"
+ ),
current_user: Optional[User] = Depends(current_active_user_optional),
):
"""
@@ -111,7 +116,9 @@ async def get_conversation_audio(
raise HTTPException(status_code=404, detail="Conversation not found")
# Check ownership (admins can access all)
- if not current_user.is_superuser and conversation.user_id != str(current_user.user_id):
+ if not current_user.is_superuser and conversation.user_id != str(
+ current_user.user_id
+ ):
raise HTTPException(status_code=403, detail="Access denied")
# Reconstruct WAV from MongoDB chunks
@@ -123,8 +130,7 @@ async def get_conversation_audio(
except Exception as e:
# Reconstruction failed
raise HTTPException(
- status_code=500,
- detail=f"Failed to reconstruct audio: {str(e)}"
+ status_code=500, detail=f"Failed to reconstruct audio: {str(e)}"
)
# Handle Range requests for seeking support
@@ -143,7 +149,7 @@ async def get_conversation_audio(
"Accept-Ranges": "bytes",
"X-Audio-Source": "mongodb-chunks",
"X-Chunk-Count": str(conversation.audio_chunks_count or 0),
- }
+ },
)
# Parse Range header (e.g., "bytes=0-1023")
@@ -159,7 +165,7 @@ async def get_conversation_audio(
content_length = range_end - range_start + 1
# Extract requested byte range
- range_data = wav_data[range_start:range_end + 1]
+ range_data = wav_data[range_start : range_end + 1]
# Return 206 Partial Content with Range headers
return Response(
@@ -172,22 +178,21 @@ async def get_conversation_audio(
"Accept-Ranges": "bytes",
"Content-Disposition": f'inline; filename="{filename}.wav"',
"X-Audio-Source": "mongodb-chunks",
- }
+ },
)
except (ValueError, IndexError) as e:
# Invalid Range header, return 416 Range Not Satisfiable
return Response(
- status_code=416,
- headers={
- "Content-Range": f"bytes */{file_size}"
- }
+ status_code=416, headers={"Content-Range": f"bytes */{file_size}"}
)
@router.get("/stream_audio/{conversation_id}")
async def stream_conversation_audio(
conversation_id: str,
- token: Optional[str] = Query(default=None, description="JWT token for audio element access"),
+ token: Optional[str] = Query(
+ default=None, description="JWT token for audio element access"
+ ),
current_user: Optional[User] = Depends(current_active_user_optional),
):
"""
@@ -230,12 +235,16 @@ async def stream_conversation_audio(
raise HTTPException(status_code=404, detail="Conversation not found")
# Check ownership (admins can access all)
- if not current_user.is_superuser and conversation.user_id != str(current_user.user_id):
+ if not current_user.is_superuser and conversation.user_id != str(
+ current_user.user_id
+ ):
raise HTTPException(status_code=403, detail="Access denied")
# Check if chunks exist
if not conversation.audio_chunks_count or conversation.audio_chunks_count == 0:
- raise HTTPException(status_code=404, detail="No audio data for this conversation")
+ raise HTTPException(
+ status_code=404, detail="No audio data for this conversation"
+ )
async def stream_chunks():
"""Generator that yields WAV data in batches."""
@@ -249,6 +258,7 @@ async def stream_chunks():
# We'll write a placeholder size since we're streaming
wav_header = io.BytesIO()
import wave
+
with wave.open(wav_header, "wb") as wav:
wav.setnchannels(CHANNELS)
wav.setsampwidth(SAMPLE_WIDTH)
@@ -268,7 +278,7 @@ async def stream_chunks():
chunks = await retrieve_audio_chunks(
conversation_id=conversation_id,
start_index=start_index,
- limit=batch_size
+ limit=batch_size,
)
if not chunks:
@@ -292,7 +302,7 @@ async def stream_chunks():
"X-Audio-Source": "mongodb-chunks-stream",
"X-Chunk-Count": str(conversation.audio_chunks_count or 0),
"X-Total-Duration": str(conversation.audio_total_duration or 0),
- }
+ },
)
@@ -301,7 +311,9 @@ async def get_audio_chunk_range(
conversation_id: str,
start_time: float = Query(..., description="Start time in seconds"),
end_time: float = Query(..., description="End time in seconds"),
- token: Optional[str] = Query(default=None, description="JWT token for audio element access"),
+ token: Optional[str] = Query(
+ default=None, description="JWT token for audio element access"
+ ),
current_user: Optional[User] = Depends(current_active_user_optional),
):
"""
@@ -331,8 +343,11 @@ async def get_audio_chunk_range(
400: If time range is invalid
"""
import logging
+
logger = logging.getLogger(__name__)
- logger.info(f"π΅ Audio chunk request: conversation={conversation_id[:8]}..., start={start_time:.2f}s, end={end_time:.2f}s")
+ logger.info(
+ f"π΅ Audio chunk request: conversation={conversation_id[:8]}..., start={start_time:.2f}s, end={end_time:.2f}s"
+ )
# Try token param if header auth failed
if not current_user and token:
@@ -350,27 +365,38 @@ async def get_audio_chunk_range(
raise HTTPException(status_code=404, detail="Conversation not found")
# Check ownership (admins can access all)
- if not current_user.is_superuser and conversation.user_id != str(current_user.user_id):
+ if not current_user.is_superuser and conversation.user_id != str(
+ current_user.user_id
+ ):
raise HTTPException(status_code=403, detail="Access denied")
# Validate time range
if start_time < 0 or end_time <= start_time:
raise HTTPException(status_code=400, detail="Invalid time range")
- if conversation.audio_total_duration and end_time > conversation.audio_total_duration:
+ if (
+ conversation.audio_total_duration
+ and end_time > conversation.audio_total_duration
+ ):
end_time = conversation.audio_total_duration
# Use the dedicated segment reconstruction function
from advanced_omi_backend.utils.audio_chunk_utils import reconstruct_audio_segment
try:
- wav_data = await reconstruct_audio_segment(conversation_id, start_time, end_time)
- logger.info(f"β
Returning WAV: {len(wav_data)} bytes for range {start_time:.2f}s - {end_time:.2f}s")
+ wav_data = await reconstruct_audio_segment(
+ conversation_id, start_time, end_time
+ )
+ logger.info(
+ f"β
Returning WAV: {len(wav_data)} bytes for range {start_time:.2f}s - {end_time:.2f}s"
+ )
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"Failed to reconstruct audio segment: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to reconstruct audio: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to reconstruct audio: {str(e)}"
+ )
return StreamingResponse(
io.BytesIO(wav_data),
@@ -381,7 +407,7 @@ async def get_audio_chunk_range(
"X-Audio-Duration": str(end_time - start_time),
"X-Start-Time": str(start_time),
"X-End-Time": str(end_time),
- }
+ },
)
@@ -389,7 +415,9 @@ async def get_audio_chunk_range(
async def upload_audio_files(
current_user: User = Depends(current_superuser),
files: list[UploadFile] = File(...),
- device_name: str = Query(default="upload", description="Device name for uploaded files"),
+ device_name: str = Query(
+ default="upload", description="Device name for uploaded files"
+ ),
):
"""
Upload and process audio files. Admin only.
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py
index fdb73e5d..0d296340 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/chat_routes.py
@@ -31,18 +31,31 @@
# --- OpenAI-compatible chat completion models ---
+
class ChatCompletionMessage(BaseModel):
- role: str = Field(..., description="The role of the message author (system, user, assistant)")
+ role: str = Field(
+ ..., description="The role of the message author (system, user, assistant)"
+ )
content: str = Field(..., description="The message content")
class ChatCompletionRequest(BaseModel):
- messages: List[ChatCompletionMessage] = Field(..., min_length=1, description="List of messages in the conversation")
- model: Optional[str] = Field(None, description="Model to use (ignored, uses server-configured model)")
+ messages: List[ChatCompletionMessage] = Field(
+ ..., min_length=1, description="List of messages in the conversation"
+ )
+ model: Optional[str] = Field(
+ None, description="Model to use (ignored, uses server-configured model)"
+ )
stream: Optional[bool] = Field(True, description="Whether to stream the response")
- temperature: Optional[float] = Field(None, description="Sampling temperature (ignored, uses server config)")
- session_id: Optional[str] = Field(None, description="Chronicle session ID (creates new if not provided)")
- include_obsidian_memory: Optional[bool] = Field(False, description="Whether to include Obsidian vault context")
+ temperature: Optional[float] = Field(
+ None, description="Sampling temperature (ignored, uses server config)"
+ )
+ session_id: Optional[str] = Field(
+ None, description="Chronicle session ID (creates new if not provided)"
+ )
+ include_obsidian_memory: Optional[bool] = Field(
+ False, description="Whether to include Obsidian vault context"
+ )
class ChatCompletionChunkDelta(BaseModel):
@@ -104,7 +117,9 @@ class ChatSessionCreateRequest(BaseModel):
class ChatSessionUpdateRequest(BaseModel):
- title: str = Field(..., min_length=1, max_length=200, description="New session title")
+ title: str = Field(
+ ..., min_length=1, max_length=200, description="New session title"
+ )
class ChatStatisticsResponse(BaseModel):
@@ -115,99 +130,97 @@ class ChatStatisticsResponse(BaseModel):
@router.post("/sessions", response_model=ChatSessionResponse)
async def create_chat_session(
- request: ChatSessionCreateRequest,
- current_user: User = Depends(current_active_user)
+ request: ChatSessionCreateRequest, current_user: User = Depends(current_active_user)
):
"""Create a new chat session."""
try:
chat_service = get_chat_service()
session = await chat_service.create_session(
- user_id=str(current_user.id),
- title=request.title
+ user_id=str(current_user.id), title=request.title
)
-
+
return ChatSessionResponse(
session_id=session.session_id,
title=session.title,
created_at=session.created_at.isoformat(),
- updated_at=session.updated_at.isoformat()
+ updated_at=session.updated_at.isoformat(),
)
except Exception as e:
logger.error(f"Failed to create chat session for user {current_user.id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to create chat session"
+ detail="Failed to create chat session",
)
@router.get("/sessions", response_model=List[ChatSessionResponse])
async def get_chat_sessions(
- limit: int = 50,
- current_user: User = Depends(current_active_user)
+ limit: int = 50, current_user: User = Depends(current_active_user)
):
"""Get all chat sessions for the current user."""
try:
chat_service = get_chat_service()
sessions = await chat_service.get_user_sessions(
- user_id=str(current_user.id),
- limit=min(limit, 100) # Cap at 100
+ user_id=str(current_user.id), limit=min(limit, 100) # Cap at 100
)
-
+
# Get message counts for each session (this could be optimized with aggregation)
session_responses = []
for session in sessions:
messages = await chat_service.get_session_messages(
session_id=session.session_id,
user_id=str(current_user.id),
- limit=1 # We just need count, but MongoDB doesn't have efficient count
+ limit=1, # We just need count, but MongoDB doesn't have efficient count
)
-
- session_responses.append(ChatSessionResponse(
- session_id=session.session_id,
- title=session.title,
- created_at=session.created_at.isoformat(),
- updated_at=session.updated_at.isoformat(),
- message_count=len(messages) # This is approximate for efficiency
- ))
-
+
+ session_responses.append(
+ ChatSessionResponse(
+ session_id=session.session_id,
+ title=session.title,
+ created_at=session.created_at.isoformat(),
+ updated_at=session.updated_at.isoformat(),
+ message_count=len(messages), # This is approximate for efficiency
+ )
+ )
+
return session_responses
except Exception as e:
logger.error(f"Failed to get chat sessions for user {current_user.id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to retrieve chat sessions"
+ detail="Failed to retrieve chat sessions",
)
@router.get("/sessions/{session_id}", response_model=ChatSessionResponse)
async def get_chat_session(
- session_id: str,
- current_user: User = Depends(current_active_user)
+ session_id: str, current_user: User = Depends(current_active_user)
):
"""Get a specific chat session."""
try:
chat_service = get_chat_service()
session = await chat_service.get_session(session_id, str(current_user.id))
-
+
if not session:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Chat session not found"
+ status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found"
)
-
+
return ChatSessionResponse(
session_id=session.session_id,
title=session.title,
created_at=session.created_at.isoformat(),
- updated_at=session.updated_at.isoformat()
+ updated_at=session.updated_at.isoformat(),
)
except HTTPException:
raise
except Exception as e:
- logger.error(f"Failed to get chat session {session_id} for user {current_user.id}: {e}")
+ logger.error(
+ f"Failed to get chat session {session_id} for user {current_user.id}: {e}"
+ )
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to retrieve chat session"
+ detail="Failed to retrieve chat session",
)
@@ -215,100 +228,100 @@ async def get_chat_session(
async def update_chat_session(
session_id: str,
request: ChatSessionUpdateRequest,
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""Update a chat session's title."""
try:
chat_service = get_chat_service()
-
+
# Verify session exists and belongs to user
session = await chat_service.get_session(session_id, str(current_user.id))
if not session:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Chat session not found"
+ status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found"
)
-
+
# Update the title
success = await chat_service.update_session_title(
session_id, str(current_user.id), request.title
)
-
+
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to update session title"
+ detail="Failed to update session title",
)
-
+
# Return updated session
- updated_session = await chat_service.get_session(session_id, str(current_user.id))
+ updated_session = await chat_service.get_session(
+ session_id, str(current_user.id)
+ )
return ChatSessionResponse(
session_id=updated_session.session_id,
title=updated_session.title,
created_at=updated_session.created_at.isoformat(),
- updated_at=updated_session.updated_at.isoformat()
+ updated_at=updated_session.updated_at.isoformat(),
)
except HTTPException:
raise
except Exception as e:
- logger.error(f"Failed to update chat session {session_id} for user {current_user.id}: {e}")
+ logger.error(
+ f"Failed to update chat session {session_id} for user {current_user.id}: {e}"
+ )
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to update chat session"
+ detail="Failed to update chat session",
)
@router.delete("/sessions/{session_id}")
async def delete_chat_session(
- session_id: str,
- current_user: User = Depends(current_active_user)
+ session_id: str, current_user: User = Depends(current_active_user)
):
"""Delete a chat session and all its messages."""
try:
chat_service = get_chat_service()
success = await chat_service.delete_session(session_id, str(current_user.id))
-
+
if not success:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Chat session not found"
+ status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found"
)
-
+
return {"message": "Chat session deleted successfully"}
except HTTPException:
raise
except Exception as e:
- logger.error(f"Failed to delete chat session {session_id} for user {current_user.id}: {e}")
+ logger.error(
+ f"Failed to delete chat session {session_id} for user {current_user.id}: {e}"
+ )
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to delete chat session"
+ detail="Failed to delete chat session",
)
@router.get("/sessions/{session_id}/messages", response_model=List[ChatMessageResponse])
async def get_session_messages(
- session_id: str,
- limit: int = 100,
- current_user: User = Depends(current_active_user)
+ session_id: str, limit: int = 100, current_user: User = Depends(current_active_user)
):
"""Get all messages in a chat session."""
try:
chat_service = get_chat_service()
-
+
# Verify session exists and belongs to user
session = await chat_service.get_session(session_id, str(current_user.id))
if not session:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Chat session not found"
+ status_code=status.HTTP_404_NOT_FOUND, detail="Chat session not found"
)
-
+
messages = await chat_service.get_session_messages(
session_id=session_id,
user_id=str(current_user.id),
- limit=min(limit, 200) # Cap at 200
+ limit=min(limit, 200), # Cap at 200
)
-
+
return [
ChatMessageResponse(
message_id=msg.message_id,
@@ -316,24 +329,25 @@ async def get_session_messages(
role=msg.role,
content=msg.content,
timestamp=msg.timestamp.isoformat(),
- memories_used=msg.memories_used
+ memories_used=msg.memories_used,
)
for msg in messages
]
except HTTPException:
raise
except Exception as e:
- logger.error(f"Failed to get messages for session {session_id}, user {current_user.id}: {e}")
+ logger.error(
+ f"Failed to get messages for session {session_id}, user {current_user.id}: {e}"
+ )
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to retrieve messages"
+ detail="Failed to retrieve messages",
)
@router.post("/completions")
async def chat_completions(
- request: ChatCompletionRequest,
- current_user: User = Depends(current_active_user)
+ request: ChatCompletionRequest, current_user: User = Depends(current_active_user)
):
"""OpenAI-compatible chat completions endpoint with streaming support."""
try:
@@ -349,7 +363,7 @@ async def chat_completions(
if not session:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
- detail="Chat session not found"
+ detail="Chat session not found",
)
# Extract the latest user message
@@ -357,7 +371,7 @@ async def chat_completions(
if not user_messages:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
- detail="At least one user message is required"
+ detail="At least one user message is required",
)
message_content = user_messages[-1].content
@@ -368,9 +382,14 @@ async def chat_completions(
if request.stream:
return StreamingResponse(
_stream_openai_format(
- chat_service, session_id, str(current_user.id),
- message_content, request.include_obsidian_memory,
- completion_id, created, model_name,
+ chat_service,
+ session_id,
+ str(current_user.id),
+ message_content,
+ request.include_obsidian_memory,
+ completion_id,
+ created,
+ model_name,
),
media_type="text/event-stream",
headers={
@@ -381,9 +400,14 @@ async def chat_completions(
)
else:
return await _non_streaming_response(
- chat_service, session_id, str(current_user.id),
- message_content, request.include_obsidian_memory,
- completion_id, created, model_name,
+ chat_service,
+ session_id,
+ str(current_user.id),
+ message_content,
+ request.include_obsidian_memory,
+ completion_id,
+ created,
+ model_name,
)
except HTTPException:
@@ -392,14 +416,19 @@ async def chat_completions(
logger.error(f"Failed to process message for user {current_user.id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to process message"
+ detail="Failed to process message",
)
async def _stream_openai_format(
- chat_service, session_id: str, user_id: str,
- message_content: str, include_obsidian_memory: bool,
- completion_id: str, created: int, model_name: str,
+ chat_service,
+ session_id: str,
+ user_id: str,
+ message_content: str,
+ include_obsidian_memory: bool,
+ completion_id: str,
+ created: int,
+ model_name: str,
):
"""Map internal streaming events to OpenAI SSE chunk format."""
previous_text = ""
@@ -415,10 +444,14 @@ async def _stream_openai_format(
if event_type == "memory_context":
# First chunk: send role + chronicle metadata
chunk = ChatCompletionChunk(
- id=completion_id, created=created, model=model_name,
- choices=[ChatCompletionChunkChoice(
- delta=ChatCompletionChunkDelta(role="assistant"),
- )],
+ id=completion_id,
+ created=created,
+ model=model_name,
+ choices=[
+ ChatCompletionChunkChoice(
+ delta=ChatCompletionChunkDelta(role="assistant"),
+ )
+ ],
chronicle_metadata={
"session_id": session_id,
**event["data"],
@@ -429,24 +462,32 @@ async def _stream_openai_format(
elif event_type == "token":
# Internal events carry accumulated text; compute delta
accumulated = event["data"]
- delta_text = accumulated[len(previous_text):]
+ delta_text = accumulated[len(previous_text) :]
previous_text = accumulated
if delta_text:
chunk = ChatCompletionChunk(
- id=completion_id, created=created, model=model_name,
- choices=[ChatCompletionChunkChoice(
- delta=ChatCompletionChunkDelta(content=delta_text),
- )],
+ id=completion_id,
+ created=created,
+ model=model_name,
+ choices=[
+ ChatCompletionChunkChoice(
+ delta=ChatCompletionChunkDelta(content=delta_text),
+ )
+ ],
)
yield f"data: {chunk.model_dump_json()}\n\n"
elif event_type == "complete":
chunk = ChatCompletionChunk(
- id=completion_id, created=created, model=model_name,
- choices=[ChatCompletionChunkChoice(
- delta=ChatCompletionChunkDelta(),
- finish_reason="stop",
- )],
+ id=completion_id,
+ created=created,
+ model=model_name,
+ choices=[
+ ChatCompletionChunkChoice(
+ delta=ChatCompletionChunkDelta(),
+ finish_reason="stop",
+ )
+ ],
chronicle_metadata={
"session_id": session_id,
"message_id": event["data"].get("message_id"),
@@ -473,9 +514,14 @@ async def _stream_openai_format(
async def _non_streaming_response(
- chat_service, session_id: str, user_id: str,
- message_content: str, include_obsidian_memory: bool,
- completion_id: str, created: int, model_name: str,
+ chat_service,
+ session_id: str,
+ user_id: str,
+ message_content: str,
+ include_obsidian_memory: bool,
+ completion_id: str,
+ created: int,
+ model_name: str,
) -> ChatCompletionResponse:
"""Collect all events and return a single ChatCompletionResponse."""
full_content = ""
@@ -506,9 +552,13 @@ async def _non_streaming_response(
id=completion_id,
created=created,
model=model_name,
- choices=[ChatCompletionChoice(
- message=ChatCompletionMessage(role="assistant", content=full_content.strip()),
- )],
+ choices=[
+ ChatCompletionChoice(
+ message=ChatCompletionMessage(
+ role="assistant", content=full_content.strip()
+ ),
+ )
+ ],
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
session_id=session_id,
chronicle_metadata=metadata,
@@ -516,62 +566,62 @@ async def _non_streaming_response(
@router.get("/statistics", response_model=ChatStatisticsResponse)
-async def get_chat_statistics(
- current_user: User = Depends(current_active_user)
-):
+async def get_chat_statistics(current_user: User = Depends(current_active_user)):
"""Get chat statistics for the current user."""
try:
chat_service = get_chat_service()
stats = await chat_service.get_chat_statistics(str(current_user.id))
-
+
return ChatStatisticsResponse(
total_sessions=stats["total_sessions"],
total_messages=stats["total_messages"],
- last_chat=stats["last_chat"].isoformat() if stats["last_chat"] else None
+ last_chat=stats["last_chat"].isoformat() if stats["last_chat"] else None,
)
except Exception as e:
logger.error(f"Failed to get chat statistics for user {current_user.id}: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to retrieve chat statistics"
+ detail="Failed to retrieve chat statistics",
)
@router.post("/sessions/{session_id}/extract-memories")
async def extract_memories_from_session(
- session_id: str,
- current_user: User = Depends(current_active_user)
+ session_id: str, current_user: User = Depends(current_active_user)
):
"""Extract memories from a chat session."""
try:
chat_service = get_chat_service()
-
+
# Extract memories from the session
- success, memory_ids, memory_count = await chat_service.extract_memories_from_session(
- session_id=session_id,
- user_id=str(current_user.id)
+ success, memory_ids, memory_count = (
+ await chat_service.extract_memories_from_session(
+ session_id=session_id, user_id=str(current_user.id)
+ )
)
-
+
if success:
return {
"success": True,
"memory_ids": memory_ids,
"count": memory_count,
- "message": f"Successfully extracted {memory_count} memories from chat session"
+ "message": f"Successfully extracted {memory_count} memories from chat session",
}
else:
return {
"success": False,
"memory_ids": [],
"count": 0,
- "message": "Failed to extract memories from chat session"
+ "message": "Failed to extract memories from chat session",
}
-
+
except Exception as e:
- logger.error(f"Failed to extract memories from session {session_id} for user {current_user.id}: {e}")
+ logger.error(
+ f"Failed to extract memories from session {session_id} for user {current_user.id}: {e}"
+ )
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to extract memories from chat session"
+ detail="Failed to extract memories from chat session",
)
@@ -583,15 +633,11 @@ async def chat_health_check():
# Simple health check - verify service can be initialized
if not chat_service._initialized:
await chat_service.initialize()
-
- return {
- "status": "healthy",
- "service": "chat",
- "timestamp": time.time()
- }
+
+ return {"status": "healthy", "service": "chat", "timestamp": time.time()}
except Exception as e:
logger.error(f"Chat service health check failed: {e}")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="Chat service is not available"
- )
\ No newline at end of file
+ detail="Chat service is not available",
+ )
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py
index 7a89fd5f..f46daf2e 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/conversation_routes.py
@@ -26,28 +26,44 @@ async def close_current_conversation(
current_user: User = Depends(current_active_user),
):
"""Close the current active conversation for a client. Works for both connected and disconnected clients."""
- return await conversation_controller.close_current_conversation(client_id, current_user)
+ return await conversation_controller.close_current_conversation(
+ client_id, current_user
+ )
@router.get("")
async def get_conversations(
- include_deleted: bool = Query(False, description="Include soft-deleted conversations"),
- include_unprocessed: bool = Query(False, description="Include orphan audio sessions (always_persist with failed/pending transcription)"),
- starred_only: bool = Query(False, description="Only return starred/favorited conversations"),
+ include_deleted: bool = Query(
+ False, description="Include soft-deleted conversations"
+ ),
+ include_unprocessed: bool = Query(
+ False,
+ description="Include orphan audio sessions (always_persist with failed/pending transcription)",
+ ),
+ starred_only: bool = Query(
+ False, description="Only return starred/favorited conversations"
+ ),
limit: int = Query(200, ge=1, le=500, description="Max conversations to return"),
offset: int = Query(0, ge=0, description="Number of conversations to skip"),
- sort_by: str = Query("created_at", description="Sort field: created_at, title, audio_total_duration"),
+ sort_by: str = Query(
+ "created_at", description="Sort field: created_at, title, audio_total_duration"
+ ),
sort_order: str = Query("desc", description="Sort direction: asc or desc"),
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""Get conversations. Admins see all conversations, users see only their own."""
return await conversation_controller.get_conversations(
- current_user, include_deleted, include_unprocessed, starred_only, limit, offset,
- sort_by=sort_by, sort_order=sort_order,
+ current_user,
+ include_deleted,
+ include_unprocessed,
+ starred_only,
+ limit,
+ offset,
+ sort_by=sort_by,
+ sort_order=sort_order,
)
-
@router.get("/search")
async def search_conversations(
q: str = Query(..., min_length=1, description="Text search query"),
@@ -56,13 +72,14 @@ async def search_conversations(
current_user: User = Depends(current_active_user),
):
"""Full-text search across conversation titles, summaries, and transcripts."""
- return await conversation_controller.search_conversations(q, current_user, limit, offset)
+ return await conversation_controller.search_conversations(
+ q, current_user, limit, offset
+ )
@router.get("/{conversation_id}")
async def get_conversation_detail(
- conversation_id: str,
- current_user: User = Depends(current_active_user)
+ conversation_id: str, current_user: User = Depends(current_active_user)
):
"""Get a specific conversation with full transcript details."""
return await conversation_controller.get_conversation(conversation_id, current_user)
@@ -94,24 +111,28 @@ async def reprocess_transcript(
conversation_id: str, current_user: User = Depends(current_active_user)
):
"""Reprocess transcript for a conversation. Users can only reprocess their own conversations."""
- return await conversation_controller.reprocess_transcript(conversation_id, current_user)
+ return await conversation_controller.reprocess_transcript(
+ conversation_id, current_user
+ )
@router.post("/{conversation_id}/reprocess-memory")
async def reprocess_memory(
conversation_id: str,
current_user: User = Depends(current_active_user),
- transcript_version_id: str = Query(default="active")
+ transcript_version_id: str = Query(default="active"),
):
"""Reprocess memory extraction for a specific transcript version. Users can only reprocess their own conversations."""
- return await conversation_controller.reprocess_memory(conversation_id, transcript_version_id, current_user)
+ return await conversation_controller.reprocess_memory(
+ conversation_id, transcript_version_id, current_user
+ )
@router.post("/{conversation_id}/reprocess-speakers")
async def reprocess_speakers(
conversation_id: str,
current_user: User = Depends(current_active_user),
- transcript_version_id: str = Query(default="active")
+ transcript_version_id: str = Query(default="active"),
):
"""
Re-run speaker identification/diarization on existing transcript.
@@ -127,9 +148,7 @@ async def reprocess_speakers(
Job status with job_id and new version_id
"""
return await conversation_controller.reprocess_speakers(
- conversation_id,
- transcript_version_id,
- current_user
+ conversation_id, transcript_version_id, current_user
)
@@ -137,20 +156,24 @@ async def reprocess_speakers(
async def activate_transcript_version(
conversation_id: str,
version_id: str,
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""Activate a specific transcript version. Users can only modify their own conversations."""
- return await conversation_controller.activate_transcript_version(conversation_id, version_id, current_user)
+ return await conversation_controller.activate_transcript_version(
+ conversation_id, version_id, current_user
+ )
@router.post("/{conversation_id}/activate-memory/{version_id}")
async def activate_memory_version(
conversation_id: str,
version_id: str,
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""Activate a specific memory version. Users can only modify their own conversations."""
- return await conversation_controller.activate_memory_version(conversation_id, version_id, current_user)
+ return await conversation_controller.activate_memory_version(
+ conversation_id, version_id, current_user
+ )
@router.get("/{conversation_id}/versions")
@@ -158,13 +181,14 @@ async def get_conversation_version_history(
conversation_id: str, current_user: User = Depends(current_active_user)
):
"""Get version history for a conversation. Users can only access their own conversations."""
- return await conversation_controller.get_conversation_version_history(conversation_id, current_user)
+ return await conversation_controller.get_conversation_version_history(
+ conversation_id, current_user
+ )
@router.get("/{conversation_id}/waveform")
async def get_conversation_waveform(
- conversation_id: str,
- current_user: User = Depends(current_active_user)
+ conversation_id: str, current_user: User = Depends(current_active_user)
):
"""
Get or generate waveform visualization data for a conversation.
@@ -208,37 +232,38 @@ async def get_conversation_waveform(
# If waveform exists, return cached version
if waveform:
- logger.info(f"Returning cached waveform for conversation {conversation_id[:12]}")
+ logger.info(
+ f"Returning cached waveform for conversation {conversation_id[:12]}"
+ )
return waveform.model_dump(exclude={"id", "revision_id"})
# Generate waveform on-demand
- logger.info(f"Generating waveform on-demand for conversation {conversation_id[:12]}")
+ logger.info(
+ f"Generating waveform on-demand for conversation {conversation_id[:12]}"
+ )
waveform_dict = await generate_waveform_data(
- conversation_id=conversation_id,
- sample_rate=3
+ conversation_id=conversation_id, sample_rate=3
)
if not waveform_dict.get("success"):
error_msg = waveform_dict.get("error", "Unknown error")
logger.error(f"Waveform generation failed: {error_msg}")
raise HTTPException(
- status_code=500,
- detail=f"Waveform generation failed: {error_msg}"
+ status_code=500, detail=f"Waveform generation failed: {error_msg}"
)
# Return generated waveform (already saved to database by generator)
return {
"samples": waveform_dict["samples"],
"sample_rate": waveform_dict["sample_rate"],
- "duration_seconds": waveform_dict["duration_seconds"]
+ "duration_seconds": waveform_dict["duration_seconds"],
}
@router.get("/{conversation_id}/metadata")
async def get_conversation_metadata(
- conversation_id: str,
- current_user: User = Depends(current_active_user)
+ conversation_id: str, current_user: User = Depends(current_active_user)
) -> dict:
"""
Get conversation metadata (duration, etc.) without loading audio.
@@ -270,7 +295,7 @@ async def get_conversation_metadata(
"conversation_id": conversation_id,
"duration": conversation.audio_total_duration or 0.0,
"created_at": conversation.created_at,
- "has_audio": (conversation.audio_total_duration or 0.0) > 0
+ "has_audio": (conversation.audio_total_duration or 0.0) > 0,
}
@@ -278,8 +303,10 @@ async def get_conversation_metadata(
async def get_audio_segment(
conversation_id: str,
start: float = Query(0.0, description="Start time in seconds"),
- duration: Optional[float] = Query(None, description="Duration in seconds (omit for full audio)"),
- current_user: User = Depends(current_active_user)
+ duration: Optional[float] = Query(
+ None, description="Duration in seconds (omit for full audio)"
+ ),
+ current_user: User = Depends(current_active_user),
) -> Response:
"""
Get audio segment from a conversation.
@@ -297,6 +324,7 @@ async def get_audio_segment(
WAV audio bytes (16kHz, mono) for the requested time range
"""
import time
+
request_start = time.time()
# Verify conversation exists and user has access
@@ -314,7 +342,9 @@ async def get_audio_segment(
# Calculate end time
total_duration = conversation.audio_total_duration or 0.0
if total_duration == 0:
- raise HTTPException(status_code=404, detail="No audio available for this conversation")
+ raise HTTPException(
+ status_code=404, detail="No audio available for this conversation"
+ )
if duration is None:
end = total_duration
@@ -323,18 +353,23 @@ async def get_audio_segment(
# Validate time range
if start < 0 or start >= total_duration:
- raise HTTPException(status_code=400, detail=f"Invalid start time: {start}s (max: {total_duration}s)")
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid start time: {start}s (max: {total_duration}s)",
+ )
# Get audio chunks for time range
try:
wav_bytes = await reconstruct_audio_segment(
- conversation_id=conversation_id,
- start_time=start,
- end_time=end
+ conversation_id=conversation_id, start_time=start, end_time=end
)
except Exception as e:
- logger.error(f"Failed to reconstruct audio segment for {conversation_id[:12]}: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to reconstruct audio: {str(e)}")
+ logger.error(
+ f"Failed to reconstruct audio segment for {conversation_id[:12]}: {e}"
+ )
+ raise HTTPException(
+ status_code=500, detail=f"Failed to reconstruct audio: {str(e)}"
+ )
request_time = time.time() - request_start
logger.info(
@@ -351,15 +386,14 @@ async def get_audio_segment(
"Content-Disposition": f"attachment; filename=segment_{start}_{end}.wav",
"X-Audio-Start": str(start),
"X-Audio-End": str(end),
- "X-Audio-Duration": str(end - start)
- }
+ "X-Audio-Duration": str(end - start),
+ },
)
@router.post("/{conversation_id}/star")
async def toggle_star(
- conversation_id: str,
- current_user: User = Depends(current_active_user)
+ conversation_id: str, current_user: User = Depends(current_active_user)
):
"""Toggle the starred/favorite status of a conversation."""
return await conversation_controller.toggle_star(conversation_id, current_user)
@@ -369,16 +403,19 @@ async def toggle_star(
async def delete_conversation(
conversation_id: str,
permanent: bool = Query(False, description="Permanently delete (admin only)"),
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
"""Soft delete a conversation (or permanently delete if admin)."""
- return await conversation_controller.delete_conversation(conversation_id, current_user, permanent)
+ return await conversation_controller.delete_conversation(
+ conversation_id, current_user, permanent
+ )
@router.post("/{conversation_id}/restore")
async def restore_conversation(
- conversation_id: str,
- current_user: User = Depends(current_active_user)
+ conversation_id: str, current_user: User = Depends(current_active_user)
):
"""Restore a soft-deleted conversation."""
- return await conversation_controller.restore_conversation(conversation_id, current_user)
+ return await conversation_controller.restore_conversation(
+ conversation_id, current_user
+ )
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py
index cd4bff97..4ea42139 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/health_routes.py
@@ -48,7 +48,7 @@
else:
_llm_def = _embed_def = _vs_def = None
-QDRANT_BASE_URL = (_vs_def.model_params.get("host") if _vs_def else "qdrant")
+QDRANT_BASE_URL = _vs_def.model_params.get("host") if _vs_def else "qdrant"
QDRANT_PORT = str(_vs_def.model_params.get("port") if _vs_def else "6333")
@@ -58,7 +58,7 @@ async def auth_health_check():
try:
# Test database connectivity
await mongo_client.admin.command("ping")
-
+
# Test memory service if available
if memory_service:
try:
@@ -69,12 +69,12 @@ async def auth_health_check():
memory_status = "degraded"
else:
memory_status = "unavailable"
-
+
return {
"status": "ok",
- "database": "ok",
+ "database": "ok",
"memory_service": memory_status,
- "timestamp": int(time.time())
+ "timestamp": int(time.time()),
}
except Exception as e:
logger.error(f"Auth health check failed: {e}")
@@ -84,8 +84,8 @@ async def auth_health_check():
"status": "error",
"detail": "Service connectivity check failed",
"error_type": "connection_failure",
- "timestamp": int(time.time())
- }
+ "timestamp": int(time.time()),
+ },
)
@@ -129,16 +129,18 @@ async def health_check():
else "Not configured"
),
"transcription_provider": (
- REGISTRY.get_default("stt").name if REGISTRY and REGISTRY.get_default("stt")
+ REGISTRY.get_default("stt").name
+ if REGISTRY and REGISTRY.get_default("stt")
else "not configured"
),
-
"provider_type": (
transcription_provider.mode if transcription_provider else "none"
),
"chunk_dir": str(os.getenv("CHUNK_DIR", "./audio_chunks")),
"active_clients": get_client_manager().get_client_count(),
- "new_conversation_timeout_minutes": float(os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5")),
+ "new_conversation_timeout_minutes": float(
+ os.getenv("NEW_CONVERSATION_TIMEOUT_MINUTES", "1.5")
+ ),
"llm_provider": (_llm_def.model_provider if _llm_def else None),
"llm_model": (_llm_def.model_name if _llm_def else None),
"llm_base_url": (_llm_def.model_url if _llm_def else None),
@@ -204,14 +206,14 @@ async def health_check():
"worker_count": worker_count,
"active_workers": active_workers,
"idle_workers": idle_workers,
- "queues": queue_health.get("queues", {})
+ "queues": queue_health.get("queues", {}),
}
else:
health_status["services"]["redis"] = {
"status": f"β Connection Failed: {queue_health.get('redis_connection')}",
"healthy": False,
"critical": True,
- "worker_count": 0
+ "worker_count": 0,
}
overall_healthy = False
critical_services_healthy = False
@@ -221,7 +223,7 @@ async def health_check():
"status": "β Connection Timeout (5s)",
"healthy": False,
"critical": True,
- "worker_count": 0
+ "worker_count": 0,
}
overall_healthy = False
critical_services_healthy = False
@@ -230,7 +232,7 @@ async def health_check():
"status": f"β Connection Failed: {str(e)}",
"healthy": False,
"critical": True,
- "worker_count": 0
+ "worker_count": 0,
}
overall_healthy = False
critical_services_healthy = False
@@ -267,7 +269,9 @@ async def health_check():
if memory_provider == "chronicle":
try:
# Test Chronicle memory service connection with timeout
- test_success = await asyncio.wait_for(memory_service.test_connection(), timeout=8.0)
+ test_success = await asyncio.wait_for(
+ memory_service.test_connection(), timeout=8.0
+ )
if test_success:
health_status["services"]["memory_service"] = {
"status": "β
Chronicle Memory Connected",
@@ -361,7 +365,8 @@ async def health_check():
# Make a health check request to the speaker service
async with aiohttp.ClientSession() as session:
async with session.get(
- f"{speaker_service_url}/health", timeout=aiohttp.ClientTimeout(total=5)
+ f"{speaker_service_url}/health",
+ timeout=aiohttp.ClientTimeout(total=5),
) as response:
if response.status == 200:
health_status["services"]["speaker_recognition"] = {
@@ -401,7 +406,8 @@ async def health_check():
# Make a health check request to the OpenMemory MCP service
async with aiohttp.ClientSession() as session:
async with session.get(
- f"{openmemory_mcp_url}/api/v1/apps/", timeout=aiohttp.ClientTimeout(total=5)
+ f"{openmemory_mcp_url}/api/v1/apps/",
+ timeout=aiohttp.ClientTimeout(total=5),
) as response:
if response.status == 200:
health_status["services"]["openmemory_mcp"] = {
@@ -464,7 +470,9 @@ async def health_check():
if not service["healthy"] and not service.get("critical", True)
]
if unhealthy_optional:
- messages.append(f"Optional services unavailable: {', '.join(unhealthy_optional)}")
+ messages.append(
+ f"Optional services unavailable: {', '.join(unhealthy_optional)}"
+ )
health_status["message"] = "; ".join(messages)
@@ -476,15 +484,21 @@ async def readiness_check():
"""Simple readiness check for container orchestration."""
# Use debug level for health check to reduce log spam
logger.debug("Readiness check requested")
-
+
# Only check critical services for readiness
try:
# Quick MongoDB ping to ensure we can serve requests
await asyncio.wait_for(mongo_client.admin.command("ping"), timeout=2.0)
- return JSONResponse(content={"status": "ready", "timestamp": int(time.time())}, status_code=200)
+ return JSONResponse(
+ content={"status": "ready", "timestamp": int(time.time())}, status_code=200
+ )
except Exception as e:
logger.error(f"Readiness check failed: {e}")
return JSONResponse(
- content={"status": "not_ready", "error": str(e), "timestamp": int(time.time())},
- status_code=503
+ content={
+ "status": "not_ready",
+ "error": str(e),
+ "timestamp": int(time.time()),
+ },
+ status_code=503,
)
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py
index 1b8ae1cf..6fcebea1 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/knowledge_graph_routes.py
@@ -37,6 +37,7 @@
class UpdateEntityRequest(BaseModel):
"""Request model for updating entity fields."""
+
name: Optional[str] = None
details: Optional[str] = None
icon: Optional[str] = None
@@ -44,6 +45,7 @@ class UpdateEntityRequest(BaseModel):
class UpdatePromiseRequest(BaseModel):
"""Request model for updating promise status."""
+
status: str # pending, in_progress, completed, cancelled
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py
index 409f7b85..89ec6091 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/memory_routes.py
@@ -21,6 +21,7 @@
class AddMemoryRequest(BaseModel):
"""Request model for adding a memory."""
+
content: str
source_id: Optional[str] = None
@@ -29,7 +30,9 @@ class AddMemoryRequest(BaseModel):
async def get_memories(
current_user: User = Depends(current_active_user),
limit: int = Query(default=50, ge=1, le=1000),
- user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"),
+ user_id: Optional[str] = Query(
+ default=None, description="User ID filter (admin only)"
+ ),
):
"""Get memories. Users see only their own memories, admins can see all or filter by user."""
return await memory_controller.get_memories(current_user, limit, user_id)
@@ -39,10 +42,14 @@ async def get_memories(
async def get_memories_with_transcripts(
current_user: User = Depends(current_active_user),
limit: int = Query(default=50, ge=1, le=1000),
- user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"),
+ user_id: Optional[str] = Query(
+ default=None, description="User ID filter (admin only)"
+ ),
):
"""Get memories with their source transcripts. Users see only their own memories, admins can see all or filter by user."""
- return await memory_controller.get_memories_with_transcripts(current_user, limit, user_id)
+ return await memory_controller.get_memories_with_transcripts(
+ current_user, limit, user_id
+ )
@router.get("/search")
@@ -50,30 +57,44 @@ async def search_memories(
query: str = Query(..., description="Search query"),
current_user: User = Depends(current_active_user),
limit: int = Query(default=20, ge=1, le=100),
- score_threshold: float = Query(default=0.0, ge=0.0, le=1.0, description="Minimum similarity score (0.0 = no threshold)"),
- user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"),
+ score_threshold: float = Query(
+ default=0.0,
+ ge=0.0,
+ le=1.0,
+ description="Minimum similarity score (0.0 = no threshold)",
+ ),
+ user_id: Optional[str] = Query(
+ default=None, description="User ID filter (admin only)"
+ ),
):
"""Search memories by text query with configurable similarity threshold. Users can only search their own memories, admins can search all or filter by user."""
- return await memory_controller.search_memories(query, current_user, limit, score_threshold, user_id)
+ return await memory_controller.search_memories(
+ query, current_user, limit, score_threshold, user_id
+ )
@router.post("")
async def add_memory(
- request: AddMemoryRequest,
- current_user: User = Depends(current_active_user)
+ request: AddMemoryRequest, current_user: User = Depends(current_active_user)
):
"""Add a memory directly from content text. The service will extract structured memories from the provided content."""
- return await memory_controller.add_memory(request.content, current_user, request.source_id)
+ return await memory_controller.add_memory(
+ request.content, current_user, request.source_id
+ )
@router.delete("/{memory_id}")
-async def delete_memory(memory_id: str, current_user: User = Depends(current_active_user)):
+async def delete_memory(
+ memory_id: str, current_user: User = Depends(current_active_user)
+):
"""Delete a memory by ID. Users can only delete their own memories, admins can delete any."""
return await memory_controller.delete_memory(memory_id, current_user)
@router.get("/admin")
-async def get_all_memories_admin(current_user: User = Depends(current_superuser), limit: int = 200):
+async def get_all_memories_admin(
+ current_user: User = Depends(current_superuser), limit: int = 200
+):
"""Get all memories across all users for admin review. Admin only."""
return await memory_controller.get_all_memories_admin(current_user, limit)
@@ -82,7 +103,9 @@ async def get_all_memories_admin(current_user: User = Depends(current_superuser)
async def get_memory_by_id(
memory_id: str,
current_user: User = Depends(current_active_user),
- user_id: Optional[str] = Query(default=None, description="User ID filter (admin only)"),
+ user_id: Optional[str] = Query(
+ default=None, description="User ID filter (admin only)"
+ ),
):
"""Get a single memory by ID. Users can only access their own memories, admins can access any."""
return await memory_controller.get_memory_by_id(memory_id, current_user, user_id)
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py
index b02ed426..8d76b947 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/obsidian_routes.py
@@ -1,4 +1,3 @@
-
import json
import logging
import os
@@ -25,20 +24,23 @@
router = APIRouter(prefix="/obsidian", tags=["obsidian"])
+
class IngestRequest(BaseModel):
vault_path: str
+
@router.post("/ingest")
async def ingest_obsidian_vault(
- request: IngestRequest,
- current_user: User = Depends(current_active_user)
+ request: IngestRequest, current_user: User = Depends(current_active_user)
):
"""
Immediate/synchronous ingestion endpoint (legacy). Not recommended for UI.
Prefer the upload_zip + start endpoints to enable progress reporting.
"""
if not os.path.exists(request.vault_path):
- raise HTTPException(status_code=400, detail=f"Path not found: {request.vault_path}")
+ raise HTTPException(
+ status_code=400, detail=f"Path not found: {request.vault_path}"
+ )
try:
result = await obsidian_service.ingest_vault(request.vault_path)
@@ -50,15 +52,16 @@ async def ingest_obsidian_vault(
@router.post("/upload_zip")
async def upload_obsidian_zip(
- file: UploadFile = File(...),
- current_user: User = Depends(current_superuser)
+ file: UploadFile = File(...), current_user: User = Depends(current_superuser)
):
"""
Upload a zipped Obsidian vault. Returns a job_id that can be started later.
Uses upload_files_async pattern from upload_files.py for proper file handling.
"""
- if not file.filename.lower().endswith('.zip'):
- raise HTTPException(status_code=400, detail="Please upload a .zip file of your Obsidian vault")
+ if not file.filename.lower().endswith(".zip"):
+ raise HTTPException(
+ status_code=400, detail="Please upload a .zip file of your Obsidian vault"
+ )
job_id = str(uuid.uuid4())
base_dir = Path("/app/data/obsidian_jobs")
@@ -67,21 +70,23 @@ async def upload_obsidian_zip(
job_dir.mkdir(parents=True, exist_ok=True)
zip_path = job_dir / "vault.zip"
extract_dir = job_dir / "vault"
-
+
# Use upload_files_async pattern for proper file handling with cleanup
zip_file_handle = None
try:
# Read file content
file_content = await file.read()
-
+
# Save zip file using proper file handling pattern from upload_files_async
try:
- zip_file_handle = open(zip_path, 'wb')
+ zip_file_handle = open(zip_path, "wb")
zip_file_handle.write(file_content)
except IOError as e:
logger.error(f"Error writing zip file {zip_path}: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to save uploaded zip: {e}")
-
+ raise HTTPException(
+ status_code=500, detail=f"Failed to save uploaded zip: {e}"
+ )
+
# Extract zip file using utility function
try:
extract_zip(zip_path, extract_dir)
@@ -90,10 +95,12 @@ async def upload_obsidian_zip(
raise HTTPException(status_code=400, detail=f"Invalid zip file: {e}")
except ZipExtractionError as e:
logger.error(f"Error extracting zip file: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to extract zip file: {e}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to extract zip file: {e}"
+ )
total = count_markdown_files(str(extract_dir))
-
+
# Store pending job state in Redis
pending_state = {
"status": "ready",
@@ -101,16 +108,20 @@ async def upload_obsidian_zip(
"processed": 0,
"errors": [],
"vault_path": str(extract_dir),
- "job_id": job_id
+ "job_id": job_id,
}
- redis_conn.set(f"obsidian_pending:{job_id}", json.dumps(pending_state), ex=3600*24) # 24h expiry
+ redis_conn.set(
+ f"obsidian_pending:{job_id}", json.dumps(pending_state), ex=3600 * 24
+ ) # 24h expiry
return {"job_id": job_id, "vault_path": str(extract_dir), "total_files": total}
except HTTPException:
raise
except Exception as e:
logger.exception(f"Failed to process uploaded zip: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to process uploaded zip: {e}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to process uploaded zip: {e}"
+ )
finally:
# Ensure file handle is closed (following upload_files_async pattern)
if zip_file_handle:
@@ -123,17 +134,17 @@ async def upload_obsidian_zip(
@router.post("/start")
async def start_ingestion(
job_id: str = Body(..., embed=True),
- current_user: User = Depends(current_active_user)
+ current_user: User = Depends(current_active_user),
):
# Check if job is pending
pending_key = f"obsidian_pending:{job_id}"
pending_data = redis_conn.get(pending_key)
-
+
if pending_data:
try:
job_data = json.loads(pending_data)
vault_path = job_data.get("vault_path")
-
+
# Enqueue to RQ
rq_job = default_queue.enqueue(
ingest_obsidian_vault_job,
@@ -141,27 +152,31 @@ async def start_ingestion(
vault_path, # arg2
job_id=job_id, # Set RQ job ID to match our ID
description=f"Obsidian ingestion for job {job_id}",
- job_timeout=3600 # 1 hour timeout
+ job_timeout=3600, # 1 hour timeout
)
-
+
# Remove pending key
redis_conn.delete(pending_key)
-
- return {"message": "Ingestion started", "job_id": job_id, "rq_job_id": rq_job.id}
+
+ return {
+ "message": "Ingestion started",
+ "job_id": job_id,
+ "rq_job_id": rq_job.id,
+ }
except Exception as e:
logger.exception(f"Failed to start job {job_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to start job: {e}")
-
+
# Check if already in RQ
try:
job = Job.fetch(job_id, connection=redis_conn)
status = job.get_status()
if status in ("queued", "started", "deferred", "scheduled"):
- raise HTTPException(status_code=400, detail=f"Job already {status}")
-
+ raise HTTPException(status_code=400, detail=f"Job already {status}")
+
# If finished/failed, we could potentially restart? But for now let's say it's done.
raise HTTPException(status_code=400, detail=f"Job is in state: {status}")
-
+
except NoSuchJobError:
raise HTTPException(status_code=404, detail="Job not found")
@@ -171,7 +186,7 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us
# 1. Try RQ first
try:
job = Job.fetch(job_id, connection=redis_conn)
-
+
# Get status
status = job.get_status()
if status == "started":
@@ -181,13 +196,18 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us
meta = job.meta or {}
# If meta has status, prefer it (for granular updates)
- if "status" in meta and meta["status"] in ("running", "finished", "failed", "canceled"):
- status = meta["status"]
+ if "status" in meta and meta["status"] in (
+ "running",
+ "finished",
+ "failed",
+ "canceled",
+ ):
+ status = meta["status"]
total = meta.get("total_files", 0)
processed = meta.get("processed", 0)
percent = int((processed / total) * 100) if total else 0
-
+
return {
"job_id": job_id,
"status": status,
@@ -196,14 +216,14 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us
"percent": percent,
"errors": meta.get("errors", []),
"vault_path": meta.get("vault_path"),
- "rq_job_id": job.id
+ "rq_job_id": job.id,
}
-
+
except NoSuchJobError:
# 2. Check pending
pending_key = f"obsidian_pending:{job_id}"
pending_data = redis_conn.get(pending_key)
-
+
if pending_data:
try:
job_data = json.loads(pending_data)
@@ -214,10 +234,8 @@ async def get_status(job_id: str, current_user: User = Depends(current_active_us
"processed": 0,
"percent": 0,
"errors": [],
- "vault_path": job_data.get("vault_path")
+ "vault_path": job_data.get("vault_path"),
}
except:
raise HTTPException(status_code=500, detail="Failed to get job status")
raise HTTPException(status_code=404, detail="Job not found")
-
-
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
index 29719566..745321f5 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/queue_routes.py
@@ -73,7 +73,9 @@ async def list_jobs(
@router.get("/jobs/{job_id}/status")
-async def get_job_status(job_id: str, current_user: User = Depends(current_active_user)):
+async def get_job_status(
+ job_id: str, current_user: User = Depends(current_active_user)
+):
"""Get just the status of a specific job (lightweight endpoint)."""
try:
job = Job.fetch(job_id, connection=redis_conn)
@@ -193,7 +195,9 @@ async def cancel_job(job_id: str, current_user: User = Depends(current_active_us
@router.get("/jobs/by-client/{client_id}")
-async def get_jobs_by_client(client_id: str, current_user: User = Depends(current_active_user)):
+async def get_jobs_by_client(
+ client_id: str, current_user: User = Depends(current_active_user)
+):
"""Get all jobs associated with a specific client device."""
try:
from rq.registry import (
@@ -242,11 +246,17 @@ def process_job_and_dependents(job, queue_name, base_status):
all_jobs.append(
{
"job_id": job.id,
- "job_type": (job.func_name.split(".")[-1] if job.func_name else "unknown"),
+ "job_type": (
+ job.func_name.split(".")[-1] if job.func_name else "unknown"
+ ),
"queue": queue_name,
"status": status,
- "created_at": (job.created_at.isoformat() if job.created_at else None),
- "started_at": (job.started_at.isoformat() if job.started_at else None),
+ "created_at": (
+ job.created_at.isoformat() if job.created_at else None
+ ),
+ "started_at": (
+ job.started_at.isoformat() if job.started_at else None
+ ),
"ended_at": job.ended_at.isoformat() if job.ended_at else None,
"description": job.description or "",
"result": job.result,
@@ -323,13 +333,17 @@ def process_job_and_dependents(job, queue_name, base_status):
# Sort by created_at
all_jobs.sort(key=lambda x: x["created_at"] or "", reverse=False)
- logger.info(f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)")
+ logger.info(
+ f"Found {len(all_jobs)} jobs for client {client_id} (including dependents)"
+ )
return {"client_id": client_id, "jobs": all_jobs, "total": len(all_jobs)}
except Exception as e:
logger.error(f"Failed to get jobs for client {client_id}: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to get jobs for client: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to get jobs for client: {str(e)}"
+ )
@router.get("/events")
@@ -349,7 +363,9 @@ async def get_events(
if not router_instance:
return {"events": [], "total": 0}
- events = router_instance.get_recent_events(limit=limit, event_type=event_type or None)
+ events = router_instance.get_recent_events(
+ limit=limit, event_type=event_type or None
+ )
return {"events": events, "total": len(events)}
except Exception as e:
logger.error(f"Failed to get events: {e}")
@@ -467,7 +483,9 @@ async def get_queue_worker_details(current_user: User = Depends(current_active_u
except Exception as e:
logger.error(f"Failed to get queue worker details: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to get worker details: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to get worker details: {str(e)}"
+ )
@router.get("/streams")
@@ -498,7 +516,9 @@ async def get_stream_stats(
async def get_stream_info(stream_key):
try:
- stream_name = stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ stream_name = (
+ stream_key.decode() if isinstance(stream_key, bytes) else stream_key
+ )
# Get basic stream info
info = await audio_service.redis.xinfo_stream(stream_name)
@@ -558,7 +578,9 @@ async def get_stream_info(stream_key):
"name": group_dict.get("name", "unknown"),
"consumers": group_dict.get("consumers", 0),
"pending": group_dict.get("pending", 0),
- "last_delivered_id": group_dict.get("last-delivered-id", "N/A"),
+ "last_delivered_id": group_dict.get(
+ "last-delivered-id", "N/A"
+ ),
"consumer_details": consumers,
}
)
@@ -569,7 +591,9 @@ async def get_stream_info(stream_key):
"stream_name": stream_name,
"length": info[b"length"],
"first_entry_id": (
- info[b"first-entry"][0].decode() if info[b"first-entry"] else None
+ info[b"first-entry"][0].decode()
+ if info[b"first-entry"]
+ else None
),
"last_entry_id": (
info[b"last-entry"][0].decode() if info[b"last-entry"] else None
@@ -581,7 +605,9 @@ async def get_stream_info(stream_key):
return None
# Fetch all stream info in parallel
- streams_info_results = await asyncio.gather(*[get_stream_info(key) for key in stream_keys])
+ streams_info_results = await asyncio.gather(
+ *[get_stream_info(key) for key in stream_keys]
+ )
streams_info = [info for info in streams_info_results if info is not None]
return {
@@ -607,7 +633,9 @@ class FlushAllJobsRequest(BaseModel):
@router.post("/flush")
-async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(current_active_user)):
+async def flush_jobs(
+ request: FlushJobsRequest, current_user: User = Depends(current_active_user)
+):
"""Flush old inactive jobs based on age and status."""
if not current_user.is_superuser:
raise HTTPException(status_code=403, detail="Admin access required")
@@ -623,7 +651,9 @@ async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(cur
from advanced_omi_backend.controllers.queue_controller import get_queue
- cutoff_time = datetime.now(timezone.utc) - timedelta(hours=request.older_than_hours)
+ cutoff_time = datetime.now(timezone.utc) - timedelta(
+ hours=request.older_than_hours
+ )
total_removed = 0
# Get all queues
@@ -655,7 +685,9 @@ async def flush_jobs(request: FlushJobsRequest, current_user: User = Depends(cur
except Exception as e:
logger.error(f"Error deleting job {job_id}: {e}")
- if "canceled" in request.statuses: # RQ standard (US spelling), not "cancelled"
+ if (
+ "canceled" in request.statuses
+ ): # RQ standard (US spelling), not "cancelled"
registry = CanceledJobRegistry(queue=queue)
for job_id in registry.get_job_ids():
try:
@@ -737,8 +769,12 @@ async def flush_all_jobs(
registries.append(("finished", FinishedJobRegistry(queue=queue)))
for registry_name, registry in registries:
- job_ids = list(registry.get_job_ids()) # Convert to list to avoid iterator issues
- logger.info(f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}")
+ job_ids = list(
+ registry.get_job_ids()
+ ) # Convert to list to avoid iterator issues
+ logger.info(
+ f"Flushing {len(job_ids)} jobs from {queue_name}/{registry_name}"
+ )
for job_id in job_ids:
try:
@@ -748,7 +784,9 @@ async def flush_all_jobs(
# Skip session-level jobs (e.g., speech_detection, audio_persistence)
# These run for the entire session and should not be killed by test cleanup
if job.meta and job.meta.get("session_level"):
- logger.info(f"Skipping session-level job {job_id} ({job.description})")
+ logger.info(
+ f"Skipping session-level job {job_id} ({job.description})"
+ )
continue
# Handle running jobs differently to avoid worker deadlock
@@ -759,7 +797,9 @@ async def flush_all_jobs(
from rq.command import send_stop_job_command
send_stop_job_command(redis_conn, job_id)
- logger.info(f"Sent stop command to worker for job {job_id}")
+ logger.info(
+ f"Sent stop command to worker for job {job_id}"
+ )
# Don't delete yet - let worker move it to canceled/failed registry
# It will be cleaned up on next flush or by worker cleanup
continue
@@ -770,9 +810,13 @@ async def flush_all_jobs(
# If stop fails, try to cancel it (may already be finishing)
try:
job.cancel()
- logger.info(f"Cancelled job {job_id} after stop failed")
+ logger.info(
+ f"Cancelled job {job_id} after stop failed"
+ )
except Exception as cancel_error:
- logger.warning(f"Could not cancel job {job_id}: {cancel_error}")
+ logger.warning(
+ f"Could not cancel job {job_id}: {cancel_error}"
+ )
# For non-running jobs, safe to delete immediately
job.delete()
@@ -787,7 +831,9 @@ async def flush_all_jobs(
f"Removed stale job reference {job_id} from {registry_name} registry"
)
except Exception as reg_error:
- logger.error(f"Could not remove {job_id} from registry: {reg_error}")
+ logger.error(
+ f"Could not remove {job_id} from registry: {reg_error}"
+ )
# Also clean up audio streams and consumer locks
deleted_keys = 0
@@ -801,7 +847,9 @@ async def flush_all_jobs(
# Delete audio streams
cursor = 0
while True:
- cursor, keys = await async_redis.scan(cursor, match="audio:*", count=1000)
+ cursor, keys = await async_redis.scan(
+ cursor, match="audio:*", count=1000
+ )
if keys:
await async_redis.delete(*keys)
deleted_keys += len(keys)
@@ -811,7 +859,9 @@ async def flush_all_jobs(
# Delete consumer locks
cursor = 0
while True:
- cursor, keys = await async_redis.scan(cursor, match="consumer:*", count=1000)
+ cursor, keys = await async_redis.scan(
+ cursor, match="consumer:*", count=1000
+ )
if keys:
await async_redis.delete(*keys)
deleted_keys += len(keys)
@@ -840,7 +890,9 @@ async def flush_all_jobs(
except Exception as e:
logger.error(f"Failed to flush all jobs: {e}")
- raise HTTPException(status_code=500, detail=f"Failed to flush all jobs: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to flush all jobs: {str(e)}"
+ )
@router.get("/sessions")
@@ -860,7 +912,9 @@ async def get_redis_sessions(
session_keys = []
cursor = b"0"
while cursor and len(session_keys) < limit:
- cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=limit)
+ cursor, keys = await redis_client.scan(
+ cursor, match="audio:session:*", count=limit
+ )
session_keys.extend(keys[: limit - len(session_keys)])
# Get session info
@@ -872,8 +926,12 @@ async def get_redis_sessions(
session_id = key.decode().replace("audio:session:", "")
# Get conversation count for this session
- conversation_count_key = f"session:conversation_count:{session_id}"
- conversation_count_bytes = await redis_client.get(conversation_count_key)
+ conversation_count_key = (
+ f"session:conversation_count:{session_id}"
+ )
+ conversation_count_bytes = await redis_client.get(
+ conversation_count_key
+ )
conversation_count = (
int(conversation_count_bytes.decode())
if conversation_count_bytes
@@ -884,16 +942,25 @@ async def get_redis_sessions(
{
"session_id": session_id,
"user_id": session_data.get(b"user_id", b"").decode(),
- "client_id": session_data.get(b"client_id", b"").decode(),
- "stream_name": session_data.get(b"stream_name", b"").decode(),
+ "client_id": session_data.get(
+ b"client_id", b""
+ ).decode(),
+ "stream_name": session_data.get(
+ b"stream_name", b""
+ ).decode(),
"provider": session_data.get(b"provider", b"").decode(),
"mode": session_data.get(b"mode", b"").decode(),
"status": session_data.get(b"status", b"").decode(),
- "started_at": session_data.get(b"started_at", b"").decode(),
+ "started_at": session_data.get(
+ b"started_at", b""
+ ).decode(),
"chunks_published": int(
- session_data.get(b"chunks_published", b"0").decode() or 0
+ session_data.get(b"chunks_published", b"0").decode()
+ or 0
),
- "last_chunk_at": session_data.get(b"last_chunk_at", b"").decode(),
+ "last_chunk_at": session_data.get(
+ b"last_chunk_at", b""
+ ).decode(),
"conversation_count": conversation_count,
}
)
@@ -936,7 +1003,9 @@ async def clear_old_sessions(
session_keys = []
cursor = b"0"
while cursor:
- cursor, keys = await redis_client.scan(cursor, match="audio:session:*", count=100)
+ cursor, keys = await redis_client.scan(
+ cursor, match="audio:session:*", count=100
+ )
session_keys.extend(keys)
# Check each session and delete if old
@@ -955,13 +1024,18 @@ async def clear_old_sessions(
except Exception as e:
logger.error(f"Error processing session {key}: {e}")
- return {"deleted_count": deleted_count, "cutoff_seconds": older_than_seconds}
+ return {
+ "deleted_count": deleted_count,
+ "cutoff_seconds": older_than_seconds,
+ }
finally:
await redis_client.aclose()
except Exception as e:
logger.error(f"Failed to clear sessions: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"Failed to clear sessions: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to clear sessions: {str(e)}"
+ )
@router.get("/dashboard")
@@ -1013,11 +1087,17 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
if status_name == "queued":
job_ids = queue.job_ids[:limit]
elif status_name == "started": # RQ standard, not "processing"
- job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[:limit]
+ job_ids = list(StartedJobRegistry(queue=queue).get_job_ids())[
+ :limit
+ ]
elif status_name == "finished": # RQ standard, not "completed"
- job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[:limit]
+ job_ids = list(FinishedJobRegistry(queue=queue).get_job_ids())[
+ :limit
+ ]
elif status_name == "failed":
- job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[:limit]
+ job_ids = list(FailedJobRegistry(queue=queue).get_job_ids())[
+ :limit
+ ]
else:
continue
@@ -1028,7 +1108,9 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
# Check user permission
if not current_user.is_superuser:
- job_user_id = job.kwargs.get("user_id") if job.kwargs else None
+ job_user_id = (
+ job.kwargs.get("user_id") if job.kwargs else None
+ )
if job_user_id != str(current_user.user_id):
continue
@@ -1037,24 +1119,38 @@ async def fetch_jobs_by_status(status_name: str, limit: int = 100):
{
"job_id": job.id,
"job_type": (
- job.func_name.split(".")[-1] if job.func_name else "unknown"
+ job.func_name.split(".")[-1]
+ if job.func_name
+ else "unknown"
+ ),
+ "user_id": (
+ job.kwargs.get("user_id")
+ if job.kwargs
+ else None
),
- "user_id": (job.kwargs.get("user_id") if job.kwargs else None),
"status": status_name,
"priority": "normal", # RQ doesn't have priority concept
"data": {"description": job.description or ""},
"result": job.result,
"meta": job.meta if job.meta else {},
"kwargs": job.kwargs if job.kwargs else {},
- "error_message": (str(job.exc_info) if job.exc_info else None),
+ "error_message": (
+ str(job.exc_info) if job.exc_info else None
+ ),
"created_at": (
- job.created_at.isoformat() if job.created_at else None
+ job.created_at.isoformat()
+ if job.created_at
+ else None
),
"started_at": (
- job.started_at.isoformat() if job.started_at else None
+ job.started_at.isoformat()
+ if job.started_at
+ else None
),
"ended_at": (
- job.ended_at.isoformat() if job.ended_at else None
+ job.ended_at.isoformat()
+ if job.ended_at
+ else None
),
"retry_count": 0, # RQ doesn't track this by default
"max_retries": 0,
@@ -1170,7 +1266,11 @@ def get_job_status(job):
# Check user permission
if not current_user.is_superuser:
- job_user_id = job.kwargs.get("user_id") if job.kwargs else None
+ job_user_id = (
+ job.kwargs.get("user_id")
+ if job.kwargs
+ else None
+ )
if job_user_id != str(current_user.user_id):
continue
@@ -1186,13 +1286,19 @@ def get_job_status(job):
"queue": queue_name,
"status": get_job_status(job),
"created_at": (
- job.created_at.isoformat() if job.created_at else None
+ job.created_at.isoformat()
+ if job.created_at
+ else None
),
"started_at": (
- job.started_at.isoformat() if job.started_at else None
+ job.started_at.isoformat()
+ if job.started_at
+ else None
),
"ended_at": (
- job.ended_at.isoformat() if job.ended_at else None
+ job.ended_at.isoformat()
+ if job.ended_at
+ else None
),
"description": job.description or "",
"result": job.result,
@@ -1255,12 +1361,20 @@ async def fetch_events():
)
queued_jobs = results[0] if not isinstance(results[0], Exception) else []
- started_jobs = results[1] if not isinstance(results[1], Exception) else [] # RQ standard
- finished_jobs = results[2] if not isinstance(results[2], Exception) else [] # RQ standard
+ started_jobs = (
+ results[1] if not isinstance(results[1], Exception) else []
+ ) # RQ standard
+ finished_jobs = (
+ results[2] if not isinstance(results[2], Exception) else []
+ ) # RQ standard
failed_jobs = results[3] if not isinstance(results[3], Exception) else []
- stats = results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0}
+ stats = (
+ results[4] if not isinstance(results[4], Exception) else {"total_jobs": 0}
+ )
streaming_status = (
- results[5] if not isinstance(results[5], Exception) else {"active_sessions": []}
+ results[5]
+ if not isinstance(results[5], Exception)
+ else {"active_sessions": []}
)
events = results[6] if not isinstance(results[6], Exception) else []
recent_conversations = []
@@ -1279,7 +1393,9 @@ async def fetch_events():
{
"conversation_id": conv.conversation_id,
"user_id": str(conv.user_id) if conv.user_id else None,
- "created_at": (conv.created_at.isoformat() if conv.created_at else None),
+ "created_at": (
+ conv.created_at.isoformat() if conv.created_at else None
+ ),
"title": conv.title,
"summary": conv.summary,
"transcript_text": (
@@ -1307,4 +1423,6 @@ async def fetch_events():
except Exception as e:
logger.error(f"Failed to get dashboard data: {e}", exc_info=True)
- raise HTTPException(status_code=500, detail=f"Failed to get dashboard data: {str(e)}")
+ raise HTTPException(
+ status_code=500, detail=f"Failed to get dashboard data: {str(e)}"
+ )
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py
index 277d7dc1..0f580be7 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/system_routes.py
@@ -29,6 +29,7 @@
# Request models for memory config endpoints
class MemoryConfigRequest(BaseModel):
"""Request model for memory configuration validation and updates."""
+
config_yaml: str
@@ -68,8 +69,7 @@ async def get_diarization_settings(current_user: User = Depends(current_superuse
@router.post("/diarization-settings")
async def save_diarization_settings(
- settings: dict,
- current_user: User = Depends(current_superuser)
+ settings: dict, current_user: User = Depends(current_superuser)
):
"""Save diarization settings. Admin only."""
return await system_controller.save_diarization_settings_controller(settings)
@@ -83,32 +83,36 @@ async def get_misc_settings(current_user: User = Depends(current_superuser)):
@router.post("/misc-settings")
async def save_misc_settings(
- settings: dict,
- current_user: User = Depends(current_superuser)
+ settings: dict, current_user: User = Depends(current_superuser)
):
"""Save miscellaneous configuration settings. Admin only."""
return await system_controller.save_misc_settings_controller(settings)
@router.get("/cleanup-settings")
-async def get_cleanup_settings(
- current_user: User = Depends(current_superuser)
-):
+async def get_cleanup_settings(current_user: User = Depends(current_superuser)):
"""Get cleanup configuration settings. Admin only."""
return await system_controller.get_cleanup_settings_controller(current_user)
@router.post("/cleanup-settings")
async def save_cleanup_settings(
- auto_cleanup_enabled: bool = Body(..., description="Enable automatic cleanup of soft-deleted conversations"),
- retention_days: int = Body(..., ge=1, le=365, description="Number of days to keep soft-deleted conversations"),
- current_user: User = Depends(current_superuser)
+ auto_cleanup_enabled: bool = Body(
+ ..., description="Enable automatic cleanup of soft-deleted conversations"
+ ),
+ retention_days: int = Body(
+ ...,
+ ge=1,
+ le=365,
+ description="Number of days to keep soft-deleted conversations",
+ ),
+ current_user: User = Depends(current_superuser),
):
"""Save cleanup configuration settings. Admin only."""
return await system_controller.save_cleanup_settings_controller(
auto_cleanup_enabled=auto_cleanup_enabled,
retention_days=retention_days,
- user=current_user
+ user=current_user,
)
@@ -120,11 +124,12 @@ async def get_speaker_configuration(current_user: User = Depends(current_active_
@router.post("/speaker-configuration")
async def update_speaker_configuration(
- primary_speakers: list[dict],
- current_user: User = Depends(current_active_user)
+ primary_speakers: list[dict], current_user: User = Depends(current_active_user)
):
"""Update current user's primary speakers configuration."""
- return await system_controller.update_speaker_configuration(current_user, primary_speakers)
+ return await system_controller.update_speaker_configuration(
+ current_user, primary_speakers
+ )
@router.get("/enrolled-speakers")
@@ -141,6 +146,7 @@ async def get_speaker_service_status(current_user: User = Depends(current_superu
# LLM Operations Configuration Endpoints
+
@router.get("/admin/llm-operations")
async def get_llm_operations(current_user: User = Depends(current_superuser)):
"""Get LLM operation configurations. Admin only."""
@@ -149,8 +155,7 @@ async def get_llm_operations(current_user: User = Depends(current_superuser)):
@router.post("/admin/llm-operations")
async def save_llm_operations(
- operations: dict,
- current_user: User = Depends(current_superuser)
+ operations: dict, current_user: User = Depends(current_superuser)
):
"""Save LLM operation configurations. Admin only."""
return await system_controller.save_llm_operations(operations)
@@ -159,7 +164,7 @@ async def save_llm_operations(
@router.post("/admin/llm-operations/test")
async def test_llm_model(
model_name: Optional[str] = Body(None, embed=True),
- current_user: User = Depends(current_superuser)
+ current_user: User = Depends(current_superuser),
):
"""Test an LLM model connection with a trivial prompt. Admin only."""
return await system_controller.test_llm_model(model_name)
@@ -171,10 +176,11 @@ async def get_memory_config_raw(current_user: User = Depends(current_superuser))
"""Get memory configuration YAML from config.yml. Admin only."""
return await system_controller.get_memory_config_raw()
+
@router.post("/admin/memory/config/raw")
async def update_memory_config_raw(
config_yaml: str = Body(..., media_type="text/plain"),
- current_user: User = Depends(current_superuser)
+ current_user: User = Depends(current_superuser),
):
"""Save memory YAML to config.yml and hot-reload. Admin only."""
return await system_controller.update_memory_config_raw(config_yaml)
@@ -191,8 +197,7 @@ async def validate_memory_config_raw(
@router.post("/admin/memory/config/validate")
async def validate_memory_config(
- request: MemoryConfigRequest,
- current_user: User = Depends(current_superuser)
+ request: MemoryConfigRequest, current_user: User = Depends(current_superuser)
):
"""Validate memory configuration YAML sent as JSON (used by tests). Admin only."""
return await system_controller.validate_memory_config(request.config_yaml)
@@ -212,6 +217,7 @@ async def delete_all_user_memories(current_user: User = Depends(current_active_u
# Chat Configuration Management Endpoints
+
@router.get("/admin/chat/config", response_class=Response)
async def get_chat_config(current_user: User = Depends(current_superuser)):
"""Get chat configuration as YAML. Admin only."""
@@ -225,13 +231,12 @@ async def get_chat_config(current_user: User = Depends(current_superuser)):
@router.post("/admin/chat/config")
async def save_chat_config(
- request: Request,
- current_user: User = Depends(current_superuser)
+ request: Request, current_user: User = Depends(current_superuser)
):
"""Save chat configuration from YAML. Admin only."""
try:
yaml_content = await request.body()
- yaml_str = yaml_content.decode('utf-8')
+ yaml_str = yaml_content.decode("utf-8")
result = await system_controller.save_chat_config_yaml(yaml_str)
return JSONResponse(content=result)
except ValueError as e:
@@ -243,13 +248,12 @@ async def save_chat_config(
@router.post("/admin/chat/config/validate")
async def validate_chat_config(
- request: Request,
- current_user: User = Depends(current_superuser)
+ request: Request, current_user: User = Depends(current_superuser)
):
"""Validate chat configuration YAML. Admin only."""
try:
yaml_content = await request.body()
- yaml_str = yaml_content.decode('utf-8')
+ yaml_str = yaml_content.decode("utf-8")
result = await system_controller.validate_chat_config_yaml(yaml_str)
return JSONResponse(content=result)
except Exception as e:
@@ -259,6 +263,7 @@ async def validate_chat_config(
# Plugin Configuration Management Endpoints
+
@router.get("/admin/plugins/config", response_class=Response)
async def get_plugins_config(current_user: User = Depends(current_superuser)):
"""Get plugins configuration as YAML. Admin only."""
@@ -272,13 +277,12 @@ async def get_plugins_config(current_user: User = Depends(current_superuser)):
@router.post("/admin/plugins/config")
async def save_plugins_config(
- request: Request,
- current_user: User = Depends(current_superuser)
+ request: Request, current_user: User = Depends(current_superuser)
):
"""Save plugins configuration from YAML. Admin only."""
try:
yaml_content = await request.body()
- yaml_str = yaml_content.decode('utf-8')
+ yaml_str = yaml_content.decode("utf-8")
result = await system_controller.save_plugins_config_yaml(yaml_str)
return JSONResponse(content=result)
except ValueError as e:
@@ -290,13 +294,12 @@ async def save_plugins_config(
@router.post("/admin/plugins/config/validate")
async def validate_plugins_config(
- request: Request,
- current_user: User = Depends(current_superuser)
+ request: Request, current_user: User = Depends(current_superuser)
):
"""Validate plugins configuration YAML. Admin only."""
try:
yaml_content = await request.body()
- yaml_str = yaml_content.decode('utf-8')
+ yaml_str = yaml_content.decode("utf-8")
result = await system_controller.validate_plugins_config_yaml(yaml_str)
return JSONResponse(content=result)
except Exception as e:
@@ -306,6 +309,7 @@ async def validate_plugins_config(
# Structured Plugin Configuration Endpoints (Form-based UI)
+
@router.post("/admin/plugins/reload")
async def reload_plugins(
request: Request,
@@ -360,9 +364,16 @@ async def get_plugins_health(current_user: User = Depends(current_superuser)):
"""Get plugin health status for all registered plugins. Admin only."""
try:
from advanced_omi_backend.services.plugin_service import get_plugin_router
+
plugin_router = get_plugin_router()
if not plugin_router:
- return {"total": 0, "initialized": 0, "failed": 0, "registered": 0, "plugins": []}
+ return {
+ "total": 0,
+ "initialized": 0,
+ "failed": 0,
+ "registered": 0,
+ "plugins": [],
+ }
return plugin_router.get_health_summary()
except Exception as e:
logger.error(f"Failed to get plugins health: {e}")
@@ -377,6 +388,7 @@ async def get_plugins_connectivity(current_user: User = Depends(current_superuse
"""
try:
from advanced_omi_backend.services.plugin_service import get_plugin_router
+
plugin_router = get_plugin_router()
if not plugin_router:
return {"plugins": {}}
@@ -406,6 +418,7 @@ async def get_plugins_metadata(current_user: User = Depends(current_superuser)):
class PluginConfigRequest(BaseModel):
"""Request model for structured plugin configuration updates."""
+
orchestration: Optional[dict] = None
settings: Optional[dict] = None
env_vars: Optional[dict] = None
@@ -413,6 +426,7 @@ class PluginConfigRequest(BaseModel):
class CreatePluginRequest(BaseModel):
"""Request model for creating a new plugin."""
+
plugin_name: str
description: str
events: list[str] = []
@@ -421,12 +435,14 @@ class CreatePluginRequest(BaseModel):
class WritePluginCodeRequest(BaseModel):
"""Request model for writing plugin code."""
+
code: str
config_yml: Optional[str] = None
class PluginAssistantRequest(BaseModel):
"""Request model for plugin assistant chat."""
+
messages: list[dict]
@@ -434,7 +450,7 @@ class PluginAssistantRequest(BaseModel):
async def update_plugin_config_structured(
plugin_id: str,
config: PluginConfigRequest,
- current_user: User = Depends(current_superuser)
+ current_user: User = Depends(current_superuser),
):
"""Update plugin configuration from structured JSON (form data). Admin only.
@@ -445,7 +461,9 @@ async def update_plugin_config_structured(
"""
try:
config_dict = config.dict(exclude_none=True)
- result = await system_controller.update_plugin_config_structured(plugin_id, config_dict)
+ result = await system_controller.update_plugin_config_structured(
+ plugin_id, config_dict
+ )
return JSONResponse(content=result)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@@ -458,7 +476,7 @@ async def update_plugin_config_structured(
async def test_plugin_connection(
plugin_id: str,
config: PluginConfigRequest,
- current_user: User = Depends(current_superuser)
+ current_user: User = Depends(current_superuser),
):
"""Test plugin connection/configuration without saving. Admin only.
@@ -490,7 +508,9 @@ async def create_plugin(
plugin_code=request.plugin_code,
)
if not result.get("success"):
- raise HTTPException(status_code=400, detail=result.get("error", "Unknown error"))
+ raise HTTPException(
+ status_code=400, detail=result.get("error", "Unknown error")
+ )
return JSONResponse(content=result)
except HTTPException:
raise
@@ -513,7 +533,9 @@ async def write_plugin_code(
config_yml=request.config_yml,
)
if not result.get("success"):
- raise HTTPException(status_code=400, detail=result.get("error", "Unknown error"))
+ raise HTTPException(
+ status_code=400, detail=result.get("error", "Unknown error")
+ )
return JSONResponse(content=result)
except HTTPException:
raise
@@ -535,7 +557,9 @@ async def delete_plugin(
remove_files=remove_files,
)
if not result.get("success"):
- raise HTTPException(status_code=400, detail=result.get("error", "Unknown error"))
+ raise HTTPException(
+ status_code=400, detail=result.get("error", "Unknown error")
+ )
return JSONResponse(content=result)
except HTTPException:
raise
@@ -572,25 +596,34 @@ async def event_stream():
@router.get("/streaming/status")
-async def get_streaming_status(request: Request, current_user: User = Depends(current_superuser)):
+async def get_streaming_status(
+ request: Request, current_user: User = Depends(current_superuser)
+):
"""Get status of active streaming sessions and Redis Streams health. Admin only."""
return await session_controller.get_streaming_status(request)
@router.post("/streaming/cleanup")
-async def cleanup_stuck_stream_workers(request: Request, current_user: User = Depends(current_superuser)):
+async def cleanup_stuck_stream_workers(
+ request: Request, current_user: User = Depends(current_superuser)
+):
"""Clean up stuck Redis Stream workers and pending messages. Admin only."""
return await queue_controller.cleanup_stuck_stream_workers(request)
@router.post("/streaming/cleanup-sessions")
-async def cleanup_old_sessions(request: Request, max_age_seconds: int = 3600, current_user: User = Depends(current_superuser)):
+async def cleanup_old_sessions(
+ request: Request,
+ max_age_seconds: int = 3600,
+ current_user: User = Depends(current_superuser),
+):
"""Clean up old session tracking metadata. Admin only."""
return await session_controller.cleanup_old_sessions(request, max_age_seconds)
# Memory Provider Configuration Endpoints
+
@router.get("/admin/memory/provider")
async def get_memory_provider(current_user: User = Depends(current_superuser)):
"""Get current memory provider configuration. Admin only."""
@@ -600,7 +633,7 @@ async def get_memory_provider(current_user: User = Depends(current_superuser)):
@router.post("/admin/memory/provider")
async def set_memory_provider(
provider: str = Body(..., embed=True),
- current_user: User = Depends(current_superuser)
+ current_user: User = Depends(current_superuser),
):
"""Set memory provider and restart backend services. Admin only."""
return await system_controller.set_memory_provider(provider)
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py
index 349fe33d..a5488450 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/test_routes.py
@@ -37,7 +37,7 @@ async def clear_test_plugin_events():
# Clear events from all plugins that have storage
for plugin_id, plugin in plugin_router.plugins.items():
- if hasattr(plugin, 'storage') and plugin.storage:
+ if hasattr(plugin, "storage") and plugin.storage:
try:
cleared = await plugin.storage.clear_events()
total_cleared += cleared
@@ -45,10 +45,7 @@ async def clear_test_plugin_events():
except Exception as e:
logger.error(f"Error clearing events from plugin '{plugin_id}': {e}")
- return {
- "message": "Test plugin events cleared",
- "events_cleared": total_cleared
- }
+ return {"message": "Test plugin events cleared", "events_cleared": total_cleared}
@router.get("/plugins/events/count")
@@ -65,23 +62,33 @@ async def get_test_plugin_event_count(event_type: Optional[str] = None):
plugin_router = get_plugin_router()
if not plugin_router:
- return {"count": 0, "event_type": event_type, "message": "No plugin router initialized"}
+ return {
+ "count": 0,
+ "event_type": event_type,
+ "message": "No plugin router initialized",
+ }
# Get count from first plugin with storage (usually test_event plugin)
for plugin_id, plugin in plugin_router.plugins.items():
- if hasattr(plugin, 'storage') and plugin.storage:
+ if hasattr(plugin, "storage") and plugin.storage:
try:
count = await plugin.storage.get_event_count(event_type)
return {
"count": count,
"event_type": event_type,
- "plugin_id": plugin_id
+ "plugin_id": plugin_id,
}
except Exception as e:
- logger.error(f"Error getting event count from plugin '{plugin_id}': {e}")
+ logger.error(
+ f"Error getting event count from plugin '{plugin_id}': {e}"
+ )
raise HTTPException(status_code=500, detail=str(e))
- return {"count": 0, "event_type": event_type, "message": "No plugin with storage found"}
+ return {
+ "count": 0,
+ "event_type": event_type,
+ "message": "No plugin with storage found",
+ }
@router.get("/plugins/events")
@@ -102,7 +109,7 @@ async def get_test_plugin_events(event_type: Optional[str] = None):
# Get events from first plugin with storage
for plugin_id, plugin in plugin_router.plugins.items():
- if hasattr(plugin, 'storage') and plugin.storage:
+ if hasattr(plugin, "storage") and plugin.storage:
try:
if event_type:
events = await plugin.storage.get_events_by_type(event_type)
@@ -113,7 +120,7 @@ async def get_test_plugin_events(event_type: Optional[str] = None):
"events": events,
"count": len(events),
"event_type": event_type,
- "plugin_id": plugin_id
+ "plugin_id": plugin_id,
}
except Exception as e:
logger.error(f"Error getting events from plugin '{plugin_id}': {e}")
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py
index 12ed5c63..f66b5d39 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/user_routes.py
@@ -24,13 +24,17 @@ async def get_users(current_user: User = Depends(current_superuser)):
@router.post("")
-async def create_user(user_data: UserCreate, current_user: User = Depends(current_superuser)):
+async def create_user(
+ user_data: UserCreate, current_user: User = Depends(current_superuser)
+):
"""Create a new user. Admin only."""
return await user_controller.create_user(user_data)
@router.put("/{user_id}")
-async def update_user(user_id: str, user_data: UserUpdate, current_user: User = Depends(current_superuser)):
+async def update_user(
+ user_id: str, user_data: UserUpdate, current_user: User = Depends(current_superuser)
+):
"""Update a user. Admin only."""
return await user_controller.update_user(user_id, user_data)
@@ -43,4 +47,6 @@ async def delete_user(
delete_memories: bool = False,
):
"""Delete a user and optionally their associated data. Admin only."""
- return await user_controller.delete_user(user_id, delete_conversations, delete_memories)
+ return await user_controller.delete_user(
+ user_id, delete_conversations, delete_memories
+ )
diff --git a/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py b/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py
index 4b244343..11a24f88 100644
--- a/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py
+++ b/backends/advanced/src/advanced_omi_backend/routers/modules/websocket_routes.py
@@ -19,6 +19,7 @@
# Create router
router = APIRouter(tags=["websocket"])
+
@router.websocket("/ws")
async def ws_endpoint(
ws: WebSocket,
@@ -42,11 +43,13 @@ async def ws_endpoint(
codec = codec.lower()
if codec not in ["pcm", "opus"]:
logger.warning(f"Unsupported codec requested: {codec}")
- await ws.close(code=1008, reason=f"Unsupported codec: {codec}. Supported: pcm, opus")
+ await ws.close(
+ code=1008, reason=f"Unsupported codec: {codec}. Supported: pcm, opus"
+ )
return
# Route to appropriate handler
if codec == "opus":
await handle_omi_websocket(ws, token, device_name)
else:
- await handle_pcm_websocket(ws, token, device_name)
\ No newline at end of file
+ await handle_pcm_websocket(ws, token, device_name)
diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_service.py b/backends/advanced/src/advanced_omi_backend/services/audio_service.py
index 992ede75..09914e24 100644
--- a/backends/advanced/src/advanced_omi_backend/services/audio_service.py
+++ b/backends/advanced/src/advanced_omi_backend/services/audio_service.py
@@ -42,9 +42,9 @@ def __init__(self, redis_url: Optional[str] = None):
self.memory_events_stream = "memory:events"
# Consumer group names (action verbs - what they DO)
- self.audio_writer = "audio-file-writer" # Writes audio chunks to WAV files
- self.memory_enqueuer = "memory-job-enqueuer" # Enqueues memory extraction jobs
- self.event_listener = "event-listener" # Listens for completion events
+ self.audio_writer = "audio-file-writer" # Writes audio chunks to WAV files
+ self.memory_enqueuer = "memory-job-enqueuer" # Enqueues memory extraction jobs
+ self.event_listener = "event-listener" # Listens for completion events
async def connect(self):
"""Connect to Redis with connection pooling."""
@@ -55,7 +55,7 @@ async def connect(self):
max_connections=20, # Allow multiple concurrent operations
socket_keepalive=True,
socket_connect_timeout=5,
- retry_on_timeout=True
+ retry_on_timeout=True,
)
logger.info(f"Audio stream service connected to Redis at {self.redis_url}")
@@ -84,7 +84,7 @@ async def publish_audio_chunk(
user_email: str,
audio_chunk: AudioChunk,
audio_uuid: Optional[str] = None,
- timestamp: Optional[int] = None
+ timestamp: Optional[int] = None,
) -> str:
"""
Publish audio chunk to Redis Stream.
@@ -135,12 +135,11 @@ async def publish_audio_chunk(
# Ensure consumer group exists for this stream
try:
await self.redis.xgroup_create(
- stream_name,
- self.audio_writer,
- id="0",
- mkstream=True
+ stream_name, self.audio_writer, id="0", mkstream=True
+ )
+ audio_logger.debug(
+ f"Created consumer group {self.audio_writer} for {stream_name}"
)
- audio_logger.debug(f"Created consumer group {self.audio_writer} for {stream_name}")
except aioredis.ResponseError as e:
if "BUSYGROUP" not in str(e):
raise
@@ -152,7 +151,7 @@ async def publish_transcript_event(
audio_uuid: str,
conversation_id: str,
status: str,
- error: Optional[str] = None
+ error: Optional[str] = None,
):
"""
Publish transcript completion event.
@@ -189,7 +188,7 @@ async def publish_transcript_event(
self.transcript_events_stream,
self.memory_enqueuer,
id="0",
- mkstream=True
+ mkstream=True,
)
except aioredis.ResponseError as e:
if "BUSYGROUP" not in str(e):
@@ -200,7 +199,7 @@ async def publish_memory_event(
conversation_id: str,
status: str,
memory_count: int = 0,
- error: Optional[str] = None
+ error: Optional[str] = None,
):
"""
Publish memory processing event.
@@ -234,21 +233,14 @@ async def publish_memory_event(
# Ensure consumer group exists
try:
await self.redis.xgroup_create(
- self.memory_events_stream,
- self.event_listener,
- id="0",
- mkstream=True
+ self.memory_events_stream, self.event_listener, id="0", mkstream=True
)
except aioredis.ResponseError as e:
if "BUSYGROUP" not in str(e):
raise
async def consume_audio_stream(
- self,
- consumer_name: str,
- callback,
- block_ms: int = 5000,
- count: int = 10
+ self, consumer_name: str, callback, block_ms: int = 5000, count: int = 10
):
"""
Consume audio chunks from all client streams.
@@ -288,7 +280,7 @@ async def consume_audio_stream(
consumer_name,
streams_dict,
count=count,
- block=block_ms
+ block=block_ms,
)
for stream_name, stream_messages in messages:
@@ -299,15 +291,13 @@ async def consume_audio_stream(
# Acknowledge message
await self.redis.xack(
- stream_name,
- self.audio_writer,
- message_id
+ stream_name, self.audio_writer, message_id
)
except Exception as e:
logger.error(
f"Error processing audio message {message_id.decode()}: {e}",
- exc_info=True
+ exc_info=True,
)
except Exception as e:
diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py
index 19b76874..4cfd5a90 100644
--- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py
+++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/aggregator.py
@@ -49,8 +49,12 @@ async def get_session_results(self, session_id: str) -> list[dict]:
"text": fields[b"text"].decode(),
"confidence": float(fields[b"confidence"].decode()),
"provider": fields[b"provider"].decode(),
- "chunk_id": fields.get(b"chunk_id", b"unknown").decode(), # Handle missing chunk_id gracefully
- "processing_time": float(fields.get(b"processing_time", b"0.0").decode()),
+ "chunk_id": fields.get(
+ b"chunk_id", b"unknown"
+ ).decode(), # Handle missing chunk_id gracefully
+ "processing_time": float(
+ fields.get(b"processing_time", b"0.0").decode()
+ ),
"timestamp": float(fields[b"timestamp"].decode()),
}
@@ -104,7 +108,7 @@ async def get_combined_results(self, session_id: str) -> dict:
"segments": [],
"chunk_count": 0,
"total_confidence": 0.0,
- "provider": None
+ "provider": None,
}
# Combine ALL final results for cumulative speech detection
@@ -150,7 +154,7 @@ async def get_combined_results(self, session_id: str) -> dict:
"segments": all_segments,
"chunk_count": len(results),
"total_confidence": avg_confidence,
- "provider": provider
+ "provider": provider,
}
logger.info(
@@ -162,10 +166,7 @@ async def get_combined_results(self, session_id: str) -> dict:
return combined
async def get_realtime_results(
- self,
- session_id: str,
- last_id: str = "0",
- timeout_ms: int = 1000
+ self, session_id: str, last_id: str = "0", timeout_ms: int = 1000
) -> tuple[list[dict], str]:
"""
Get new results since last_id (for real-time streaming).
@@ -183,9 +184,7 @@ async def get_realtime_results(
try:
# Read new messages since last_id
messages = await self.redis_client.xread(
- {stream_name: last_id},
- count=10,
- block=timeout_ms
+ {stream_name: last_id}, count=10, block=timeout_ms
)
results = []
@@ -199,14 +198,18 @@ async def get_realtime_results(
"text": fields[b"text"].decode(),
"confidence": float(fields[b"confidence"].decode()),
"provider": fields[b"provider"].decode(),
- "chunk_id": fields.get(b"chunk_id", b"unknown").decode(), # Handle missing chunk_id gracefully
+ "chunk_id": fields.get(
+ b"chunk_id", b"unknown"
+ ).decode(), # Handle missing chunk_id gracefully
}
# Optional fields
if b"words" in fields:
result["words"] = json.loads(fields[b"words"].decode())
if b"segments" in fields:
- result["segments"] = json.loads(fields[b"segments"].decode())
+ result["segments"] = json.loads(
+ fields[b"segments"].decode()
+ )
results.append(result)
new_last_id = message_id.decode()
@@ -214,5 +217,7 @@ async def get_realtime_results(
return results, new_last_id
except Exception as e:
- logger.error(f"π Error getting realtime results for session {session_id}: {e}")
+ logger.error(
+ f"π Error getting realtime results for session {session_id}: {e}"
+ )
return [], last_id
diff --git a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py
index 455ebebe..8902ae5d 100644
--- a/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py
+++ b/backends/advanced/src/advanced_omi_backend/services/audio_stream/consumer.py
@@ -23,7 +23,9 @@ class BaseAudioStreamConsumer(ABC):
Writes results to transcription:results:{session_id}.
"""
- def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks: int = 30):
+ def __init__(
+ self, provider_name: str, redis_client: redis.Redis, buffer_chunks: int = 30
+ ):
"""
Initialize consumer.
@@ -50,7 +52,9 @@ def __init__(self, provider_name: str, redis_client: redis.Redis, buffer_chunks:
self.active_streams = {} # {stream_name: True}
# Buffering: accumulate chunks per session
- self.session_buffers = {} # {session_id: {"chunks": [], "chunk_ids": [], "sample_rate": int}}
+ self.session_buffers = (
+ {}
+ ) # {session_id: {"chunks": [], "chunk_ids": [], "sample_rate": int}}
async def discover_streams(self) -> list[str]:
"""
@@ -67,7 +71,9 @@ async def discover_streams(self) -> list[str]:
cursor, match=self.stream_pattern, count=100
)
if keys:
- streams.extend([k.decode() if isinstance(k, bytes) else k for k in keys])
+ streams.extend(
+ [k.decode() if isinstance(k, bytes) else k for k in keys]
+ )
return streams
@@ -76,16 +82,17 @@ async def setup_consumer_group(self, stream_name: str):
# Create consumer group (ignore error if already exists)
try:
await self.redis_client.xgroup_create(
- stream_name,
- self.group_name,
- "0",
- mkstream=True
+ stream_name, self.group_name, "0", mkstream=True
+ )
+ logger.debug(
+ f"β‘οΈ Created consumer group {self.group_name} for {stream_name}"
)
- logger.debug(f"β‘οΈ Created consumer group {self.group_name} for {stream_name}")
except redis_exceptions.ResponseError as e:
if "BUSYGROUP" not in str(e):
raise
- logger.debug(f"β‘οΈ Consumer group {self.group_name} already exists for {stream_name}")
+ logger.debug(
+ f"β‘οΈ Consumer group {self.group_name} already exists for {stream_name}"
+ )
async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000):
"""
@@ -100,7 +107,7 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000):
try:
# Get all consumers in the group
consumers = await self.redis_client.execute_command(
- 'XINFO', 'CONSUMERS', self.input_stream, self.group_name
+ "XINFO", "CONSUMERS", self.input_stream, self.group_name
)
if not consumers:
@@ -115,9 +122,13 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000):
# Parse consumer fields (flat key-value pairs within each consumer)
for j in range(0, len(consumer_info), 2):
- if j+1 < len(consumer_info):
- key = consumer_info[j].decode() if isinstance(consumer_info[j], bytes) else str(consumer_info[j])
- value = consumer_info[j+1]
+ if j + 1 < len(consumer_info):
+ key = (
+ consumer_info[j].decode()
+ if isinstance(consumer_info[j], bytes)
+ else str(consumer_info[j])
+ )
+ value = consumer_info[j + 1]
if isinstance(value, bytes):
try:
value = value.decode()
@@ -142,11 +153,19 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000):
if is_dead:
# If consumer has pending messages, claim and ACK them first
if consumer_pending > 0:
- logger.info(f"π Claiming {consumer_pending} pending messages from dead consumer {consumer_name}")
+ logger.info(
+ f"π Claiming {consumer_pending} pending messages from dead consumer {consumer_name}"
+ )
try:
pending_messages = await self.redis_client.execute_command(
- 'XPENDING', self.input_stream, self.group_name, '-', '+', str(consumer_pending), consumer_name
+ "XPENDING",
+ self.input_stream,
+ self.group_name,
+ "-",
+ "+",
+ str(consumer_pending),
+ consumer_name,
)
# Parse pending messages (groups of 4: msg_id, consumer, idle_ms, delivery_count)
@@ -159,28 +178,49 @@ async def cleanup_dead_consumers(self, idle_threshold_ms: int = 30000):
# Claim to ourselves and ACK immediately
try:
await self.redis_client.execute_command(
- 'XCLAIM', self.input_stream, self.group_name, self.consumer_name, '0', msg_id
+ "XCLAIM",
+ self.input_stream,
+ self.group_name,
+ self.consumer_name,
+ "0",
+ msg_id,
+ )
+ await self.redis_client.xack(
+ self.input_stream, self.group_name, msg_id
)
- await self.redis_client.xack(self.input_stream, self.group_name, msg_id)
claimed_count += 1
except Exception as claim_error:
- logger.warning(f"Failed to claim/ack message {msg_id}: {claim_error}")
+ logger.warning(
+ f"Failed to claim/ack message {msg_id}: {claim_error}"
+ )
except Exception as pending_error:
- logger.warning(f"Failed to process pending messages for {consumer_name}: {pending_error}")
+ logger.warning(
+ f"Failed to process pending messages for {consumer_name}: {pending_error}"
+ )
# Delete the dead consumer
try:
await self.redis_client.execute_command(
- 'XGROUP', 'DELCONSUMER', self.input_stream, self.group_name, consumer_name
+ "XGROUP",
+ "DELCONSUMER",
+ self.input_stream,
+ self.group_name,
+ consumer_name,
)
deleted_count += 1
- logger.info(f"π§Ή Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)")
+ logger.info(
+ f"π§Ή Deleted dead consumer {consumer_name} (idle: {consumer_idle_ms}ms)"
+ )
except Exception as delete_error:
- logger.warning(f"Failed to delete consumer {consumer_name}: {delete_error}")
+ logger.warning(
+ f"Failed to delete consumer {consumer_name}: {delete_error}"
+ )
if deleted_count > 0 or claimed_count > 0:
- logger.info(f"β
Cleanup complete: deleted {deleted_count} dead consumers, claimed {claimed_count} pending messages")
+ logger.info(
+ f"β
Cleanup complete: deleted {deleted_count} dead consumers, claimed {claimed_count} pending messages"
+ )
except Exception as e:
logger.error(f"β Failed to cleanup dead consumers: {e}", exc_info=True)
@@ -204,7 +244,9 @@ async def transcribe_audio(self, audio_data: bytes, sample_rate: int) -> dict:
async def start_consuming(self):
"""Discover and consume from multiple streams using Redis consumer groups."""
self.running = True
- logger.info(f"β‘οΈ Starting dynamic stream consumer: {self.consumer_name} (group: {self.group_name})")
+ logger.info(
+ f"β‘οΈ Starting dynamic stream consumer: {self.consumer_name} (group: {self.group_name})"
+ )
last_discovery = 0
discovery_interval = 10 # Discover new streams every 10 seconds
@@ -223,7 +265,9 @@ async def start_consuming(self):
# Setup consumer group for this stream (no manual lock needed)
await self.setup_consumer_group(stream_name)
self.active_streams[stream_name] = True
- logger.info(f"β
Now consuming from {stream_name} (group: {self.group_name})")
+ logger.info(
+ f"β
Now consuming from {stream_name} (group: {self.group_name})"
+ )
last_discovery = current_time
@@ -241,14 +285,18 @@ async def start_consuming(self):
self.consumer_name,
streams_dict,
count=1,
- block=1000 # Block for 1 second
+ block=1000, # Block for 1 second
)
if not messages:
continue
for stream_name, msgs in messages:
- stream_name_str = stream_name.decode() if isinstance(stream_name, bytes) else stream_name
+ stream_name_str = (
+ stream_name.decode()
+ if isinstance(stream_name, bytes)
+ else stream_name
+ )
for message_id, fields in msgs:
await self.process_message(message_id, fields, stream_name_str)
@@ -260,11 +308,15 @@ async def start_consuming(self):
# Extract stream name from error message
for stream_name in list(self.active_streams.keys()):
if stream_name in error_msg:
- logger.warning(f"β‘οΈ [{self.consumer_name}] Stream {stream_name} was deleted, removing from active streams")
+ logger.warning(
+ f"β‘οΈ [{self.consumer_name}] Stream {stream_name} was deleted, removing from active streams"
+ )
# Remove from active streams
del self.active_streams[stream_name]
- logger.info(f"β‘οΈ [{self.consumer_name}] Removed {stream_name}, {len(self.active_streams)} streams remaining")
+ logger.info(
+ f"β‘οΈ [{self.consumer_name}] Removed {stream_name}, {len(self.active_streams)} streams remaining"
+ )
break
else:
# Other ResponseError - log and continue
@@ -273,7 +325,10 @@ async def start_consuming(self):
await asyncio.sleep(1)
except Exception as e:
- logger.error(f"β‘οΈ [{self.consumer_name}] Error in dynamic consume loop: {e}", exc_info=True)
+ logger.error(
+ f"β‘οΈ [{self.consumer_name}] Error in dynamic consume loop: {e}",
+ exc_info=True,
+ )
await asyncio.sleep(1)
async def process_message(self, message_id: bytes, fields: dict, stream_name: str):
@@ -295,7 +350,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
# Check for end-of-session signal
if chunk_id == "END":
- logger.info(f"β‘οΈ [{self.consumer_name}] {self.provider_name}: Received END signal for session {session_id}")
+ logger.info(
+ f"β‘οΈ [{self.consumer_name}] {self.provider_name}: Received END signal for session {session_id}"
+ )
# Flush buffer for this session if it has any chunks
if session_id in self.session_buffers:
@@ -306,7 +363,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
# Combine buffered chunks
combined_audio = b"".join(buffer["chunks"])
- combined_chunk_id = f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}"
+ combined_chunk_id = (
+ f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}"
+ )
logger.info(
f"β‘οΈ [{self.consumer_name}] {self.provider_name}: Flushing {len(buffer['chunks'])} remaining chunks "
@@ -314,7 +373,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
)
# Transcribe remaining audio
- result = await self.transcribe_audio(combined_audio, buffer["sample_rate"])
+ result = await self.transcribe_audio(
+ combined_audio, buffer["sample_rate"]
+ )
# Store result
processing_time = time.time() - start_time
@@ -325,19 +386,27 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
confidence=result.get("confidence", 0.0),
words=result.get("words", []),
segments=result.get("segments", []),
- processing_time=processing_time
+ processing_time=processing_time,
)
# ACK all buffered messages
for msg_id in buffer["message_ids"]:
- await self.redis_client.xack(stream_name, self.group_name, msg_id)
+ await self.redis_client.xack(
+ stream_name, self.group_name, msg_id
+ )
# Trim stream to remove ACKed messages (keep only last 1000 for safety)
try:
- await self.redis_client.xtrim(stream_name, maxlen=1000, approximate=True)
- logger.debug(f"π§Ή Trimmed audio stream {stream_name} to max 1000 entries")
+ await self.redis_client.xtrim(
+ stream_name, maxlen=1000, approximate=True
+ )
+ logger.debug(
+ f"π§Ή Trimmed audio stream {stream_name} to max 1000 entries"
+ )
except Exception as trim_error:
- logger.warning(f"Failed to trim stream {stream_name}: {trim_error}")
+ logger.warning(
+ f"Failed to trim stream {stream_name}: {trim_error}"
+ )
logger.info(
f"β‘οΈ [{self.consumer_name}] {self.provider_name}: Flushed buffer for session {session_id} "
@@ -358,7 +427,7 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
"chunk_ids": [],
"sample_rate": sample_rate,
"message_ids": [],
- "audio_offset_seconds": 0.0 # Track cumulative audio duration
+ "audio_offset_seconds": 0.0, # Track cumulative audio duration
}
# Add to buffer (skip empty audio data from END signals)
@@ -382,14 +451,24 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
# Combine buffered chunks
combined_audio = b"".join(buffer["chunks"])
- combined_chunk_id = f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}"
+ combined_chunk_id = (
+ f"{buffer['chunk_ids'][0]}-{buffer['chunk_ids'][-1]}"
+ )
# Calculate audio duration for this chunk (16-bit PCM, 1 channel)
- audio_duration_seconds = len(combined_audio) / (sample_rate * 2) # 2 bytes per sample
+ audio_duration_seconds = len(combined_audio) / (
+ sample_rate * 2
+ ) # 2 bytes per sample
audio_offset = buffer["audio_offset_seconds"]
# Log individual chunk IDs to detect duplicates
- chunk_list = ", ".join(buffer['chunk_ids'][:5] + ['...'] + buffer['chunk_ids'][-5:]) if len(buffer['chunk_ids']) > 10 else ", ".join(buffer['chunk_ids'])
+ chunk_list = (
+ ", ".join(
+ buffer["chunk_ids"][:5] + ["..."] + buffer["chunk_ids"][-5:]
+ )
+ if len(buffer["chunk_ids"]) > 10
+ else ", ".join(buffer["chunk_ids"])
+ )
logger.info(
f"β‘οΈ [{self.consumer_name}] {self.provider_name}: Transcribing {len(buffer['chunks'])} chunks "
@@ -415,7 +494,9 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
adjusted_word["end"] = word.get("end", 0.0) + audio_offset
adjusted_words.append(adjusted_word)
- logger.debug(f"β‘οΈ [{self.consumer_name}] Adjusted {len(adjusted_segments)} segments by +{audio_offset:.1f}s")
+ logger.debug(
+ f"β‘οΈ [{self.consumer_name}] Adjusted {len(adjusted_segments)} segments by +{audio_offset:.1f}s"
+ )
# Store result with adjusted timestamps
processing_time = time.time() - start_time
@@ -426,7 +507,7 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
confidence=result.get("confidence", 0.0),
words=adjusted_words,
segments=adjusted_segments,
- processing_time=processing_time
+ processing_time=processing_time,
)
# Update audio offset for next chunk
@@ -438,8 +519,12 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
# Trim stream to remove ACKed messages (keep only last 1000 for safety)
try:
- await self.redis_client.xtrim(stream_name, maxlen=1000, approximate=True)
- logger.debug(f"π§Ή Trimmed audio stream {stream_name} to max 1000 entries")
+ await self.redis_client.xtrim(
+ stream_name, maxlen=1000, approximate=True
+ )
+ logger.debug(
+ f"π§Ή Trimmed audio stream {stream_name} to max 1000 entries"
+ )
except Exception as trim_error:
logger.warning(f"Failed to trim stream {stream_name}: {trim_error}")
@@ -456,7 +541,7 @@ async def process_message(self, message_id: bytes, fields: dict, stream_name: st
except Exception as e:
logger.error(
f"β‘οΈ [{self.consumer_name}] {self.provider_name}: Failed to process chunk {fields.get(b'chunk_id', b'unknown').decode()}: {e}",
- exc_info=True
+ exc_info=True,
)
async def store_result(
@@ -467,7 +552,7 @@ async def store_result(
confidence: float,
words: list,
segments: list,
- processing_time: float
+ processing_time: float,
):
"""
Store transcription result in Redis Stream.
@@ -502,7 +587,7 @@ async def store_result(
session_results_stream,
result_data,
maxlen=1000, # Keep max 1k results per session
- approximate=True
+ approximate=True,
)
logger.debug(
diff --git a/backends/advanced/src/advanced_omi_backend/services/capabilities.py b/backends/advanced/src/advanced_omi_backend/services/capabilities.py
index 85837920..fc9d7de0 100644
--- a/backends/advanced/src/advanced_omi_backend/services/capabilities.py
+++ b/backends/advanced/src/advanced_omi_backend/services/capabilities.py
@@ -26,7 +26,9 @@ class TranscriptCapability(str, Enum):
WORD_TIMESTAMPS = "word_timestamps" # Word-level timing data
SEGMENTS = "segments" # Speaker segments in output
- DIARIZATION = "diarization" # Speaker labels in segments (Speaker 0, Speaker 1, etc.)
+ DIARIZATION = (
+ "diarization" # Speaker labels in segments (Speaker 0, Speaker 1, etc.)
+ )
class FeatureRequirement(str, Enum):
@@ -99,7 +101,9 @@ def check_requirements(
return True, "OK"
-def get_provider_capabilities(transcript_version: "Conversation.TranscriptVersion") -> dict:
+def get_provider_capabilities(
+ transcript_version: "Conversation.TranscriptVersion",
+) -> dict:
"""
Get provider capabilities from transcript version metadata.
diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py
index a12cd540..2955b57f 100644
--- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py
+++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/entity_extractor.py
@@ -161,35 +161,45 @@ def _parse_extraction_response(content: str) -> ExtractionResult:
for e in data.get("entities", []):
if isinstance(e, dict) and e.get("name"):
entity_type = _normalize_entity_type(e.get("type", "thing"))
- entities.append(ExtractedEntity(
- name=e["name"].strip(),
- type=entity_type,
- details=e.get("details"),
- icon=e.get("icon") or _get_default_icon(entity_type),
- when=e.get("when"),
- ))
+ entities.append(
+ ExtractedEntity(
+ name=e["name"].strip(),
+ type=entity_type,
+ details=e.get("details"),
+ icon=e.get("icon") or _get_default_icon(entity_type),
+ when=e.get("when"),
+ )
+ )
# Parse relationships
relationships = []
for r in data.get("relationships", []):
if isinstance(r, dict) and r.get("subject") and r.get("object"):
- relationships.append(ExtractedRelationship(
- subject=r["subject"].strip(),
- relation=_normalize_relation_type(r.get("relation", "related_to")),
- object=r["object"].strip(),
- ))
+ relationships.append(
+ ExtractedRelationship(
+ subject=r["subject"].strip(),
+ relation=_normalize_relation_type(
+ r.get("relation", "related_to")
+ ),
+ object=r["object"].strip(),
+ )
+ )
# Parse promises
promises = []
for p in data.get("promises", []):
if isinstance(p, dict) and p.get("action"):
- promises.append(ExtractedPromise(
- action=p["action"].strip(),
- to=p.get("to"),
- deadline=p.get("deadline"),
- ))
-
- logger.info(f"Extracted {len(entities)} entities, {len(relationships)} relationships, {len(promises)} promises")
+ promises.append(
+ ExtractedPromise(
+ action=p["action"].strip(),
+ to=p.get("to"),
+ deadline=p.get("deadline"),
+ )
+ )
+
+ logger.info(
+ f"Extracted {len(entities)} entities, {len(relationships)} relationships, {len(promises)} promises"
+ )
return ExtractionResult(
entities=entities,
relationships=relationships,
@@ -274,7 +284,9 @@ def _get_default_icon(entity_type: str) -> str:
return icons.get(entity_type, "π")
-def parse_natural_datetime(text: Optional[str], reference_date: Optional[datetime] = None) -> Optional[datetime]:
+def parse_natural_datetime(
+ text: Optional[str], reference_date: Optional[datetime] = None
+) -> Optional[datetime]:
"""Parse natural language date/time into datetime.
Args:
@@ -292,9 +304,20 @@ def parse_natural_datetime(text: Optional[str], reference_date: Optional[datetim
# Simple patterns - can be extended with dateparser library later
weekdays = {
- "monday": 0, "tuesday": 1, "wednesday": 2, "thursday": 3,
- "friday": 4, "saturday": 5, "sunday": 6,
- "mon": 0, "tue": 1, "wed": 2, "thu": 3, "fri": 4, "sat": 5, "sun": 6,
+ "monday": 0,
+ "tuesday": 1,
+ "wednesday": 2,
+ "thursday": 3,
+ "friday": 4,
+ "saturday": 5,
+ "sunday": 6,
+ "mon": 0,
+ "tue": 1,
+ "wed": 2,
+ "thu": 3,
+ "fri": 4,
+ "sat": 5,
+ "sun": 6,
}
try:
@@ -318,6 +341,7 @@ def parse_natural_datetime(text: Optional[str], reference_date: Optional[datetim
# Handle "in X days/weeks"
import re
+
match = re.search(r"in (\d+) (day|week|month)s?", text_lower)
if match:
num = int(match.group(1))
diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py
index c4cf533c..0b78ba73 100644
--- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py
+++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/models.py
@@ -14,6 +14,7 @@
class EntityType(str, Enum):
"""Supported entity types in the knowledge graph."""
+
PERSON = "person"
PLACE = "place"
ORGANIZATION = "organization"
@@ -26,6 +27,7 @@ class EntityType(str, Enum):
class RelationshipType(str, Enum):
"""Supported relationship types between entities."""
+
MENTIONED_IN = "MENTIONED_IN"
WORKS_AT = "WORKS_AT"
LIVES_IN = "LIVES_IN"
@@ -41,6 +43,7 @@ class RelationshipType(str, Enum):
class PromiseStatus(str, Enum):
"""Status of a promise/task."""
+
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
@@ -55,6 +58,7 @@ class Entity(BaseModel):
Each entity belongs to a specific user and can have relationships with
other entities.
"""
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
type: EntityType
@@ -113,6 +117,7 @@ class Relationship(BaseModel):
Relationships are edges in the graph connecting entities with
typed connections (e.g., WORKS_AT, KNOWS, MENTIONED_IN).
"""
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
type: RelationshipType
source_id: str # Entity ID
@@ -164,6 +169,7 @@ class Promise(BaseModel):
Promises are commitments made during conversations that can be
tracked and followed up on.
"""
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
user_id: str
action: str # What was promised
@@ -191,7 +197,9 @@ def to_dict(self) -> Dict[str, Any]:
"to_entity_name": self.to_entity_name,
"status": self.status,
"due_date": self.due_date.isoformat() if self.due_date else None,
- "completed_at": self.completed_at.isoformat() if self.completed_at else None,
+ "completed_at": (
+ self.completed_at.isoformat() if self.completed_at else None
+ ),
"source_conversation_id": self.source_conversation_id,
"context": self.context,
"metadata": self.metadata,
@@ -202,6 +210,7 @@ def to_dict(self) -> Dict[str, Any]:
class ExtractedEntity(BaseModel):
"""Entity as extracted by LLM before Neo4j storage."""
+
name: str
type: str # Will be validated against EntityType
details: Optional[str] = None
@@ -212,6 +221,7 @@ class ExtractedEntity(BaseModel):
class ExtractedRelationship(BaseModel):
"""Relationship as extracted by LLM."""
+
subject: str # Entity name or "speaker"
relation: str # Relationship type
object: str # Entity name
@@ -219,6 +229,7 @@ class ExtractedRelationship(BaseModel):
class ExtractedPromise(BaseModel):
"""Promise as extracted by LLM."""
+
action: str
to: Optional[str] = None # Entity name
deadline: Optional[str] = None # Natural language deadline
@@ -226,6 +237,7 @@ class ExtractedPromise(BaseModel):
class ExtractionResult(BaseModel):
"""Result of entity extraction from a conversation."""
+
entities: List[ExtractedEntity] = Field(default_factory=list)
relationships: List[ExtractedRelationship] = Field(default_factory=list)
promises: List[ExtractedPromise] = Field(default_factory=list)
diff --git a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py
index 5f13508d..0c7bd4a8 100644
--- a/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py
+++ b/backends/advanced/src/advanced_omi_backend/services/knowledge_graph/service.py
@@ -126,7 +126,9 @@ async def process_conversation(
)
if not extraction.entities and not extraction.promises:
- logger.debug(f"No entities extracted from conversation {conversation_id}")
+ logger.debug(
+ f"No entities extracted from conversation {conversation_id}"
+ )
return {"entities": 0, "relationships": 0, "promises": 0}
# Create conversation entity node
@@ -827,6 +829,7 @@ def _parse_metadata(self, value: Any) -> Dict[str, Any]:
if isinstance(value, str):
try:
import json
+
return json.loads(value)
except (json.JSONDecodeError, ValueError):
return {}
diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/README.md b/backends/advanced/src/advanced_omi_backend/services/memory/README.md
index ba6de6a4..49b74bd0 100644
--- a/backends/advanced/src/advanced_omi_backend/services/memory/README.md
+++ b/backends/advanced/src/advanced_omi_backend/services/memory/README.md
@@ -28,50 +28,50 @@ memory/
graph TB
%% User Interface Layer
User[User/Application] --> CompatService[compat_service.py
MemoryService]
-
+
%% Compatibility Layer
CompatService --> CoreService[memory_service.py
CoreMemoryService]
-
+
%% Configuration
Config[config.py
MemoryConfig] --> CoreService
Config --> LLMProviders
Config --> VectorStores
-
+
%% Core Service Layer
CoreService --> Base[base.py
Abstract Interfaces]
-
+
%% Base Abstractions
Base --> MemoryServiceBase[MemoryServiceBase]
- Base --> LLMProviderBase[LLMProviderBase]
+ Base --> LLMProviderBase[LLMProviderBase]
Base --> VectorStoreBase[VectorStoreBase]
Base --> MemoryEntry[MemoryEntry
Data Structure]
-
+
%% Provider Implementations
subgraph LLMProviders[LLM Providers]
OpenAI[OpenAIProvider]
Ollama[OllamaProvider]
Qwen[Qwen3EmbeddingProvider]
end
-
+
subgraph VectorStores[Vector Stores]
Qdrant[QdrantVectorStore]
end
-
+
%% Inheritance relationships
LLMProviderBase -.-> OpenAI
LLMProviderBase -.-> Ollama
VectorStoreBase -.-> Qdrant
-
+
%% Core Service uses providers
CoreService --> LLMProviders
CoreService --> VectorStores
-
+
%% External Services
OpenAI --> OpenAIAPI[OpenAI API]
Ollama --> OllamaAPI[Ollama Server]
Qwen --> LocalModel[Local Qwen Model]
Qdrant --> QdrantDB[(Qdrant Database)]
-
+
%% Memory Flow
subgraph MemoryFlow[Memory Processing Flow]
Transcript[Transcript] --> Extract[Extract Memories
via LLM]
@@ -80,15 +80,15 @@ graph TB
Store --> Search[Semantic Search]
Search --> Results[Memory Results]
end
-
+
CoreService --> MemoryFlow
-
+
%% Styling
classDef interface fill:#e1f5fe,stroke:#01579b,stroke-width:2px
classDef implementation fill:#f3e5f5,stroke:#4a148c,stroke-width:2px
classDef external fill:#fff3e0,stroke:#e65100,stroke-width:2px
classDef data fill:#e8f5e8,stroke:#2e7d32,stroke-width:2px
-
+
class Base,MemoryServiceBase,LLMProviderBase,VectorStoreBase interface
class CompatService,CoreService,OpenAI,Ollama,Qdrant implementation
class OpenAIAPI,OllamaAPI,LocalModel,QdrantDB external
@@ -109,7 +109,7 @@ classDiagram
+string created_at
+__post_init__()
}
-
+
class MemoryServiceBase {
<
(OpenAI/Ollama)
participant Vector as Vector Store
(Qdrant)
participant Config as Configuration
-
+
Note over User, Config: Memory Service Initialization
User->>Compat: get_memory_service()
Compat->>Core: __init__(config)
@@ -275,23 +275,23 @@ sequenceDiagram
Vector-->>Core: ready
LLM-->>Core: ready
Core-->>Compat: initialized
-
+
Note over User, Config: Adding Memory from Transcript
User->>Compat: add_memory(transcript, ...)
Compat->>Core: add_memory(transcript, ...)
-
+
Core->>Core: _deduplicate_memories()
Core->>LLM: generate_embeddings(memory_texts)
LLM->>LLM: create vector embeddings
LLM-->>Core: List[embeddings]
-
+
alt Memory Updates Enabled
Core->>Vector: search_memories(embeddings, user_id)
Vector-->>Core: existing_memories
Core->>LLM: propose_memory_actions(old, new)
LLM->>LLM: decide ADD/UPDATE/DELETE
LLM-->>Core: actions_list
-
+
loop For each action
alt Action: ADD
Core->>Core: create MemoryEntry
@@ -306,11 +306,11 @@ sequenceDiagram
Core->>Core: create MemoryEntry objects
Core->>Vector: add_memories(entries)
end
-
+
Vector-->>Core: created_ids
Core-->>Compat: success, memory_ids
Compat-->>User: success, memory_ids
-
+
Note over User, Config: Searching Memories
User->>Compat: search_memories(query, user_id)
Compat->>Core: search_memories(query, user_id)
@@ -380,7 +380,7 @@ await memory_service.initialize()
success, memory_ids = await memory_service.add_memory(
transcript="User discussed their goals for the next quarter.",
client_id="client123",
- audio_uuid="audio456",
+ audio_uuid="audio456",
user_id="user789",
user_email="user@example.com"
)
@@ -412,7 +412,7 @@ success, memory_ids = await service.add_memory(
transcript="My favorite destination is now Tokyo instead of Paris.",
client_id="client123",
audio_uuid="audio456",
- user_id="user789",
+ user_id="user789",
user_email="user@example.com",
allow_update=True # Enable intelligent memory updates
)
@@ -427,11 +427,11 @@ class CustomLLMProvider(LLMProviderBase):
async def extract_memories(self, text: str, prompt: str) -> List[str]:
# Custom implementation
pass
-
+
async def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
- # Custom implementation
+ # Custom implementation
pass
-
+
# ... implement other abstract methods
```
@@ -464,7 +464,7 @@ MEMORY_TIMEOUT_SECONDS=30
```python
from advanced_omi_backend.memory.config import (
MemoryConfig,
- create_openai_config,
+ create_openai_config,
create_qdrant_config
)
@@ -493,7 +493,7 @@ The service supports intelligent memory updates through LLM-driven action propos
- **ADD**: Create new memories for novel information
- **UPDATE**: Modify existing memories with new details
-- **DELETE**: Remove outdated or incorrect memories
+- **DELETE**: Remove outdated or incorrect memories
- **NONE**: No action needed for redundant information
### Example Flow
@@ -594,7 +594,7 @@ The service includes comprehensive error handling:
The modular architecture makes it easy to:
1. **Add new LLM providers**: Inherit from `LLMProviderBase`
-2. **Add new vector stores**: Inherit from `VectorStoreBase`
+2. **Add new vector stores**: Inherit from `VectorStoreBase`
3. **Customize memory logic**: Override `MemoryServiceBase` methods
4. **Add new data formats**: Extend `MemoryEntry` or conversion logic
@@ -618,7 +618,7 @@ logging.getLogger("memory_service").setLevel(logging.INFO)
Log levels:
- **INFO**: Service lifecycle, major operations
-- **DEBUG**: Detailed processing information
+- **DEBUG**: Detailed processing information
- **WARNING**: Recoverable errors, fallbacks
- **ERROR**: Serious errors requiring attention
@@ -694,15 +694,15 @@ class AnthropicProvider(LLMProviderBase):
def __init__(self, config: Dict[str, Any]):
self.api_key = config["api_key"]
self.model = config.get("model", "claude-3-sonnet")
-
+
async def extract_memories(self, text: str, prompt: str) -> List[str]:
# Implement using Anthropic API
pass
-
+
async def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
# Implement using Anthropic embeddings
pass
-
+
# ... implement other abstract methods
```
@@ -781,15 +781,15 @@ class PineconeVectorStore(VectorStoreBase):
self.api_key = config["api_key"]
self.environment = config["environment"]
self.index_name = config["index_name"]
-
+
async def initialize(self) -> None:
# Initialize Pinecone client
pass
-
+
async def add_memories(self, memories: List[MemoryEntry]) -> List[str]:
# Add to Pinecone index
pass
-
+
# ... implement other abstract methods
```
@@ -831,10 +831,10 @@ class CustomMemoryService(CoreMemoryService):
async def add_memory(self, transcript: str, **kwargs):
# Pre-process transcript
processed_transcript = await self._custom_preprocessing(transcript)
-
+
# Call parent method
return await super().add_memory(processed_transcript, **kwargs)
-
+
async def _custom_preprocessing(self, transcript: str) -> str:
# Your custom logic here
return transcript
@@ -909,4 +909,4 @@ Planned improvements:
- Advanced memory summarization
- Multi-modal memory support (images, audio)
- Memory compression and archival
-- Real-time memory streaming
\ No newline at end of file
+- Real-time memory streaming
diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/base.py b/backends/advanced/src/advanced_omi_backend/services/memory/base.py
index bae18e56..3eaf7259 100644
--- a/backends/advanced/src/advanced_omi_backend/services/memory/base.py
+++ b/backends/advanced/src/advanced_omi_backend/services/memory/base.py
@@ -136,7 +136,9 @@ async def search_memories(
pass
@abstractmethod
- async def get_all_memories(self, user_id: str, limit: int = 100) -> List[MemoryEntry]:
+ async def get_all_memories(
+ self, user_id: str, limit: int = 100
+ ) -> List[MemoryEntry]:
"""Get all memories for a specific user.
Args:
@@ -265,7 +267,10 @@ async def reprocess_memory(
@abstractmethod
async def delete_memory(
- self, memory_id: str, user_id: Optional[str] = None, user_email: Optional[str] = None
+ self,
+ memory_id: str,
+ user_id: Optional[str] = None,
+ user_email: Optional[str] = None,
) -> bool:
"""Delete a specific memory by ID.
@@ -341,7 +346,10 @@ class LLMProviderBase(ABC):
@abstractmethod
async def extract_memories(
- self, text: str, prompt: str, user_id: Optional[str] = None,
+ self,
+ text: str,
+ prompt: str,
+ user_id: Optional[str] = None,
langfuse_session_id: Optional[str] = None,
) -> List[str]:
"""Extract meaningful fact memories from text using an LLM.
@@ -469,7 +477,11 @@ async def add_memories(self, memories: List[MemoryEntry]) -> List[str]:
@abstractmethod
async def search_memories(
- self, query_embedding: List[float], user_id: str, limit: int, score_threshold: float = 0.0
+ self,
+ query_embedding: List[float],
+ user_id: str,
+ limit: int,
+ score_threshold: float = 0.0,
) -> List[MemoryEntry]:
"""Search memories using vector similarity.
diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/config.py b/backends/advanced/src/advanced_omi_backend/services/memory/config.py
index 83070706..3050487a 100644
--- a/backends/advanced/src/advanced_omi_backend/services/memory/config.py
+++ b/backends/advanced/src/advanced_omi_backend/services/memory/config.py
@@ -137,7 +137,9 @@ def build_memory_config_from_env() -> MemoryConfig:
# Map legacy provider names to current names
if memory_provider in ("friend-lite", "friend_lite"):
- memory_logger.info(f"π§ Mapping legacy provider '{memory_provider}' to 'chronicle'")
+ memory_logger.info(
+ f"π§ Mapping legacy provider '{memory_provider}' to 'chronicle'"
+ )
memory_provider = "chronicle"
if memory_provider not in [p.value for p in MemoryProvider]:
@@ -178,7 +180,9 @@ def build_memory_config_from_env() -> MemoryConfig:
if not llm_def:
raise ValueError("No default LLM defined in config.yml")
model = llm_def.model_name
- embedding_model = embed_def.model_name if embed_def else "text-embedding-3-small"
+ embedding_model = (
+ embed_def.model_name if embed_def else "text-embedding-3-small"
+ )
base_url = llm_def.model_url
memory_logger.info(
f"π§ Memory config (registry): LLM={model}, Embedding={embedding_model}, Base URL={base_url}"
@@ -201,7 +205,9 @@ def build_memory_config_from_env() -> MemoryConfig:
host = str(vs_def.model_params.get("host", "qdrant"))
port = int(vs_def.model_params.get("port", 6333))
- collection_name = str(vs_def.model_params.get("collection_name", "chronicle_memories"))
+ collection_name = str(
+ vs_def.model_params.get("collection_name", "chronicle_memories")
+ )
vector_store_config = create_qdrant_config(
host=host,
port=port,
@@ -235,7 +241,9 @@ def build_memory_config_from_env() -> MemoryConfig:
)
except ImportError:
- memory_logger.warning("Config loader not available, using environment variables only")
+ memory_logger.warning(
+ "Config loader not available, using environment variables only"
+ )
raise
diff --git a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py
index 0e704be3..54c948b2 100644
--- a/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py
+++ b/backends/advanced/src/advanced_omi_backend/services/memory/prompts.py
@@ -33,32 +33,32 @@
DEFAULT_UPDATE_MEMORY_PROMPT = f"""
You are a memory manager for a system.
-You must compare a list of **retrieved facts** with the **existing memory** (an array of `{{id, text}}` objects).
-For each memory item, decide one of four operations: **ADD**, **UPDATE**, **DELETE**, or **NONE**.
+You must compare a list of **retrieved facts** with the **existing memory** (an array of `{{id, text}}` objects).
+For each memory item, decide one of four operations: **ADD**, **UPDATE**, **DELETE**, or **NONE**.
Your output must follow the exact XML format described.
---
## Rules
-1. **ADD**:
+1. **ADD**:
- If a retrieved fact is new (no existing memory on that topic), create a new `
Audio Chunks
Sent: {recording.debugStats.chunksSent}
- Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? + Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? Math.round(recording.debugStats.chunksSent / ((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000)) : 0}/s
Session
- Duration: {recording.debugStats.sessionStartTime ? + Duration: {recording.debugStats.sessionStartTime ? Math.round((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000) + 's' : 'N/A'}
@@ -72,4 +72,4 @@ export default function DebugPanel({ recording }: DebugPanelProps) {
User: {user?.name || user?.email} @@ -66,7 +66,7 @@ export default function RecordingStatus({ recording }: RecordingStatusProps) { {/* Component Status Indicators */}
Audio Chunks
Sent: {recording.debugStats.chunksSent}
- Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? + Rate: {recording.debugStats.chunksSent > 0 && recording.debugStats.sessionStartTime ? Math.round(recording.debugStats.chunksSent / ((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000)) : 0}/s
Session
- Duration: {recording.debugStats.sessionStartTime ? + Duration: {recording.debugStats.sessionStartTime ? Math.round((Date.now() - recording.debugStats.sessionStartTime.getTime()) / 1000) + 's' : 'N/A'}
@@ -72,4 +72,4 @@ export default function SimpleDebugPanel({ recording }: SimpleDebugPanelProps) {