diff --git a/apps/eval-cli/src/main.rs b/apps/eval-cli/src/main.rs index 9071db0129..d0321f67fd 100644 --- a/apps/eval-cli/src/main.rs +++ b/apps/eval-cli/src/main.rs @@ -8,7 +8,7 @@ mod report; mod submissions; use hypr_eval::{ - DEFAULT_MODELS, EvalResult, Executor, ExecutorProgress, OpenRouterClient, parse_config, + parse_config, EvalResult, Executor, ExecutorProgress, OpenRouterClient, DEFAULT_MODELS, }; use report::{render_json, render_results}; use submissions::{all_cases, filter_cases}; @@ -179,7 +179,7 @@ fn list_cases() { fn generate_completion(shell: Shell) { use clap::CommandFactory; - use clap_complete::{Shell as ClapShell, generate}; + use clap_complete::{generate, Shell as ClapShell}; let mut cmd = Cli::command(); let shell = match shell { diff --git a/apps/eval-cli/src/report.rs b/apps/eval-cli/src/report.rs index f892169293..a82d5d3471 100644 --- a/apps/eval-cli/src/report.rs +++ b/apps/eval-cli/src/report.rs @@ -1,4 +1,4 @@ -use comfy_table::{Cell, Color, ContentArrangement, Table, presets::UTF8_FULL_CONDENSED}; +use comfy_table::{presets::UTF8_FULL_CONDENSED, Cell, Color, ContentArrangement, Table}; use hypr_eval::EvalResult; diff --git a/apps/eval-cli/src/submissions.rs b/apps/eval-cli/src/submissions.rs index a9fcc8251a..7e059af967 100644 --- a/apps/eval-cli/src/submissions.rs +++ b/apps/eval-cli/src/submissions.rs @@ -1,6 +1,6 @@ use hypr_eval::{ - ChatMessage, CheckResult, EvalCase, GraderSpec, RubricSpec, find_headings, find_lists, grade, - is_non_empty, + find_headings, find_lists, grade, is_non_empty, ChatMessage, CheckResult, EvalCase, GraderSpec, + RubricSpec, }; use hypr_template_eval::{MdgenSystem, Template}; diff --git a/crates/audio-device/src/macos.rs b/crates/audio-device/src/macos.rs index 445ebb4945..af014c82ed 100644 --- a/crates/audio-device/src/macos.rs +++ b/crates/audio-device/src/macos.rs @@ -59,7 +59,7 @@ impl MacOSBackend { .unwrap_or(TransportType::Unknown); let is_default = default_device_id - .map(|id| device.0.0 == id) + .map(|id| device.0 .0 == id) .unwrap_or(false); let mut audio_device = AudioDevice { @@ -102,7 +102,11 @@ impl MacOSBackend { }) }); - if detected { Some(true) } else { None } + if detected { + Some(true) + } else { + None + } } fn is_external_from_device(device: Option) -> bool { @@ -120,8 +124,8 @@ impl AudioDeviceBackend for MacOSBackend { let ca_devices = ca::System::devices().map_err(|e| Error::EnumerationFailed(format!("{:?}", e)))?; - let default_input_id = ca::System::default_input_device().ok().map(|d| d.0.0); - let default_output_id = ca::System::default_output_device().ok().map(|d| d.0.0); + let default_input_id = ca::System::default_input_device().ok().map(|d| d.0 .0); + let default_output_id = ca::System::default_output_device().ok().map(|d| d.0 .0); let mut devices = Vec::new(); @@ -161,7 +165,7 @@ impl AudioDeviceBackend for MacOSBackend { Ok(Self::create_audio_device( &ca_device, AudioDirection::Input, - Some(ca_device.0.0), + Some(ca_device.0 .0), )) } @@ -178,7 +182,7 @@ impl AudioDeviceBackend for MacOSBackend { Ok(Self::create_audio_device( &ca_device, AudioDirection::Output, - Some(ca_device.0.0), + Some(ca_device.0 .0), )) } diff --git a/crates/audio-device/src/windows.rs b/crates/audio-device/src/windows.rs index 3d7b95fc08..74b6416b22 100644 --- a/crates/audio-device/src/windows.rs +++ b/crates/audio-device/src/windows.rs @@ -1,17 +1,17 @@ use crate::{AudioDevice, AudioDeviceBackend, AudioDirection, DeviceId, Error, TransportType}; use std::ffi::OsString; use std::os::windows::ffi::OsStringExt; +use windows::core::{Interface, GUID, PCWSTR, PWSTR}; use windows::Win32::Devices::FunctionDiscovery::PKEY_Device_FriendlyName; use windows::Win32::Media::Audio::Endpoints::IAudioEndpointVolume; use windows::Win32::Media::Audio::{ - DEVICE_STATE_ACTIVE, IMMDevice, IMMDeviceEnumerator, MMDeviceEnumerator, eAll, eCapture, - eConsole, eRender, + eAll, eCapture, eConsole, eRender, IMMDevice, IMMDeviceEnumerator, MMDeviceEnumerator, + DEVICE_STATE_ACTIVE, }; use windows::Win32::System::Com::{ - CLSCTX_ALL, COINIT_MULTITHREADED, CoCreateInstance, CoInitializeEx, CoUninitialize, STGM_READ, + CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_ALL, COINIT_MULTITHREADED, STGM_READ, }; use windows::Win32::UI::Shell::PropertiesSystem::IPropertyStore; -use windows::core::{GUID, Interface, PCWSTR, PWSTR}; pub struct WindowsBackend; diff --git a/crates/eval/src/format.rs b/crates/eval/src/format.rs index cfdce32c07..e37b0dc1fa 100644 --- a/crates/eval/src/format.rs +++ b/crates/eval/src/format.rs @@ -1,5 +1,5 @@ use markdown::mdast::{Heading, List, ListItem, Node}; -use markdown::{ParseOptions, to_mdast}; +use markdown::{to_mdast, ParseOptions}; #[derive(Debug, Clone)] pub struct CheckResult { diff --git a/crates/eval/src/lib.rs b/crates/eval/src/lib.rs index 9fe42390a7..e7158bf18d 100644 --- a/crates/eval/src/lib.rs +++ b/crates/eval/src/lib.rs @@ -54,23 +54,23 @@ pub use testing::*; // Re-export core types at root for convenience pub use client::{ - ChatCompleter, ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ClientError, - GraderResponse, OpenRouterClient, Usage, UsageResolver, generate_chat_multi_with_generation_id, - generate_chat_with_generation_id, generate_structured_grader_response, - generate_structured_grader_response_multi, generate_text_multi_with_generation_id, - generate_text_with_generation_id, + generate_chat_multi_with_generation_id, generate_chat_with_generation_id, + generate_structured_grader_response, generate_structured_grader_response_multi, + generate_text_multi_with_generation_id, generate_text_with_generation_id, ChatCompleter, + ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ClientError, GraderResponse, + OpenRouterClient, Usage, UsageResolver, }; -pub use config::{Config, parse_config}; +pub use config::{parse_config, Config}; pub use constants::*; pub use format::{ - CheckResult, GradeResult, Rule, count_list_items_in_section, extract_text, find_headings, - find_list_items, find_lists, first_inline_child, grade, split_by_headings, + count_list_items_in_section, extract_text, find_headings, find_list_items, find_lists, + first_inline_child, grade, split_by_headings, CheckResult, GradeResult, Rule, }; pub use models::{fetch_openrouter_models, filter_models}; -pub use rubric::{Score, grade_with_func, grade_with_llm, is_non_empty}; +pub use rubric::{grade_with_func, grade_with_llm, is_non_empty, Score}; pub use stats::{ - AggregatedGraderResponse, ConfidenceInterval, PassStats, aggregate_grader_responses, - calc_pass_stats, + aggregate_grader_responses, calc_pass_stats, AggregatedGraderResponse, ConfidenceInterval, + PassStats, }; pub use submission::{ EvalCase, EvalResult, Executor, ExecutorProgress, ExecutorProgressCallback, GraderSpec, diff --git a/crates/eval/src/rubric.rs b/crates/eval/src/rubric.rs index 6b5ba4764b..34f6769798 100644 --- a/crates/eval/src/rubric.rs +++ b/crates/eval/src/rubric.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use crate::{ - ChatCompleter, ConfidenceInterval, aggregate_grader_responses, - generate_structured_grader_response, generate_structured_grader_response_multi, + aggregate_grader_responses, generate_structured_grader_response, + generate_structured_grader_response_multi, ChatCompleter, ConfidenceInterval, }; /// Score represents the result of evaluating output against a rubric. diff --git a/crates/eval/src/submission.rs b/crates/eval/src/submission.rs index f8c1def52e..1f60e9cabb 100644 --- a/crates/eval/src/submission.rs +++ b/crates/eval/src/submission.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Mutex}; use rayon::prelude::*; use crate::constants::DEFAULT_GRADER_MODEL; -use crate::{ChatCompleter, ChatMessage, Score, Usage, parse_config}; +use crate::{parse_config, ChatCompleter, ChatMessage, Score, Usage}; pub type ValidatorFn = fn(&str) -> (bool, String); pub type ValidatorFnWithMeta = fn(&str, &HashMap) -> (bool, String); diff --git a/crates/owhisper-client/src/adapter/openai/live.rs b/crates/owhisper-client/src/adapter/openai/live.rs index 064b745910..5a50fd453e 100644 --- a/crates/owhisper-client/src/adapter/openai/live.rs +++ b/crates/owhisper-client/src/adapter/openai/live.rs @@ -29,8 +29,18 @@ impl RealtimeSttAdapter for OpenAIAdapter { false } - fn build_ws_url(&self, api_base: &str, _params: &ListenParams, _channels: u8) -> url::Url { - let (mut url, existing_params) = Self::build_ws_url_from_base(api_base); + fn build_ws_url(&self, api_base: &str, params: &ListenParams, _channels: u8) -> url::Url { + // Detect Azure from the base URL and store flag for initial_message + if let Ok(parsed) = api_base.parse::() { + if let Some(host) = parsed.host_str() { + if Self::is_azure_host(host) { + self.set_azure(true); + } + } + } + + let model = params.model.as_deref(); + let (mut url, existing_params) = Self::build_ws_url_from_base_with_model(api_base, model); if !existing_params.is_empty() { let mut query_pairs = url.query_pairs_mut(); @@ -78,6 +88,11 @@ impl RealtimeSttAdapter for OpenAIAdapter { None => default, }; + // Use the Azure flag set during build_ws_url (detected from api_base URL) + if self.is_azure() { + return self.build_azure_initial_message(model, language); + } + let session_config = SessionUpdateEvent { event_type: "session.update".to_string(), session: SessionConfig { @@ -211,6 +226,76 @@ impl RealtimeSttAdapter for OpenAIAdapter { } } +impl OpenAIAdapter { + /// Build Azure OpenAI-specific initial message + /// Azure uses a different session update format: transcription_session.update + fn build_azure_initial_message( + &self, + model: &str, + language: Option, + ) -> Option { + let session_update = AzureTranscriptionSessionUpdate { + event_type: "transcription_session.update".to_string(), + session: AzureSessionConfig { + input_audio_format: "pcm16".to_string(), + input_audio_transcription: AzureTranscriptionConfig { + model: model.to_string(), + prompt: None, + language, + }, + turn_detection: Some(AzureTurnDetection { + detection_type: VAD_DETECTION_TYPE.to_string(), + threshold: Some(VAD_THRESHOLD), + prefix_padding_ms: Some(VAD_PREFIX_PADDING_MS), + silence_duration_ms: Some(VAD_SILENCE_DURATION_MS), + }), + }, + }; + + let json = serde_json::to_string(&session_update).ok()?; + tracing::debug!(payload = %json, "azure_openai_session_update_payload"); + Some(Message::Text(json.into())) + } +} + +// Azure OpenAI specific session message types + +#[derive(Debug, Serialize)] +struct AzureTranscriptionSessionUpdate { + #[serde(rename = "type")] + event_type: String, + session: AzureSessionConfig, +} + +#[derive(Debug, Serialize)] +struct AzureSessionConfig { + input_audio_format: String, + input_audio_transcription: AzureTranscriptionConfig, + #[serde(skip_serializing_if = "Option::is_none")] + turn_detection: Option, +} + +#[derive(Debug, Serialize)] +struct AzureTranscriptionConfig { + model: String, + #[serde(skip_serializing_if = "Option::is_none")] + prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + language: Option, +} + +#[derive(Debug, Serialize)] +struct AzureTurnDetection { + #[serde(rename = "type")] + detection_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + threshold: Option, + #[serde(skip_serializing_if = "Option::is_none")] + prefix_padding_ms: Option, + #[serde(skip_serializing_if = "Option::is_none")] + silence_duration_ms: Option, +} + #[derive(Debug, Serialize)] struct SessionUpdateEvent { #[serde(rename = "type")] diff --git a/crates/owhisper-client/src/adapter/openai/mod.rs b/crates/owhisper-client/src/adapter/openai/mod.rs index 7dd67158f1..bc078806fd 100644 --- a/crates/owhisper-client/src/adapter/openai/mod.rs +++ b/crates/owhisper-client/src/adapter/openai/mod.rs @@ -5,8 +5,33 @@ use crate::providers::Provider; use super::{LanguageQuality, LanguageSupport}; -#[derive(Clone, Default)] -pub struct OpenAIAdapter; +const AZURE_API_VERSION: &str = "2025-04-01-preview"; + +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; + +#[derive(Clone)] +pub struct OpenAIAdapter { + is_azure: Arc, +} + +impl Default for OpenAIAdapter { + fn default() -> Self { + Self { + is_azure: Arc::new(AtomicBool::new(false)), + } + } +} + +impl OpenAIAdapter { + pub fn set_azure(&self, value: bool) { + self.is_azure.store(value, Ordering::SeqCst); + } + + pub fn is_azure(&self) -> bool { + self.is_azure.load(Ordering::SeqCst) + } +} impl OpenAIAdapter { pub fn language_support_live(_languages: &[hypr_language::Language]) -> LanguageSupport { @@ -27,7 +52,18 @@ impl OpenAIAdapter { Self::language_support_batch(languages).is_supported() } + pub fn is_azure_host(host: &str) -> bool { + host.ends_with(".openai.azure.com") + } + pub(crate) fn build_ws_url_from_base(api_base: &str) -> (url::Url, Vec<(String, String)>) { + Self::build_ws_url_from_base_with_model(api_base, None) + } + + pub(crate) fn build_ws_url_from_base_with_model( + api_base: &str, + model: Option<&str>, + ) -> (url::Url, Vec<(String, String)>) { if api_base.is_empty() { return ( Provider::OpenAI @@ -43,15 +79,20 @@ impl OpenAIAdapter { } let parsed: url::Url = api_base.parse().expect("invalid_api_base"); + let host = parsed + .host_str() + .unwrap_or(Provider::OpenAI.default_ws_host()); + + if Self::is_azure_host(host) { + return Self::build_azure_ws_url(&parsed, host, model); + } + let mut existing_params = super::extract_query_params(&parsed); if !existing_params.iter().any(|(k, _)| k == "intent") { existing_params.push(("intent".to_string(), "transcription".to_string())); } - let host = parsed - .host_str() - .unwrap_or(Provider::OpenAI.default_ws_host()); let mut url: url::Url = format!("wss://{}{}", host, Provider::OpenAI.ws_path()) .parse() .expect("invalid_ws_url"); @@ -60,6 +101,28 @@ impl OpenAIAdapter { (url, existing_params) } + + fn build_azure_ws_url( + parsed: &url::Url, + host: &str, + _model: Option<&str>, + ) -> (url::Url, Vec<(String, String)>) { + // For Azure transcription Realtime API: + // - deployment/model should NOT be in URL (causes 400 error) + // - deployment must be sent in transcription_session.update message + // - Only api-version and intent go in the URL + + let url: url::Url = format!("wss://{}/openai/realtime", host) + .parse() + .expect("invalid_azure_ws_url"); + + let params = vec![ + ("api-version".to_string(), AZURE_API_VERSION.to_string()), + ("intent".to_string(), "transcription".to_string()), + ]; + + (url, params) + } } #[cfg(test)] @@ -98,4 +161,65 @@ mod tests { assert!(Provider::OpenAI.is_host("openai.com")); assert!(!Provider::OpenAI.is_host("api.deepgram.com")); } + + #[test] + fn test_is_azure_host() { + assert!(OpenAIAdapter::is_azure_host("my-resource.openai.azure.com")); + assert!(OpenAIAdapter::is_azure_host("eastus.openai.azure.com")); + assert!(!OpenAIAdapter::is_azure_host("api.openai.com")); + assert!(!OpenAIAdapter::is_azure_host("openai.com")); + assert!(!OpenAIAdapter::is_azure_host("azure.com")); + } + + #[test] + fn test_build_ws_url_azure() { + // Azure transcription: deployment should NOT be in URL + // Only api-version and intent should be in params + let (url, params) = OpenAIAdapter::build_ws_url_from_base_with_model( + "https://my-resource.openai.azure.com", + Some("gpt-4o-realtime-preview"), + ); + assert_eq!( + url.as_str(), + "wss://my-resource.openai.azure.com/openai/realtime" + ); + assert!( + params + .iter() + .any(|(k, v)| k == "api-version" && v == "2025-04-01-preview") + ); + assert!( + params + .iter() + .any(|(k, v)| k == "intent" && v == "transcription") + ); + // deployment should NOT be in URL params + assert!(!params.iter().any(|(k, _)| k == "deployment")); + } + + #[test] + fn test_build_ws_url_azure_deployment_not_in_url() { + // Even if deployment is in input URL, it should not appear in output params + // (deployment goes in session message, not URL) + let (url, params) = OpenAIAdapter::build_ws_url_from_base_with_model( + "https://my-resource.openai.azure.com?deployment=my-deployment", + None, + ); + assert_eq!( + url.as_str(), + "wss://my-resource.openai.azure.com/openai/realtime" + ); + // deployment should NOT be in URL params + assert!(!params.iter().any(|(k, _)| k == "deployment")); + assert!( + params + .iter() + .any(|(k, v)| k == "api-version" && v == "2025-04-01-preview") + ); + assert!( + params + .iter() + .any(|(k, v)| k == "intent" && v == "transcription") + ); + } } diff --git a/crates/owhisper-client/src/lib.rs b/crates/owhisper-client/src/lib.rs index 88c38b9e6e..e3f3680970 100644 --- a/crates/owhisper-client/src/lib.rs +++ b/crates/owhisper-client/src/lib.rs @@ -15,8 +15,6 @@ pub use providers::{Auth, Provider, is_meta_model}; use std::marker::PhantomData; -#[cfg(feature = "argmax")] -pub use adapter::StreamingBatchConfig; pub use adapter::deepgram::DeepgramModel; pub use adapter::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, BatchSttAdapter, CactusAdapter, CallbackResult, @@ -26,7 +24,15 @@ pub use adapter::{ documented_language_codes_batch, documented_language_codes_live, is_hyprnote_proxy, is_local_host, normalize_languages, }; -pub use adapter::{StreamingBatchEvent, StreamingBatchStream}; + +fn is_azure_openai(api_base: &str) -> bool { + url::Url::parse(api_base) + .ok() + .and_then(|u| u.host_str().map(OpenAIAdapter::is_azure_host)) + .unwrap_or(false) +} +#[cfg(feature = "argmax")] +pub use adapter::{StreamingBatchConfig, StreamingBatchEvent, StreamingBatchStream}; pub use batch::{BatchClient, BatchClientBuilder}; pub use error::Error; @@ -117,6 +123,10 @@ impl ListenClientBuilder { for (name, value) in &self.extra_headers { request = request.with_header(name, value); } + } else if is_azure_openai(original_api_base) { + if let Some(api_key) = self.api_key.as_deref() { + request = request.with_header("api-key", api_key); + } } else if let Some((header_name, header_value)) = adapter.build_auth_header(self.api_key.as_deref()) { diff --git a/crates/owhisper-client/src/providers.rs b/crates/owhisper-client/src/providers.rs index 3f3581c0ba..60866b285d 100644 --- a/crates/owhisper-client/src/providers.rs +++ b/crates/owhisper-client/src/providers.rs @@ -236,7 +236,14 @@ impl Provider { pub fn is_host(&self, host: &str) -> bool { let domain = self.domain(); - host == domain || host.ends_with(&format!(".{}", domain)) + let matches_domain = host == domain || host.ends_with(&format!(".{}", domain)); + + // Also match Azure OpenAI (*.openai.azure.com) as OpenAI provider + if *self == Self::OpenAI && host.ends_with(".openai.azure.com") { + return true; + } + + matches_domain } pub fn matches_url(&self, base_url: &str) -> bool { diff --git a/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs b/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs index 890005d1aa..35b3cfc6f6 100644 --- a/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs +++ b/crates/transcribe-proxy/src/routes/streaming/hyprnote.rs @@ -42,7 +42,7 @@ fn build_upstream_url_with_adapter( Provider::AssemblyAI => AssemblyAIAdapter.build_ws_url(api_base, params, channels), Provider::Soniox => SonioxAdapter.build_ws_url(api_base, params, channels), Provider::Fireworks => FireworksAdapter.build_ws_url(api_base, params, channels), - Provider::OpenAI => OpenAIAdapter.build_ws_url(api_base, params, channels), + Provider::OpenAI => OpenAIAdapter::default().build_ws_url(api_base, params, channels), Provider::Gladia => GladiaAdapter.build_ws_url(api_base, params, channels), Provider::ElevenLabs => ElevenLabsAdapter.build_ws_url(api_base, params, channels), Provider::DashScope => DashScopeAdapter.build_ws_url(api_base, params, channels), @@ -61,7 +61,7 @@ fn build_initial_message_with_adapter( Provider::AssemblyAI => AssemblyAIAdapter.initial_message(api_key, params, channels), Provider::Soniox => SonioxAdapter.initial_message(api_key, params, channels), Provider::Fireworks => FireworksAdapter.initial_message(api_key, params, channels), - Provider::OpenAI => OpenAIAdapter.initial_message(api_key, params, channels), + Provider::OpenAI => OpenAIAdapter::default().initial_message(api_key, params, channels), Provider::Gladia => GladiaAdapter.initial_message(api_key, params, channels), Provider::ElevenLabs => ElevenLabsAdapter.initial_message(api_key, params, channels), Provider::DashScope => DashScopeAdapter.initial_message(api_key, params, channels), @@ -84,7 +84,7 @@ fn build_response_transformer( Provider::AssemblyAI => AssemblyAIAdapter.parse_response(raw), Provider::Soniox => SonioxAdapter.parse_response(raw), Provider::Fireworks => FireworksAdapter.parse_response(raw), - Provider::OpenAI => OpenAIAdapter.parse_response(raw), + Provider::OpenAI => OpenAIAdapter::default().parse_response(raw), Provider::Gladia => GladiaAdapter.parse_response(raw), Provider::ElevenLabs => ElevenLabsAdapter.parse_response(raw), Provider::DashScope => DashScopeAdapter.parse_response(raw), diff --git a/plugins/fs2/src/error.rs b/plugins/fs2/src/error.rs index 546936d975..5e1995422e 100644 --- a/plugins/fs2/src/error.rs +++ b/plugins/fs2/src/error.rs @@ -1,4 +1,4 @@ -use serde::{Serialize, ser::Serializer}; +use serde::{ser::Serializer, Serialize}; pub type Result = std::result::Result;