diff --git a/Cargo.toml b/Cargo.toml index d0f6f82..eee3d0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "context-switch" -version = "1.0.1" +version = "1.1.0" edition = "2024" rust-version = "1.88" @@ -73,6 +73,10 @@ openai-api-rs = { workspace = true } serde_json = { workspace = true } chrono-tz = { version = "0.10.3" } + +# For recognizing audio files in azure-transcribe. +playback = { path = "services/playback" } + [workspace.dependencies] tracing-subscriber = { version = "0.3.19" } diff --git a/audio-knife/Cargo.toml b/audio-knife/Cargo.toml index 5e15a07..88f729c 100644 --- a/audio-knife/Cargo.toml +++ b/audio-knife/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "audio-knife" -version = "1.3.1" +version = "1.4.0" edition = "2024" [profile.dev] diff --git a/core/src/lib.rs b/core/src/lib.rs index 99f6c95..f2d4d27 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -94,7 +94,7 @@ pub fn audio_msg_channel(format: AudioFormat) -> (AudioMsgProducer, AudioMsgCons #[derive(Debug)] pub struct AudioConsumer { pub format: AudioFormat, - pub receiver: mpsc::Receiver>, + pub receiver: mpsc::UnboundedReceiver>, } impl AudioConsumer { @@ -114,18 +114,10 @@ impl AudioConsumer { #[derive(Debug)] pub struct AudioProducer { pub format: AudioFormat, - pub sender: mpsc::Sender>, + pub sender: mpsc::UnboundedSender>, } impl AudioProducer { - // TODO: remove this function. - pub fn produce_raw(&self, samples: Vec) -> Result<()> { - self.produce(AudioFrame { - format: self.format, - samples, - }) - } - pub fn produce(&self, frame: AudioFrame) -> Result<()> { if frame.format != self.format { bail!( @@ -134,15 +126,14 @@ impl AudioProducer { frame.format ); } - self.sender - .try_send(frame.samples) - .context("Sending samples") + self.sender.send(frame.samples).context("Sending samples")?; + Ok(()) } } /// Create an unidirectional audio channel. pub fn audio_channel(format: AudioFormat) -> (AudioProducer, AudioConsumer) { - let (producer, consumer) = mpsc::channel(256); + let (producer, consumer) = mpsc::unbounded_channel(); ( AudioProducer { format, diff --git a/core/src/protocol.rs b/core/src/protocol.rs index c38535d..fbbc3de 100644 --- a/core/src/protocol.rs +++ b/core/src/protocol.rs @@ -31,10 +31,12 @@ impl AudioFormat { time::Duration::from_secs_f64(mono_sample_count as f64 / self.sample_rate as f64) } + // Architecture: This is used only in the examples anymore. pub fn new_channel(&self) -> (AudioProducer, AudioConsumer) { audio_channel(*self) } + #[deprecated(note = "Removed without replacement")] pub fn new_msg_channel(&self) -> (AudioMsgProducer, AudioMsgConsumer) { audio_msg_channel(*self) } diff --git a/examples/azure-transcribe.rs b/examples/azure-transcribe.rs index d8636ba..1c99191 100644 --- a/examples/azure-transcribe.rs +++ b/examples/azure-transcribe.rs @@ -1,24 +1,58 @@ -use std::{env, time::Duration}; +use std::{env, path::Path, time::Duration}; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, bail}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use tokio::{ select, sync::mpsc::{channel, unbounded_channel}, }; -use context_switch::{InputModality, OutputModality, services::AzureTranscribe}; +use context_switch::{AudioConsumer, InputModality, OutputModality, services::AzureTranscribe}; use context_switch_core::{ AudioFormat, AudioFrame, audio, conversation::{Conversation, Input}, service::Service, }; +const LANGUAGE: &str = "de-DE"; + #[tokio::main] async fn main() -> Result<()> { dotenvy::dotenv_override()?; tracing_subscriber::fmt::init(); + let mut args = env::args(); + match args.len() { + 1 => recognize_from_microphone().await?, + 2 => recognize_from_wav(Path::new(&args.nth(1).unwrap())).await?, + _ => bail!("Invalid number of arguments, expect zero or one"), + } + Ok(()) +} + +async fn recognize_from_wav(file: &Path) -> Result<()> { + // For now we always convert to 16khz single channel (this is what we use internally for + // testing). + let format = AudioFormat { + channels: 1, + sample_rate: 16000, + }; + + let frames = playback::audio_file_to_frames(file, format)?; + if frames.is_empty() { + bail!("No frames in the audio file") + } + + let (producer, input_consumer) = format.new_channel(); + + for frame in frames { + producer.produce(frame)?; + } + + recognize(format, input_consumer).await +} + +async fn recognize_from_microphone() -> Result<()> { let host = cpal::default_host(); let device = host .default_input_device() @@ -33,7 +67,7 @@ async fn main() -> Result<()> { let sample_rate = config.sample_rate(); let format = AudioFormat::new(channels, sample_rate.0); - let (producer, mut input_consumer) = format.new_channel(); + let (producer, input_consumer) = format.new_channel(); // Create and run the input stream let stream = device @@ -56,19 +90,23 @@ async fn main() -> Result<()> { stream.play().expect("Failed to play stream"); - let language = "de-DE"; + recognize(format, input_consumer).await +} +async fn recognize(format: AudioFormat, mut input_consumer: AudioConsumer) -> Result<()> { // TODO: clarify how to access configurations. let params = azure::transcribe::Params { host: None, region: Some(env::var("AZURE_REGION").expect("AZURE_REGION undefined")), subscription_key: env::var("AZURE_SUBSCRIPTION_KEY") .expect("AZURE_SUBSCRIPTION_KEY undefined"), - language: language.into(), + language: LANGUAGE.into(), + speech_gate: false, }; let (output_producer, mut output_consumer) = unbounded_channel(); - let (conv_input_producer, conv_input_consumer) = channel(32); + // For now this is more or less unbounded, because we push complete audio files for recognition. + let (conv_input_producer, conv_input_consumer) = channel(16384); let azure = AzureTranscribe; let mut conversation = azure.conversation( diff --git a/services/azure/src/transcribe.rs b/services/azure/src/transcribe.rs index 4f968d6..84c160f 100644 --- a/services/azure/src/transcribe.rs +++ b/services/azure/src/transcribe.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use azure_speech::recognizer::{self, Event}; use futures::StreamExt; use serde::Deserialize; -use tracing::error; +use tracing::{error, info}; use crate::Host; use context_switch_core::{ @@ -20,6 +20,8 @@ pub struct Params { pub region: Option, pub subscription_key: String, pub language: String, + #[serde(default)] + pub speech_gate: bool, } #[derive(Debug)] @@ -67,9 +69,18 @@ impl Service for AzureTranscribe { .into_header_for_infinite_file(); stream! { yield wav_header; - let mut speech_gate = make_speech_gate_processor_soft_rms(0.0025, 10., 300., 0.01); - while let Some(Input::Audio{ frame }) = input.recv().await { - let frame = speech_gate(&frame); + let mut speech_gate = + if params.speech_gate { + info!("Enabling speech gate"); + Some(make_speech_gate_processor_soft_rms(0.0025, 10., 300., 0.01)) + } + else { + None + }; + while let Some(Input::Audio{ mut frame }) = input.recv().await { + if let Some(ref mut speech_gate) = speech_gate { + frame = (speech_gate)(&frame); + } yield frame.to_le_bytes(); // // Speech to text hours are measured as the hours of audio _sent to the service_, billed in second increments. diff --git a/services/playback/src/lib.rs b/services/playback/src/lib.rs index 283bba3..90dc0f9 100644 --- a/services/playback/src/lib.rs +++ b/services/playback/src/lib.rs @@ -159,7 +159,7 @@ impl Service for Playback { } /// Render the file into 100ms audio frames mono. -fn audio_file_to_frames(path: &Path, format: AudioFormat) -> Result> { +pub fn audio_file_to_frames(path: &Path, format: AudioFormat) -> Result> { check_supported_audio_type(&path.to_string_lossy(), None)?; let file = File::open(path).inspect_err(|e| { // We don't want to provide the resolved path to the user in an error message. Therefore we