Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "context-switch"
version = "1.0.1"
version = "1.1.0"
edition = "2024"
rust-version = "1.88"

Expand Down Expand Up @@ -73,6 +73,10 @@ openai-api-rs = { workspace = true }
serde_json = { workspace = true }
chrono-tz = { version = "0.10.3" }


# For recognizing audio files in azure-transcribe.
playback = { path = "services/playback" }

[workspace.dependencies]
tracing-subscriber = { version = "0.3.19" }

Expand Down
2 changes: 1 addition & 1 deletion audio-knife/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "audio-knife"
version = "1.3.1"
version = "1.4.0"
edition = "2024"

[profile.dev]
Expand Down
19 changes: 5 additions & 14 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub fn audio_msg_channel(format: AudioFormat) -> (AudioMsgProducer, AudioMsgCons
#[derive(Debug)]
pub struct AudioConsumer {
pub format: AudioFormat,
pub receiver: mpsc::Receiver<Vec<i16>>,
pub receiver: mpsc::UnboundedReceiver<Vec<i16>>,
}

impl AudioConsumer {
Expand All @@ -114,18 +114,10 @@ impl AudioConsumer {
#[derive(Debug)]
pub struct AudioProducer {
pub format: AudioFormat,
pub sender: mpsc::Sender<Vec<i16>>,
pub sender: mpsc::UnboundedSender<Vec<i16>>,
}

impl AudioProducer {
// TODO: remove this function.
pub fn produce_raw(&self, samples: Vec<i16>) -> Result<()> {
self.produce(AudioFrame {
format: self.format,
samples,
})
}

pub fn produce(&self, frame: AudioFrame) -> Result<()> {
if frame.format != self.format {
bail!(
Expand All @@ -134,15 +126,14 @@ impl AudioProducer {
frame.format
);
}
self.sender
.try_send(frame.samples)
.context("Sending samples")
self.sender.send(frame.samples).context("Sending samples")?;
Ok(())
}
}

/// Create an unidirectional audio channel.
pub fn audio_channel(format: AudioFormat) -> (AudioProducer, AudioConsumer) {
let (producer, consumer) = mpsc::channel(256);
let (producer, consumer) = mpsc::unbounded_channel();
(
AudioProducer {
format,
Expand Down
2 changes: 2 additions & 0 deletions core/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ impl AudioFormat {
time::Duration::from_secs_f64(mono_sample_count as f64 / self.sample_rate as f64)
}

// Architecture: This is used only in the examples anymore.
pub fn new_channel(&self) -> (AudioProducer, AudioConsumer) {
audio_channel(*self)
}

#[deprecated(note = "Removed without replacement")]
pub fn new_msg_channel(&self) -> (AudioMsgProducer, AudioMsgConsumer) {
audio_msg_channel(*self)
}
Expand Down
52 changes: 45 additions & 7 deletions examples/azure-transcribe.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,58 @@
use std::{env, time::Duration};
use std::{env, path::Path, time::Duration};

use anyhow::{Context, Result};
use anyhow::{Context, Result, bail};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use tokio::{
select,
sync::mpsc::{channel, unbounded_channel},
};

use context_switch::{InputModality, OutputModality, services::AzureTranscribe};
use context_switch::{AudioConsumer, InputModality, OutputModality, services::AzureTranscribe};
use context_switch_core::{
AudioFormat, AudioFrame, audio,
conversation::{Conversation, Input},
service::Service,
};

const LANGUAGE: &str = "de-DE";

#[tokio::main]
async fn main() -> Result<()> {
dotenvy::dotenv_override()?;
tracing_subscriber::fmt::init();

let mut args = env::args();
match args.len() {
1 => recognize_from_microphone().await?,
2 => recognize_from_wav(Path::new(&args.nth(1).unwrap())).await?,
_ => bail!("Invalid number of arguments, expect zero or one"),
}
Ok(())
}

async fn recognize_from_wav(file: &Path) -> Result<()> {
// For now we always convert to 16khz single channel (this is what we use internally for
// testing).
let format = AudioFormat {
channels: 1,
sample_rate: 16000,
};

let frames = playback::audio_file_to_frames(file, format)?;
if frames.is_empty() {
bail!("No frames in the audio file")
}

let (producer, input_consumer) = format.new_channel();

for frame in frames {
producer.produce(frame)?;
}

recognize(format, input_consumer).await
}

async fn recognize_from_microphone() -> Result<()> {
let host = cpal::default_host();
let device = host
.default_input_device()
Expand All @@ -33,7 +67,7 @@ async fn main() -> Result<()> {
let sample_rate = config.sample_rate();
let format = AudioFormat::new(channels, sample_rate.0);

let (producer, mut input_consumer) = format.new_channel();
let (producer, input_consumer) = format.new_channel();

// Create and run the input stream
let stream = device
Expand All @@ -56,19 +90,23 @@ async fn main() -> Result<()> {

stream.play().expect("Failed to play stream");

let language = "de-DE";
recognize(format, input_consumer).await
}

async fn recognize(format: AudioFormat, mut input_consumer: AudioConsumer) -> Result<()> {
// TODO: clarify how to access configurations.
let params = azure::transcribe::Params {
host: None,
region: Some(env::var("AZURE_REGION").expect("AZURE_REGION undefined")),
subscription_key: env::var("AZURE_SUBSCRIPTION_KEY")
.expect("AZURE_SUBSCRIPTION_KEY undefined"),
language: language.into(),
language: LANGUAGE.into(),
speech_gate: false,
};

let (output_producer, mut output_consumer) = unbounded_channel();
let (conv_input_producer, conv_input_consumer) = channel(32);
// For now this is more or less unbounded, because we push complete audio files for recognition.
let (conv_input_producer, conv_input_consumer) = channel(16384);

let azure = AzureTranscribe;
let mut conversation = azure.conversation(
Expand Down
19 changes: 15 additions & 4 deletions services/azure/src/transcribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use async_trait::async_trait;
use azure_speech::recognizer::{self, Event};
use futures::StreamExt;
use serde::Deserialize;
use tracing::error;
use tracing::{error, info};

use crate::Host;
use context_switch_core::{
Expand All @@ -20,6 +20,8 @@ pub struct Params {
pub region: Option<String>,
pub subscription_key: String,
pub language: String,
#[serde(default)]
pub speech_gate: bool,
}

#[derive(Debug)]
Expand Down Expand Up @@ -67,9 +69,18 @@ impl Service for AzureTranscribe {
.into_header_for_infinite_file();
stream! {
yield wav_header;
let mut speech_gate = make_speech_gate_processor_soft_rms(0.0025, 10., 300., 0.01);
while let Some(Input::Audio{ frame }) = input.recv().await {
let frame = speech_gate(&frame);
let mut speech_gate =
if params.speech_gate {
info!("Enabling speech gate");
Some(make_speech_gate_processor_soft_rms(0.0025, 10., 300., 0.01))
}
else {
None
};
while let Some(Input::Audio{ mut frame }) = input.recv().await {
if let Some(ref mut speech_gate) = speech_gate {
frame = (speech_gate)(&frame);
}
yield frame.to_le_bytes();
// <https://azure.microsoft.com/en-us/pricing/details/cognitive-services/speech-services/>
// Speech to text hours are measured as the hours of audio _sent to the service_, billed in second increments.
Expand Down
2 changes: 1 addition & 1 deletion services/playback/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl Service for Playback {
}

/// Render the file into 100ms audio frames mono.
fn audio_file_to_frames(path: &Path, format: AudioFormat) -> Result<Vec<AudioFrame>> {
pub fn audio_file_to_frames(path: &Path, format: AudioFormat) -> Result<Vec<AudioFrame>> {
check_supported_audio_type(&path.to_string_lossy(), None)?;
let file = File::open(path).inspect_err(|e| {
// We don't want to provide the resolved path to the user in an error message. Therefore we
Expand Down
Loading