diff --git a/crates/listener-core/src/actors/recorder/mod.rs b/crates/listener-core/src/actors/recorder/mod.rs index d40412762f..30dc08473b 100644 --- a/crates/listener-core/src/actors/recorder/mod.rs +++ b/crates/listener-core/src/actors/recorder/mod.rs @@ -25,6 +25,7 @@ pub enum RecMsg { pub struct RecArgs { pub app_dir: PathBuf, pub session_id: String, + pub done_tx: Option>, } pub struct RecState { @@ -34,6 +35,7 @@ pub struct RecState { wav_path: PathBuf, last_flush: Instant, is_stereo: bool, + done_tx: Option>, } pub struct RecorderActor { @@ -134,6 +136,7 @@ impl Actor for RecorderActor { wav_path, last_flush: Instant::now(), is_stereo, + done_tx: args.done_tx, }) } @@ -207,6 +210,9 @@ impl Actor for RecorderActor { } } + if let Some(tx) = st.done_tx.take() { + let _ = tx.send(()); + } Ok(()) } } diff --git a/crates/listener-core/src/actors/session/supervisor.rs b/crates/listener-core/src/actors/session/supervisor.rs index b49eab930f..cf749b1f92 100644 --- a/crates/listener-core/src/actors/session/supervisor.rs +++ b/crates/listener-core/src/actors/session/supervisor.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use hypr_supervisor::{RestartBudget, RestartTracker, RetryStrategy, spawn_with_retry}; use ractor::concurrency::Duration; use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SupervisionEvent}; @@ -34,6 +36,7 @@ pub struct SessionState { source_cell: Option, listener_cell: Option, recorder_cell: Option, + recorder_done: Option>, source_restarts: RestartTracker, recorder_restarts: RestartTracker, shutting_down: bool, @@ -74,20 +77,22 @@ impl Actor for SessionActor { ) .await?; - let recorder_cell = if ctx.params.record_enabled { + let (recorder_cell, recorder_done) = if ctx.params.record_enabled { + let (done_tx, done_rx) = tokio::sync::oneshot::channel(); let (recorder_ref, _): (ActorRef, _) = Actor::spawn_linked( Some(RecorderActor::name()), RecorderActor::new(), RecArgs { app_dir: ctx.app_dir.clone(), session_id: ctx.params.session_id.clone(), + done_tx: Some(done_tx), }, myself.get_cell(), ) .await?; - Some(recorder_ref.get_cell()) + (Some(recorder_ref.get_cell()), Some(done_rx)) } else { - None + (None, None) }; Ok(SessionState { @@ -95,6 +100,7 @@ impl Actor for SessionActor { source_cell: Some(source_ref.get_cell()), listener_cell: None, recorder_cell, + recorder_done, source_restarts: RestartTracker::new(), recorder_restarts: RestartTracker::new(), shutting_down: false, @@ -170,8 +176,9 @@ impl Actor for SessionActor { state.shutting_down = true; if let Some(cell) = state.recorder_cell.take() { + let done = state.recorder_done.take(); cell.stop(Some("session_stop".to_string())); - lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + wait_for_recorder_done(done).await; } if let Some(cell) = state.source_cell.take() { @@ -367,11 +374,14 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta let sup = supervisor_cell; let app_dir = state.ctx.app_dir.clone(); let session_id = state.ctx.params.session_id.clone(); + let (done_tx, done_rx) = tokio::sync::oneshot::channel(); + let done_tx = Arc::new(std::sync::Mutex::new(Some(done_tx))); let cell = spawn_with_retry(&RETRY_STRATEGY, || { let sup = sup.clone(); let app_dir = app_dir.clone(); let session_id = session_id.clone(); + let done_tx = done_tx.lock().unwrap().take(); async move { let (r, _): (ActorRef, _) = Actor::spawn_linked( Some(RecorderActor::name()), @@ -379,6 +389,7 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta RecArgs { app_dir, session_id, + done_tx, }, sup, ) @@ -391,6 +402,7 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta match cell { Some(c) => { state.recorder_cell = Some(c); + state.recorder_done = Some(done_rx); true } None => false, @@ -407,12 +419,24 @@ async fn meltdown(myself: ActorRef, state: &mut SessionState) { cell.stop(Some("meltdown".to_string())); } if let Some(cell) = state.recorder_cell.take() { + let done = state.recorder_done.take(); cell.stop(Some("meltdown".to_string())); - lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + wait_for_recorder_done(done).await; } myself.stop(Some("restart_limit_exceeded".to_string())); } +async fn wait_for_recorder_done(done: Option>) { + match done { + Some(rx) => { + tokio::time::timeout(Duration::from_secs(30), rx).await.ok(); + } + None => { + lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await; + } + } +} + fn classify_connection_failure(base_url: &str) -> String { if base_url.contains("localhost") || base_url.contains("127.0.0.1") { "Local transcription server is not running".to_string() diff --git a/crates/owhisper-client/src/providers.rs b/crates/owhisper-client/src/providers.rs index 3f3581c0ba..37e86d3344 100644 --- a/crates/owhisper-client/src/providers.rs +++ b/crates/owhisper-client/src/providers.rs @@ -357,6 +357,42 @@ impl Provider { } } + pub fn translate_control_message( + &self, + msg: &owhisper_interface::ControlMessage, + ) -> Option { + use crate::adapter::RealtimeSttAdapter; + use hypr_ws_client::client::Message; + use owhisper_interface::ControlMessage; + + fn extract_text(msg: Message) -> Option { + match msg { + Message::Text(t) => Some(t.to_string()), + _ => None, + } + } + + fn from_adapter(adapter: &impl RealtimeSttAdapter, msg: &ControlMessage) -> Option { + match msg { + ControlMessage::KeepAlive => adapter.keep_alive_message().and_then(extract_text), + ControlMessage::Finalize => extract_text(adapter.finalize_message()), + ControlMessage::CloseStream => None, + } + } + + match self { + Self::Deepgram => from_adapter(&crate::adapter::DeepgramAdapter, msg), + Self::AssemblyAI => from_adapter(&crate::adapter::AssemblyAIAdapter, msg), + Self::Soniox => from_adapter(&crate::adapter::SonioxAdapter, msg), + Self::Fireworks => from_adapter(&crate::adapter::FireworksAdapter, msg), + Self::OpenAI => from_adapter(&crate::adapter::OpenAIAdapter, msg), + Self::Gladia => from_adapter(&crate::adapter::GladiaAdapter, msg), + Self::ElevenLabs => from_adapter(&crate::adapter::ElevenLabsAdapter, msg), + Self::DashScope => from_adapter(&crate::adapter::DashScopeAdapter, msg), + Self::Mistral => from_adapter(&crate::adapter::MistralAdapter::default(), msg), + } + } + pub fn detect_error(&self, data: &[u8]) -> Option { match self { Self::Deepgram => deepgram::error::detect_error(data), diff --git a/crates/transcribe-proxy/src/relay/builder.rs b/crates/transcribe-proxy/src/relay/builder.rs index 8e439a67fd..f624e9d712 100644 --- a/crates/transcribe-proxy/src/relay/builder.rs +++ b/crates/transcribe-proxy/src/relay/builder.rs @@ -6,7 +6,10 @@ use owhisper_client::Auth; pub use tokio_tungstenite::tungstenite::ClientRequestBuilder; use super::handler::WebSocketProxy; -use super::types::{FirstMessageTransformer, InitialMessage, OnCloseCallback, ResponseTransformer}; +use super::types::{ + ClientMessageFilter, FirstMessageTransformer, InitialMessage, OnCloseCallback, + ResponseTransformer, +}; use crate::config::DEFAULT_CONNECT_TIMEOUT_MS; use crate::provider_selector::SelectedProvider; @@ -34,6 +37,7 @@ pub struct WebSocketProxyBuilder { response_transformer: Option, connect_timeout: Duration, on_close: Option, + client_message_filter: Option, } impl Default for WebSocketProxyBuilder { @@ -46,6 +50,7 @@ impl Default for WebSocketProxyBuilder { response_transformer: None, connect_timeout: Duration::from_millis(DEFAULT_CONNECT_TIMEOUT_MS), on_close: None, + client_message_filter: None, } } } @@ -60,6 +65,7 @@ impl WebSocketProxyBuilder { response_transformer: self.response_transformer, connect_timeout: self.connect_timeout, on_close: self.on_close, + client_message_filter: self.client_message_filter, } } @@ -71,6 +77,7 @@ impl WebSocketProxyBuilder { response_transformer: Option, connect_timeout: Duration, on_close: Option, + client_message_filter: Option, ) -> WebSocketProxy { let control_message_types = if control_message_types.is_empty() { None @@ -86,6 +93,7 @@ impl WebSocketProxyBuilder { response_transformer, connect_timeout, on_close, + client_message_filter, ) } @@ -131,6 +139,11 @@ impl WebSocketProxyBuilder { })); self } + + pub fn client_message_filter(mut self, filter: ClientMessageFilter) -> Self { + self.client_message_filter = Some(filter); + self + } } impl WebSocketProxyBuilder { @@ -194,6 +207,7 @@ impl WebSocketProxyBuilder { self.response_transformer, self.connect_timeout, self.on_close, + self.client_message_filter, )) } } diff --git a/crates/transcribe-proxy/src/relay/channel_split.rs b/crates/transcribe-proxy/src/relay/channel_split.rs index 1b5f2878f8..093ccd145e 100644 --- a/crates/transcribe-proxy/src/relay/channel_split.rs +++ b/crates/transcribe-proxy/src/relay/channel_split.rs @@ -14,7 +14,9 @@ use tokio_tungstenite::{ use owhisper_client::Provider; -use super::types::{InitialMessage, OnCloseCallback, ResponseTransformer, convert}; +use super::types::{ + ClientMessageFilter, InitialMessage, OnCloseCallback, ResponseTransformer, convert, +}; const SAMPLE_BYTES: usize = 2; const FRAME_BYTES: usize = SAMPLE_BYTES * 2; @@ -71,6 +73,7 @@ pub struct ChannelSplitProxy { response_transformer: Option, connect_timeout: Duration, on_close: Option, + client_message_filter: Option, } impl ChannelSplitProxy { @@ -106,9 +109,15 @@ impl ChannelSplitProxy { response_transformer, connect_timeout, on_close, + client_message_filter: None, } } + pub fn with_client_message_filter(mut self, filter: ClientMessageFilter) -> Self { + self.client_message_filter = Some(filter); + self + } + async fn connect_upstream( request: &ClientRequestBuilder, timeout: Duration, @@ -155,6 +164,7 @@ impl ChannelSplitProxy { spk_upstream, self.initial_message.clone(), self.response_transformer.clone(), + self.client_message_filter.clone(), ) .await; @@ -177,6 +187,7 @@ impl ChannelSplitProxy { spk_upstream: WebSocketStream>, initial_message: Option, response_transformer: Option, + client_message_filter: Option, ) { let (mut mic_tx, mut mic_rx) = mic_upstream.split(); let (mut spk_tx, mut spk_rx) = spk_upstream.split(); @@ -218,7 +229,15 @@ impl ChannelSplitProxy { } } Message::Text(text) => { - let tung = TungsteniteMessage::Text(text.to_string().into()); + let text_str = text.to_string(); + let forwarded = match client_message_filter.as_ref() { + Some(filter) => match filter(text_str) { + Some(s) => s, + None => continue, + }, + None => text_str, + }; + let tung = TungsteniteMessage::Text(forwarded.into()); if mic_tx.send(tung.clone()).await.is_err() || spk_tx.send(tung).await.is_err() { diff --git a/crates/transcribe-proxy/src/relay/handler.rs b/crates/transcribe-proxy/src/relay/handler.rs index 93de3eb5a6..26915dabf1 100644 --- a/crates/transcribe-proxy/src/relay/handler.rs +++ b/crates/transcribe-proxy/src/relay/handler.rs @@ -17,9 +17,9 @@ use owhisper_client::Provider; use super::builder::WebSocketProxyBuilder; use super::pending::{FlushError, PendingState, QueuedPayload}; use super::types::{ - ClientReceiver, ClientSender, ControlMessageTypes, DEFAULT_CLOSE_CODE, FirstMessageTransformer, - InitialMessage, OnCloseCallback, ResponseTransformer, UpstreamReceiver, UpstreamSender, - convert, is_control_message, + ClientMessageFilter, ClientReceiver, ClientSender, ControlMessageTypes, DEFAULT_CLOSE_CODE, + FirstMessageTransformer, InitialMessage, OnCloseCallback, ResponseTransformer, + UpstreamReceiver, UpstreamSender, convert, is_control_message, }; #[derive(Clone)] @@ -31,6 +31,7 @@ pub struct WebSocketProxy { response_transformer: Option, connect_timeout: Duration, on_close: Option, + client_message_filter: Option, } impl WebSocketProxy { @@ -42,6 +43,7 @@ impl WebSocketProxy { response_transformer: Option, connect_timeout: Duration, on_close: Option, + client_message_filter: Option, ) -> Self { Self { upstream_request, @@ -51,6 +53,7 @@ impl WebSocketProxy { response_transformer, connect_timeout, on_close, + client_message_filter, } } @@ -89,6 +92,7 @@ impl WebSocketProxy { self.initial_message.clone(), self.response_transformer.clone(), self.on_close.clone(), + self.client_message_filter.clone(), ) .await; @@ -121,6 +125,7 @@ impl WebSocketProxy { initial_message: Option, response_transformer: Option, on_close: Option, + client_message_filter: Option, ) { let start_time = Instant::now(); @@ -138,6 +143,7 @@ impl WebSocketProxy { control_message_types, transform_first_message, initial_message, + client_message_filter, ); let upstream_to_client = Self::run_upstream_to_client( @@ -211,6 +217,7 @@ impl WebSocketProxy { control_types: Option, mut first_msg_transformer: Option, initial_message: Option, + client_message_filter: Option, ) { let mut pending = PendingState::default(); @@ -264,6 +271,15 @@ impl WebSocketProxy { Some(t) => t(text_owned), None => text_owned, }; + + let text_str = match client_message_filter.as_ref() { + Some(filter) => match filter(text_str) { + Some(s) => s, + None => continue, + }, + None => text_str, + }; + let data = text_str.into_bytes(); if Self::process_data_message(&mut pending, data, true, &control_types, &shutdown_tx, &mut upstream_sender).await { diff --git a/crates/transcribe-proxy/src/relay/mod.rs b/crates/transcribe-proxy/src/relay/mod.rs index eef2fcd942..ed37e5a9ff 100644 --- a/crates/transcribe-proxy/src/relay/mod.rs +++ b/crates/transcribe-proxy/src/relay/mod.rs @@ -8,5 +8,5 @@ mod upstream_error; pub use builder::ClientRequestBuilder; pub use channel_split::ChannelSplitProxy; pub use handler::WebSocketProxy; -pub use types::{InitialMessage, OnCloseCallback, ResponseTransformer}; +pub use types::{ClientMessageFilter, InitialMessage, OnCloseCallback, ResponseTransformer}; pub use upstream_error::{UpstreamError, detect_upstream_error}; diff --git a/crates/transcribe-proxy/src/relay/types.rs b/crates/transcribe-proxy/src/relay/types.rs index 87e73a6873..516174c0ac 100644 --- a/crates/transcribe-proxy/src/relay/types.rs +++ b/crates/transcribe-proxy/src/relay/types.rs @@ -17,6 +17,8 @@ pub type FirstMessageTransformer = Arc String + Send + Sync>; pub type InitialMessage = Arc; pub type ResponseTransformer = Arc Option + Send + Sync>; +pub type ClientMessageFilter = Arc Option + Send + Sync>; + pub type UpstreamSender = SplitSink< WebSocketStream>, tokio_tungstenite::tungstenite::Message, diff --git a/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs b/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs index 890005d1aa..3a7605e413 100644 --- a/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs +++ b/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs @@ -10,7 +10,7 @@ use owhisper_interface::ListenParams; use crate::config::SttProxyConfig; use crate::provider_selector::SelectedProvider; use crate::query_params::{QueryParams, QueryValue}; -use crate::relay::{ChannelSplitProxy, WebSocketProxy}; +use crate::relay::{ChannelSplitProxy, ClientMessageFilter, WebSocketProxy}; use crate::routes::AppState; use crate::routes::model_resolution::resolve_model; @@ -103,6 +103,16 @@ fn build_response_transformer( } } +fn build_client_message_filter(provider: Provider) -> ClientMessageFilter { + Arc::new(move |text: String| { + let msg = match serde_json::from_str::(&text) { + Ok(msg) => msg, + Err(_) => return Some(text), + }; + provider.translate_control_message(&msg) + }) +} + pub enum StreamingProxy { Single(WebSocketProxy), ChannelSplit(ChannelSplitProxy), @@ -146,11 +156,13 @@ fn build_proxy_with_adapter( channels, ); + let filter = build_client_message_filter(provider); let mut builder = WebSocketProxy::builder() .upstream_url(upstream_url.as_str()) .connect_timeout(config.connect_timeout) .control_message_types(provider.control_message_types()) .response_transformer(build_response_transformer(provider)) + .client_message_filter(filter) .apply_auth(selected); if let Some(msg) = initial_message { @@ -192,13 +204,17 @@ fn build_channel_split_proxy( Some(Arc::new(build_response_transformer(provider))); let on_close = build_on_close_callback(config, provider, &analytics_ctx); - Ok(StreamingProxy::ChannelSplit(ChannelSplitProxy::new( - request, - initial_msg, - response_transformer, - config.connect_timeout, - on_close, - ))) + let filter = build_client_message_filter(provider); + Ok(StreamingProxy::ChannelSplit( + ChannelSplitProxy::new( + request, + initial_msg, + response_transformer, + config.connect_timeout, + on_close, + ) + .with_client_message_filter(filter), + )) } fn build_session_channel_split_proxy( @@ -224,6 +240,7 @@ fn build_session_channel_split_proxy( Some(Arc::new(build_response_transformer(provider))); let on_close = build_on_close_callback(config, provider, &analytics_ctx); + let filter = build_client_message_filter(provider); Ok(StreamingProxy::ChannelSplit( ChannelSplitProxy::with_split_requests( mic_request, @@ -232,7 +249,8 @@ fn build_session_channel_split_proxy( response_transformer, config.connect_timeout, on_close, - ), + ) + .with_client_message_filter(filter), )) } @@ -243,11 +261,13 @@ fn build_proxy_with_url_and_transformer( analytics_ctx: AnalyticsContext, ) -> Result { let provider = selected.provider(); + let filter = build_client_message_filter(provider); let builder = WebSocketProxy::builder() .upstream_url(upstream_url) .connect_timeout(config.connect_timeout) .control_message_types(provider.control_message_types()) .response_transformer(build_response_transformer(provider)) + .client_message_filter(filter) .apply_auth(selected); let proxy = finalize_proxy_builder!(builder, provider, config, analytics_ctx)?; @@ -490,6 +510,51 @@ mod tests { assert!(result.is_none()); } + #[test] + fn test_client_message_filter_deepgram_identity() { + let filter = build_client_message_filter(Provider::Deepgram); + assert_eq!( + filter(r#"{"type":"KeepAlive"}"#.to_string()), + Some(r#"{"type":"KeepAlive"}"#.to_string()) + ); + assert_eq!(filter(r#"{"type":"CloseStream"}"#.to_string()), None); + assert_eq!( + filter(r#"{"type":"Finalize"}"#.to_string()), + Some(r#"{"type":"Finalize"}"#.to_string()) + ); + } + + #[test] + fn test_client_message_filter_soniox_translates_control_messages() { + let filter = build_client_message_filter(Provider::Soniox); + + assert_eq!(filter(r#"{"type":"CloseStream"}"#.to_string()), None); + assert_eq!( + filter(r#"{"type":"KeepAlive"}"#.to_string()), + Some(r#"{"type":"keepalive"}"#.to_string()) + ); + assert_eq!( + filter(r#"{"type":"Finalize"}"#.to_string()), + Some(r#"{"type":"finalize"}"#.to_string()) + ); + } + + #[test] + fn test_client_message_filter_assemblyai_translates_finalize() { + let filter = build_client_message_filter(Provider::AssemblyAI); + assert_eq!(filter(r#"{"type":"KeepAlive"}"#.to_string()), None); + assert_eq!( + filter(r#"{"type":"Finalize"}"#.to_string()), + Some(r#"{"type":"Terminate"}"#.to_string()) + ); + } + + #[test] + fn test_client_message_filter_non_json_passthrough() { + let filter = build_client_message_filter(Provider::Soniox); + assert_eq!(filter("not json".to_string()), Some("not json".to_string())); + } + #[test] fn test_resolve_model_clears_meta_model_for_soniox() { let mut params = ListenParams {