diff --git a/Cargo.lock b/Cargo.lock index a6d5075f2d..6f304d96a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4507,10 +4507,15 @@ dependencies = [ name = "denoise" version = "0.1.0" dependencies = [ + "approx", + "audio-snapshot", "criterion", + "dasp", "data", + "hound", "onnx", "realfft", + "rodio", "serde", "thiserror 2.0.18", ] @@ -10193,15 +10198,18 @@ dependencies = [ "host", "hound", "language", + "mp3", "owhisper-client", "owhisper-interface", "ractor", "serde", "specta", + "storage", "thiserror 2.0.18", "tokio", "tokio-stream", "tracing", + "uuid", ] [[package]] @@ -10859,6 +10867,7 @@ dependencies = [ "hound", "mp3lame-encoder", "tempfile", + "thiserror 2.0.18", ] [[package]] @@ -17825,6 +17834,7 @@ dependencies = [ "data", "frontmatter", "glob", + "mp3", "predicates", "rayon", "rodio", @@ -17833,6 +17843,7 @@ dependencies = [ "serde_yaml", "specta", "specta-typescript", + "storage", "tauri", "tauri-plugin", "tauri-plugin-notify", @@ -18040,6 +18051,7 @@ dependencies = [ "serde", "specta", "specta-typescript", + "storage", "tauri", "tauri-plugin", "tauri-plugin-settings", diff --git a/apps/desktop/src/audio-player/provider.tsx b/apps/desktop/src/audio-player/provider.tsx index 596b368dea..eef92e9a3e 100644 --- a/apps/desktop/src/audio-player/provider.tsx +++ b/apps/desktop/src/audio-player/provider.tsx @@ -61,6 +61,7 @@ class TimeStore { } interface AudioPlayerContextValue { + sessionId: string; registerContainer: (el: HTMLDivElement | null) => void; wavesurfer: WaveSurfer | null; state: AudioPlayerState; @@ -266,6 +267,7 @@ export function AudioPlayerProvider({ const value = useMemo( () => ({ + sessionId, registerContainer, wavesurfer, state, @@ -280,6 +282,7 @@ export function AudioPlayerProvider({ setPlaybackRate, }), [ + sessionId, registerContainer, wavesurfer, state, diff --git a/apps/desktop/src/audio-player/timeline.tsx b/apps/desktop/src/audio-player/timeline.tsx index 0e5c91fbe1..fc16e25f4b 100644 --- a/apps/desktop/src/audio-player/timeline.tsx +++ b/apps/desktop/src/audio-player/timeline.tsx @@ -1,5 +1,15 @@ -import { Pause, Play } from "lucide-react"; +import { useQueryClient } from "@tanstack/react-query"; +import { + CheckIcon, + LoaderIcon, + Pause, + Play, + SparklesIcon, + UndoIcon, +} from "lucide-react"; import { useEffect, useRef, useState } from "react"; +import { useStore } from "zustand"; +import { denoiseStore } from "~/services/denoise"; import { cn } from "@hypr/utils"; @@ -8,31 +18,27 @@ import { useAudioPlayer, useAudioTime } from "./provider"; const PLAYBACK_RATES = [0.5, 0.75, 1, 1.25, 1.5, 1.75, 2]; export function Timeline() { - const { - registerContainer, - state, - pause, - resume, - start, - playbackRate, - setPlaybackRate, - } = useAudioPlayer(); - const time = useAudioTime(); - const [showRateMenu, setShowRateMenu] = useState(false); - const rateMenuRef = useRef(null); + const { sessionId, registerContainer } = useAudioPlayer(); - useEffect(() => { - const handleClickOutside = (e: MouseEvent) => { - if ( - rateMenuRef.current && - !rateMenuRef.current.contains(e.target as Node) - ) { - setShowRateMenu(false); - } - }; - document.addEventListener("mousedown", handleClickOutside); - return () => document.removeEventListener("mousedown", handleClickOutside); - }, []); + return ( +
+
+ + + + +
+
+
+ ); +} + +function PlayPauseButton() { + const { state, pause, resume, start } = useAudioPlayer(); const handleClick = () => { if (state === "playing") { @@ -45,81 +51,205 @@ export function Timeline() { }; return ( -
-
- + ); +} + +function TimeDisplay() { + const time = useAudioTime(); + + return ( +
+ {formatTime(time.current)}/ + {formatTime(time.total)} +
+ ); +} + +function PlaybackRateSelector() { + const { playbackRate, setPlaybackRate } = useAudioPlayer(); + const [showMenu, setShowMenu] = useState(false); + const menuRef = useRef(null); + + useEffect(() => { + const handleClickOutside = (e: MouseEvent) => { + if (menuRef.current && !menuRef.current.contains(e.target as Node)) { + setShowMenu(false); + } + }; + document.addEventListener("mousedown", handleClickOutside); + return () => document.removeEventListener("mousedown", handleClickOutside); + }, []); + + return ( +
+ + {showMenu && ( +
- {state === "playing" ? ( - - ) : ( - - )} - - -
- {formatTime(time.current)}/ - {formatTime(time.total)} -
- -
- - {showRateMenu && ( -
( + - ))} -
- )} + {rate}x + + ))}
+ )} +
+ ); +} -
+function DenoiseButton({ sessionId }: { sessionId: string }) { + const queryClient = useQueryClient(); + const job = useStore(denoiseStore, (state) => state.jobs[sessionId]); + + const invalidateAudio = () => { + void queryClient.invalidateQueries({ + queryKey: ["audio", sessionId, "url"], + }); + void queryClient.invalidateQueries({ + queryKey: ["audio", sessionId, "exist"], + }); + }; + + const handleDenoise = () => { + void denoiseStore.getState().startDenoise(sessionId); + }; + + const handleConfirm = () => { + void denoiseStore + .getState() + .confirmDenoise(sessionId) + .then(invalidateAudio); + }; + + const handleRevert = () => { + void denoiseStore.getState().revertDenoise(sessionId).then(invalidateAudio); + }; + + useEffect(() => { + if (job?.status === "completed") { + invalidateAudio(); + } + }, [job?.status, sessionId, queryClient]); + + if (job?.status === "running") { + return ( + + ); + } + + if (job?.status === "completed") { + return ( +
+ +
-
+ ); + } + + return ( + ); } diff --git a/apps/desktop/src/services/denoise/index.ts b/apps/desktop/src/services/denoise/index.ts new file mode 100644 index 0000000000..bca8092aad --- /dev/null +++ b/apps/desktop/src/services/denoise/index.ts @@ -0,0 +1,183 @@ +import { createStore } from "zustand"; + +import { commands as fsSyncCommands } from "@hypr/plugin-fs-sync"; +import { + type DenoiseEvent, + commands as listener2Commands, + events as listener2Events, +} from "@hypr/plugin-listener2"; + +type DenoiseJob = { + status: "running" | "completed" | "failed"; + progress: number; + error?: string; +}; + +type DenoiseState = { + jobs: Record; +}; + +type DenoiseActions = { + startDenoise: (sessionId: string) => Promise; + confirmDenoise: (sessionId: string) => Promise; + revertDenoise: (sessionId: string) => Promise; + getJob: (sessionId: string) => DenoiseJob | undefined; +}; + +function createDenoiseStore() { + return createStore((set, get) => ({ + jobs: {}, + + getJob: (sessionId: string) => { + return get().jobs[sessionId]; + }, + + confirmDenoise: async (sessionId: string) => { + const result = await listener2Commands.audioConfirmDenoise(sessionId); + if (result.status === "error") { + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { + status: "failed", + progress: 0, + error: result.error, + }, + }, + })); + return; + } + + const { [sessionId]: _, ...rest } = get().jobs; + set({ jobs: rest }); + }, + + revertDenoise: async (sessionId: string) => { + const result = await listener2Commands.audioRevertDenoise(sessionId); + if (result.status === "error") { + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { + status: "failed", + progress: 0, + error: result.error, + }, + }, + })); + return; + } + + const { [sessionId]: _, ...rest } = get().jobs; + set({ jobs: rest }); + }, + + startDenoise: async (sessionId: string) => { + const existing = get().jobs[sessionId]; + if (existing?.status === "running") { + return; + } + + const audioPathResult = await fsSyncCommands.audioPath(sessionId); + if (audioPathResult.status === "error") { + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { + status: "failed", + progress: 0, + error: audioPathResult.error, + }, + }, + })); + return; + } + + const sessionDirResult = await fsSyncCommands.sessionDir(sessionId); + if (sessionDirResult.status === "error") { + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { + status: "failed", + progress: 0, + error: sessionDirResult.error, + }, + }, + })); + return; + } + + const inputPath = audioPathResult.data; + const outputPath = `${sessionDirResult.data}/audio-postprocess.wav`; + + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { status: "running", progress: 0 }, + }, + })); + + const unlisten = await listener2Events.denoiseEvent.listen( + (event: { payload: DenoiseEvent }) => { + const data = event.payload; + + if (!("session_id" in data) || data.session_id !== sessionId) { + return; + } + + switch (data.type) { + case "denoiseProgress": + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { status: "running", progress: data.percentage }, + }, + })); + break; + case "denoiseCompleted": + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { status: "completed", progress: 100 }, + }, + })); + unlisten(); + break; + case "denoiseFailed": + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { + status: "failed", + progress: 0, + error: data.error, + }, + }, + })); + unlisten(); + break; + } + }, + ); + + const result = await listener2Commands.runDenoise({ + session_id: sessionId, + input_path: inputPath, + output_path: outputPath, + }); + + if (result.status === "error") { + set((state) => ({ + jobs: { + ...state.jobs, + [sessionId]: { status: "failed", progress: 0, error: result.error }, + }, + })); + unlisten(); + } + }, + })); +} + +export const denoiseStore = createDenoiseStore(); diff --git a/crates/audio-utils/src/lib.rs b/crates/audio-utils/src/lib.rs index fb4250c0da..8ce9bbfaf7 100644 --- a/crates/audio-utils/src/lib.rs +++ b/crates/audio-utils/src/lib.rs @@ -99,6 +99,31 @@ pub fn deinterleave_stereo_bytes(data: &[u8]) -> (Vec, Vec) { (ch0, ch1) } +pub fn deinterleave(samples: &[f32], channels: usize) -> Vec> { + if channels <= 1 { + return vec![samples.to_vec()]; + } + let mut output = vec![Vec::with_capacity(samples.len() / channels + 1); channels]; + for (index, sample) in samples.iter().enumerate() { + output[index % channels].push(*sample); + } + output +} + +pub fn interleave(channels: &[Vec]) -> Vec { + if channels.is_empty() { + return Vec::new(); + } + let frames = channels.iter().map(|c| c.len()).max().unwrap_or(0); + let mut output = Vec::with_capacity(frames * channels.len()); + for frame in 0..frames { + for ch in channels { + output.push(ch.get(frame).copied().unwrap_or(0.0)); + } + } + output +} + pub fn mix_sample_f32(mic: f32, speaker: f32) -> f32 { (mic + speaker).clamp(-1.0, 1.0) } diff --git a/crates/audio-utils/src/vorbis.rs b/crates/audio-utils/src/vorbis.rs index c93e6c55da..cb98b39aff 100644 --- a/crates/audio-utils/src/vorbis.rs +++ b/crates/audio-utils/src/vorbis.rs @@ -6,7 +6,7 @@ use std::path::Path; use hound::{SampleFormat, WavReader, WavSpec, WavWriter}; use vorbis_rs::{VorbisBitrateManagementStrategy, VorbisDecoder, VorbisEncoderBuilder}; -use crate::Error; +use crate::{Error, deinterleave}; pub const DEFAULT_VORBIS_QUALITY: f32 = 0.7; pub const DEFAULT_VORBIS_BLOCK_SIZE: usize = 4096; @@ -300,15 +300,3 @@ pub fn mix_down_to_mono(samples: &[f32], channels: NonZeroU8) -> Vec { } mono } - -fn deinterleave(samples: &[f32], channels: usize) -> Vec> { - if channels <= 1 { - return vec![samples.to_vec()]; - } - - let mut output = vec![Vec::with_capacity(samples.len() / channels + 1); channels]; - for (index, sample) in samples.iter().enumerate() { - output[index % channels].push(*sample); - } - output -} diff --git a/crates/denoise/.gitignore b/crates/denoise/.gitignore new file mode 100644 index 0000000000..5beee59845 --- /dev/null +++ b/crates/denoise/.gitignore @@ -0,0 +1,3 @@ +*.wav +!data/inputs/*.wav +data/outputs/*.wav diff --git a/crates/denoise/Cargo.toml b/crates/denoise/Cargo.toml index b23036336e..5057b864b8 100644 --- a/crates/denoise/Cargo.toml +++ b/crates/denoise/Cargo.toml @@ -15,8 +15,13 @@ serde = { workspace = true } thiserror = { workspace = true } [dev-dependencies] +approx = { workspace = true } criterion = { workspace = true } +dasp = { workspace = true } +hound = { workspace = true } +hypr-audio-snapshot = { workspace = true } hypr-data = { workspace = true } +rodio = { workspace = true } [[bench]] name = "denoise_bench" diff --git a/crates/denoise/data/snapshots/english_1_batch.json b/crates/denoise/data/snapshots/english_1_batch.json new file mode 100644 index 0000000000..229204c588 --- /dev/null +++ b/crates/denoise/data/snapshots/english_1_batch.json @@ -0,0 +1,10 @@ +{ + "sample_count": 2791788, + "rms_energy": 0.10244533, + "peak_amplitude": 0.95817304, + "zero_crossing_rate": 0.12256307, + "spectral_centroid": 441.8855, + "band_energy_low": 241.81445, + "band_energy_mid": 73.382286, + "band_energy_high": 8.082554 +} diff --git a/crates/denoise/data/snapshots/english_1_streaming.json b/crates/denoise/data/snapshots/english_1_streaming.json new file mode 100644 index 0000000000..1bd495f6d0 --- /dev/null +++ b/crates/denoise/data/snapshots/english_1_streaming.json @@ -0,0 +1,10 @@ +{ + "sample_count": 2791788, + "rms_energy": 0.102442056, + "peak_amplitude": 0.9581596, + "zero_crossing_rate": 0.12255663, + "spectral_centroid": 441.92853, + "band_energy_low": 241.77722, + "band_energy_mid": 73.37927, + "band_energy_high": 8.085011 +} diff --git a/crates/denoise/src/onnx/context.rs b/crates/denoise/src/onnx/context.rs new file mode 100644 index 0000000000..0b419ca95b --- /dev/null +++ b/crates/denoise/src/onnx/context.rs @@ -0,0 +1,31 @@ +use hypr_onnx::ndarray::Array3; +use realfft::{ComplexToReal, RealToComplex, num_complex::Complex}; +use std::sync::Arc; + +pub(super) struct ProcessingContext { + pub scratch: Vec>, + pub ifft_scratch: Vec>, + pub fft_buffer: Vec, + pub fft_result: Vec>, + pub estimated_block_vec: Vec, + pub in_mag: Array3, + pub estimated_block: Array3, +} + +impl ProcessingContext { + pub fn new( + block_len: usize, + fft: &Arc>, + ifft: &Arc>, + ) -> Self { + Self { + scratch: vec![Complex::new(0.0f32, 0.0f32); fft.get_scratch_len()], + ifft_scratch: vec![Complex::new(0.0f32, 0.0f32); ifft.get_scratch_len()], + fft_buffer: vec![0.0f32; block_len], + fft_result: vec![Complex::new(0.0f32, 0.0f32); block_len / 2 + 1], + estimated_block_vec: vec![0.0f32; block_len], + in_mag: Array3::::zeros((1, 1, block_len / 2 + 1)), + estimated_block: Array3::::zeros((1, 1, block_len)), + } + } +} diff --git a/crates/denoise/src/onnx/denoiser.rs b/crates/denoise/src/onnx/denoiser.rs new file mode 100644 index 0000000000..2a64a90933 --- /dev/null +++ b/crates/denoise/src/onnx/denoiser.rs @@ -0,0 +1,222 @@ +use super::{buffer::CircularBuffer, context::ProcessingContext, error::Error, model}; +use hypr_onnx::{ + ndarray::{Array3, Array4}, + ort::{session::Session, value::TensorRef}, +}; +use realfft::{ComplexToReal, RealFftPlanner, RealToComplex, num_complex::Complex}; +use std::sync::Arc; + +pub struct Denoiser { + session_1: Session, + session_2: Session, + block_len: usize, + block_shift: usize, + fft: Arc>, + ifft: Arc>, + states_1: Array4, + states_2: Array4, + in_buffer: CircularBuffer, + out_buffer: CircularBuffer, +} + +impl Denoiser { + pub fn new() -> Result { + let block_len = model::BLOCK_SIZE; + let block_shift = model::BLOCK_SHIFT; + + let mut fft_planner = RealFftPlanner::::new(); + let fft = fft_planner.plan_fft_forward(block_len); + let ifft = fft_planner.plan_fft_inverse(block_len); + + let session_1 = hypr_onnx::load_model_from_bytes(model::BYTES_1)?; + let session_2 = hypr_onnx::load_model_from_bytes(model::BYTES_2)?; + + let state_size = model::STATE_SIZE; + + Ok(Denoiser { + session_1, + session_2, + block_len, + block_shift, + fft, + ifft, + states_1: Array4::::zeros((1, 2, state_size, 2)), + states_2: Array4::::zeros((1, 2, state_size, 2)), + in_buffer: CircularBuffer::new(block_len, block_shift), + out_buffer: CircularBuffer::new(block_len, block_shift), + }) + } + + pub fn reset(&mut self) { + let state_size = model::STATE_SIZE; + self.states_1 = Array4::::zeros((1, 2, state_size, 2)); + self.states_2 = Array4::::zeros((1, 2, state_size, 2)); + self.in_buffer.clear(); + self.out_buffer.clear(); + } + + pub fn process(&mut self, input: &[f32]) -> Result, Error> { + self.reset(); + + let len_audio = input.len(); + + let padding = vec![0.0f32; self.block_len - self.block_shift]; + let mut audio = Vec::with_capacity(padding.len() * 2 + len_audio); + audio.extend(&padding); + audio.extend(input); + audio.extend(&padding); + + let result = self._process_internal(&audio, true)?; + + let start_idx = self.block_len - self.block_shift; + Ok(result[start_idx..start_idx + len_audio].to_vec()) + } + + pub fn process_streaming(&mut self, input: &[f32]) -> Result, Error> { + if input.is_empty() { + return Ok(vec![]); + } + + self._process_internal(input, false) + } + + fn _process_internal(&mut self, audio: &[f32], with_padding: bool) -> Result, Error> { + let mut out_file = vec![0.0f32; audio.len()]; + + let effective_len = if with_padding { + audio.len() - (self.block_len - self.block_shift) + } else { + audio.len() + }; + let num_blocks = effective_len / self.block_shift; + + let mut ctx = ProcessingContext::new(self.block_len, &self.fft, &self.ifft); + + for idx in 0..num_blocks { + let start = idx * self.block_shift; + let end = (start + self.block_shift).min(audio.len()); + + self.in_buffer.push_chunk(&audio[start..end]); + + // FFT + ctx.fft_buffer.copy_from_slice(self.in_buffer.data()); + self.fft.process_with_scratch( + &mut ctx.fft_buffer, + &mut ctx.fft_result, + &mut ctx.scratch, + )?; + + // Extract magnitude + for (i, &c) in ctx.fft_result.iter().enumerate() { + ctx.in_mag[[0, 0, i]] = c.norm(); + } + + // Model 1: magnitude + states → mask + new states + let out_mask = self.run_model_1(&ctx.in_mag)?; + + // Apply mask to complex spectrum + for (i, c) in ctx.fft_result.iter_mut().enumerate() { + *c *= out_mask[[0, 0, i]]; + } + + // IFFT + self.ifft.process_with_scratch( + &mut ctx.fft_result, + &mut ctx.estimated_block_vec, + &mut ctx.ifft_scratch, + )?; + + // Normalize + let norm_factor = 1.0 / self.block_len as f32; + ctx.estimated_block_vec + .iter_mut() + .for_each(|x| *x *= norm_factor); + + // Copy to Array3 for model 2 + for (i, &val) in ctx.estimated_block_vec.iter().enumerate() { + ctx.estimated_block[[0, 0, i]] = val; + } + + // Model 2: time-domain samples + states → refined samples + new states + let out_block = self.run_model_2(&ctx.estimated_block)?; + + // Overlap-add + let out_slice = out_block.as_slice().ok_or_else(|| { + Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( + hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, + )) + })?; + self.out_buffer.shift_and_accumulate(out_slice); + + // Write to output + let out_start = idx * self.block_shift; + let out_end = (out_start + self.block_shift).min(out_file.len()); + let out_chunk_len = out_end - out_start; + if out_chunk_len > 0 { + out_file[out_start..out_end] + .copy_from_slice(&self.out_buffer.data()[..out_chunk_len]); + } + } + + self.normalize_output(&mut out_file); + Ok(out_file) + } + + fn run_model_1(&mut self, in_mag: &Array3) -> Result, Error> { + let mut outputs = self.session_1.run(hypr_onnx::ort::inputs![ + "input_2" => TensorRef::from_array_view(in_mag.view())?, + "input_3" => TensorRef::from_array_view(self.states_1.view())? + ])?; + + let out_mask = outputs + .remove("activation_2") + .ok_or_else(|| Error::MissingOutput("activation_2".to_string()))? + .try_extract_array::()? + .view() + .to_owned() + .into_shape_with_order((1, 1, model::FFT_OUT_SIZE))?; + + self.states_1 = outputs + .remove("tf_op_layer_stack_2") + .ok_or_else(|| Error::MissingOutput("tf_op_layer_stack_2".to_string()))? + .try_extract_array::()? + .view() + .to_owned() + .into_shape_with_order((1, 2, model::STATE_SIZE, 2))?; + + Ok(out_mask) + } + + fn run_model_2(&mut self, estimated_block: &Array3) -> Result, Error> { + let mut outputs = self.session_2.run(hypr_onnx::ort::inputs![ + "input_4" => TensorRef::from_array_view(estimated_block.view())?, + "input_5" => TensorRef::from_array_view(self.states_2.view())? + ])?; + + let out_block = outputs + .remove("conv1d_3") + .ok_or_else(|| Error::MissingOutput("conv1d_3".into()))? + .try_extract_array::()? + .view() + .to_owned() + .into_shape_with_order((1, 1, model::BLOCK_SIZE))?; + + self.states_2 = outputs + .remove("tf_op_layer_stack_5") + .ok_or_else(|| Error::MissingOutput("tf_op_layer_stack_5".into()))? + .try_extract_array::()? + .view() + .to_owned() + .into_shape_with_order((1, 2, model::STATE_SIZE, 2))?; + + Ok(out_block) + } + + fn normalize_output(&self, output: &mut [f32]) { + let max_val = output.iter().fold(0.0f32, |max, &x| max.max(x.abs())); + if max_val > 1.0 { + let scale = 0.99 / max_val; + output.iter_mut().for_each(|x| *x *= scale); + } + } +} diff --git a/crates/denoise/src/onnx/mod.rs b/crates/denoise/src/onnx/mod.rs index 232b22e4ef..b7821b3c80 100644 --- a/crates/denoise/src/onnx/mod.rs +++ b/crates/denoise/src/onnx/mod.rs @@ -1,256 +1,14 @@ mod buffer; +mod context; +mod denoiser; mod error; pub mod model; +pub use denoiser::Denoiser; pub use error::*; -use buffer::CircularBuffer; -use hypr_onnx::{ - ndarray::{Array3, Array4}, - ort::{session::Session, value::TensorRef}, -}; -use realfft::{ComplexToReal, RealFftPlanner, RealToComplex, num_complex::Complex}; -use std::sync::Arc; - -struct ProcessingContext { - scratch: Vec>, - ifft_scratch: Vec>, - fft_buffer: Vec, - fft_result: Vec>, - estimated_block_vec: Vec, - in_mag: Array3, - estimated_block: Array3, -} - -impl ProcessingContext { - fn new( - block_len: usize, - fft: &Arc>, - ifft: &Arc>, - ) -> Self { - Self { - scratch: vec![Complex::new(0.0f32, 0.0f32); fft.get_scratch_len()], - ifft_scratch: vec![Complex::new(0.0f32, 0.0f32); ifft.get_scratch_len()], - fft_buffer: vec![0.0f32; block_len], - fft_result: vec![Complex::new(0.0f32, 0.0f32); block_len / 2 + 1], - estimated_block_vec: vec![0.0f32; block_len], - in_mag: Array3::::zeros((1, 1, block_len / 2 + 1)), - estimated_block: Array3::::zeros((1, 1, block_len)), - } - } -} - -pub struct Denoiser { - session_1: Session, - session_2: Session, - block_len: usize, - block_shift: usize, - fft: Arc>, - ifft: Arc>, - states_1: Array4, - states_2: Array4, - in_buffer: CircularBuffer, - out_buffer: CircularBuffer, -} - -impl Denoiser { - pub fn new() -> Result { - let block_len = model::BLOCK_SIZE; - let block_shift = model::BLOCK_SHIFT; - - let mut fft_planner = RealFftPlanner::::new(); - let fft = fft_planner.plan_fft_forward(block_len); - let ifft = fft_planner.plan_fft_inverse(block_len); - - let session_1 = hypr_onnx::load_model_from_bytes(model::BYTES_1)?; - let session_2 = hypr_onnx::load_model_from_bytes(model::BYTES_2)?; - - let state_size = model::STATE_SIZE; - - Ok(Denoiser { - session_1, - session_2, - block_len, - block_shift, - fft, - ifft, - states_1: Array4::::zeros((1, 2, state_size, 2)), - states_2: Array4::::zeros((1, 2, state_size, 2)), - in_buffer: CircularBuffer::new(block_len, block_shift), - out_buffer: CircularBuffer::new(block_len, block_shift), - }) - } - - pub fn reset(&mut self) { - let state_size = model::STATE_SIZE; - self.states_1 = Array4::::zeros((1, 2, state_size, 2)); - self.states_2 = Array4::::zeros((1, 2, state_size, 2)); - self.in_buffer.clear(); - self.out_buffer.clear(); - } - - pub fn process(&mut self, input: &[f32]) -> Result, Error> { - self.reset(); - - let len_audio = input.len(); - - let padding = vec![0.0f32; self.block_len - self.block_shift]; - let mut audio = Vec::with_capacity(padding.len() * 2 + len_audio); - audio.extend(&padding); - audio.extend(input); - audio.extend(&padding); - - let result = self._process_internal(&audio, true)?; - - let start_idx = self.block_len - self.block_shift; - Ok(result[start_idx..start_idx + len_audio].to_vec()) - } - - pub fn process_streaming(&mut self, input: &[f32]) -> Result, Error> { - if input.is_empty() { - return Ok(vec![]); - } - - self._process_internal(input, false) - } - - fn _process_internal(&mut self, audio: &[f32], with_padding: bool) -> Result, Error> { - let mut out_file = vec![0.0f32; audio.len()]; - - let effective_len = if with_padding { - audio.len() - (self.block_len - self.block_shift) - } else { - audio.len() - }; - let num_blocks = effective_len / self.block_shift; - - let mut ctx = ProcessingContext::new(self.block_len, &self.fft, &self.ifft); - - for idx in 0..num_blocks { - let start = idx * self.block_shift; - let end = (start + self.block_shift).min(audio.len()); - - self.in_buffer.push_chunk(&audio[start..end]); - - // FFT - ctx.fft_buffer.copy_from_slice(self.in_buffer.data()); - self.fft.process_with_scratch( - &mut ctx.fft_buffer, - &mut ctx.fft_result, - &mut ctx.scratch, - )?; - - // Extract magnitude - for (i, &c) in ctx.fft_result.iter().enumerate() { - ctx.in_mag[[0, 0, i]] = c.norm(); - } - - // Model 1: magnitude + states → mask + new states - let out_mask = self.run_model_1(&ctx.in_mag)?; - - // Apply mask to complex spectrum - for (i, c) in ctx.fft_result.iter_mut().enumerate() { - *c *= out_mask[[0, 0, i]]; - } - - // IFFT - self.ifft.process_with_scratch( - &mut ctx.fft_result, - &mut ctx.estimated_block_vec, - &mut ctx.ifft_scratch, - )?; - - // Normalize - let norm_factor = 1.0 / self.block_len as f32; - ctx.estimated_block_vec - .iter_mut() - .for_each(|x| *x *= norm_factor); - - // Copy to Array3 for model 2 - for (i, &val) in ctx.estimated_block_vec.iter().enumerate() { - ctx.estimated_block[[0, 0, i]] = val; - } - - // Model 2: time-domain samples + states → refined samples + new states - let out_block = self.run_model_2(&ctx.estimated_block)?; - - // Overlap-add - let out_slice = out_block.as_slice().ok_or_else(|| { - Error::ShapeError(hypr_onnx::ndarray::ShapeError::from_kind( - hypr_onnx::ndarray::ErrorKind::IncompatibleLayout, - )) - })?; - self.out_buffer.shift_and_accumulate(out_slice); - - // Write to output - let out_start = idx * self.block_shift; - let out_end = (out_start + self.block_shift).min(out_file.len()); - let out_chunk_len = out_end - out_start; - if out_chunk_len > 0 { - out_file[out_start..out_end] - .copy_from_slice(&self.out_buffer.data()[..out_chunk_len]); - } - } - - self.normalize_output(&mut out_file); - Ok(out_file) - } - - fn run_model_1(&mut self, in_mag: &Array3) -> Result, Error> { - let mut outputs = self.session_1.run(hypr_onnx::ort::inputs![ - "input_2" => TensorRef::from_array_view(in_mag.view())?, - "input_3" => TensorRef::from_array_view(self.states_1.view())? - ])?; - - let out_mask = outputs - .remove("activation_2") - .ok_or_else(|| Error::MissingOutput("activation_2".to_string()))? - .try_extract_array::()? - .view() - .to_owned() - .into_shape_with_order((1, 1, model::FFT_OUT_SIZE))?; - - self.states_1 = outputs - .remove("tf_op_layer_stack_2") - .ok_or_else(|| Error::MissingOutput("tf_op_layer_stack_2".to_string()))? - .try_extract_array::()? - .view() - .to_owned() - .into_shape_with_order((1, 2, model::STATE_SIZE, 2))?; - - Ok(out_mask) - } - - fn run_model_2(&mut self, estimated_block: &Array3) -> Result, Error> { - let mut outputs = self.session_2.run(hypr_onnx::ort::inputs![ - "input_4" => TensorRef::from_array_view(estimated_block.view())?, - "input_5" => TensorRef::from_array_view(self.states_2.view())? - ])?; - - let out_block = outputs - .remove("conv1d_3") - .ok_or_else(|| Error::MissingOutput("conv1d_3".into()))? - .try_extract_array::()? - .view() - .to_owned() - .into_shape_with_order((1, 1, model::BLOCK_SIZE))?; - - self.states_2 = outputs - .remove("tf_op_layer_stack_5") - .ok_or_else(|| Error::MissingOutput("tf_op_layer_stack_5".into()))? - .try_extract_array::()? - .view() - .to_owned() - .into_shape_with_order((1, 2, model::STATE_SIZE, 2))?; - - Ok(out_block) - } - - fn normalize_output(&self, output: &mut [f32]) { - let max_val = output.iter().fold(0.0f32, |max, &x| max.max(x.abs())); - if max_val > 1.0 { - let scale = 0.99 / max_val; - output.iter_mut().for_each(|x| *x *= scale); - } - } -} +// cargo test -p denoise --features onnx +// +// Set UPDATE_SNAPSHOTS=1 to regenerate baseline snapshots. +#[cfg(test)] +mod tests; diff --git a/crates/denoise/src/onnx/tests.rs b/crates/denoise/src/onnx/tests.rs new file mode 100644 index 0000000000..118b34d646 --- /dev/null +++ b/crates/denoise/src/onnx/tests.rs @@ -0,0 +1,95 @@ +use super::denoiser::Denoiser; +use super::model::{BLOCK_SHIFT, BLOCK_SIZE}; +use approx::assert_abs_diff_eq; +use hypr_audio_snapshot::{SpectralConfig, Tolerances}; +use std::path::PathBuf; + +fn pcm_bytes_to_f32(bytes: &[u8]) -> Vec { + bytes + .chunks_exact(2) + .map(|c| i16::from_le_bytes([c[0], c[1]]) as f32 / 32768.0) + .collect() +} + +fn output_path(prefix: &str, mode: &str) -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("outputs") + .join(format!("{prefix}_{mode}.wav")) +} + +fn snapshot_path(prefix: &str, mode: &str) -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("data") + .join("snapshots") + .join(format!("{prefix}_{mode}.json")) +} + +#[test] +fn test_denoise_english_1() { + let config = SpectralConfig { + fft_size: BLOCK_SIZE, + hop_size: BLOCK_SHIFT, + sample_rate: 16000.0, + }; + let tolerances = Tolerances::default(); + + let samples = pcm_bytes_to_f32(hypr_data::english_1::AUDIO); + + let batch_result = { + let mut denoiser = Denoiser::new().unwrap(); + let result = denoiser.process(&samples).unwrap(); + assert!(result.iter().all(|&x| x.is_finite())); + result + }; + + let streaming_result = { + let mut denoiser = Denoiser::new().unwrap(); + let mut streaming_result = Vec::new(); + + let chunk_size = BLOCK_SIZE * 2; + let mut processed = 0; + + while processed < samples.len() { + let end = (processed + chunk_size).min(samples.len()); + let chunk = &samples[processed..end]; + + let chunk_result = denoiser.process_streaming(chunk).unwrap(); + streaming_result.extend(chunk_result); + + processed = end; + } + + assert!(streaming_result.iter().all(|&x| x.is_finite())); + streaming_result + }; + + let batch_snap = hypr_audio_snapshot::assert_or_update( + &batch_result, + &output_path("english_1", "batch"), + &snapshot_path("english_1", "batch"), + "english_1 batch", + &config, + &tolerances, + ); + + let streaming_snap = hypr_audio_snapshot::assert_or_update( + &streaming_result, + &output_path("english_1", "streaming"), + &snapshot_path("english_1", "streaming"), + "english_1 streaming", + &config, + &tolerances, + ); + + assert_abs_diff_eq!( + batch_snap.rms_energy, + streaming_snap.rms_energy, + epsilon = 0.05, + ); + assert_abs_diff_eq!( + batch_snap.spectral_centroid, + streaming_snap.spectral_centroid, + epsilon = 300.0, + ); +} diff --git a/crates/listener-core/src/actors/recorder/codec.rs b/crates/listener-core/src/actors/recorder/codec.rs index 7947e31e7d..5831005449 100644 --- a/crates/listener-core/src/actors/recorder/codec.rs +++ b/crates/listener-core/src/actors/recorder/codec.rs @@ -8,11 +8,11 @@ impl AudioCodec for Mp3Codec { } fn encode(&self, input: &Path, output: &Path) -> Result<(), Box> { - hypr_mp3::encode_wav(input, output) + Ok(hypr_mp3::encode_wav(input, output)?) } fn decode(&self, input: &Path, output: &Path) -> Result<(), Box> { - hypr_mp3::decode_to_wav(input, output) + Ok(hypr_mp3::decode_to_wav(input, output)?) } } diff --git a/crates/listener2-core/Cargo.toml b/crates/listener2-core/Cargo.toml index 5c82bc4057..fbda5a29db 100644 --- a/crates/listener2-core/Cargo.toml +++ b/crates/listener2-core/Cargo.toml @@ -12,6 +12,9 @@ hypr-audio-utils = { workspace = true } hypr-denoise = { workspace = true, features = ["onnx"] } hypr-host = { workspace = true } hypr-language = { workspace = true } +hypr-mp3 = { workspace = true } +hypr-storage = { workspace = true } +uuid = { workspace = true } hound = { workspace = true } diff --git a/crates/listener2-core/src/denoise.rs b/crates/listener2-core/src/denoise.rs index 03c0aa6641..9b456930f3 100644 --- a/crates/listener2-core/src/denoise.rs +++ b/crates/listener2-core/src/denoise.rs @@ -1,8 +1,13 @@ -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use hypr_storage::vault::audio::{ + AUDIO_MP3, AUDIO_POSTPROCESS_WAV, SESSION_ORIGINAL_AUDIO_FORMATS, +}; + use crate::DenoiseEvent; use crate::runtime::DenoiseRuntime; +use hypr_audio_utils::Source; const DENOISE_SAMPLE_RATE: u32 = 16000; const CHUNK_SIZE: usize = 16000; @@ -47,30 +52,46 @@ fn run_denoise_blocking( let source = hypr_audio_utils::source_from_path(¶ms.input_path) .map_err(|e| crate::Error::DenoiseError(e.to_string()))?; + let channels = source.channels() as usize; + let samples = hypr_audio_utils::resample_audio(source, DENOISE_SAMPLE_RATE) .map_err(|e| crate::Error::DenoiseError(e.to_string()))?; - let mut denoiser = hypr_denoise::onnx::Denoiser::new() - .map_err(|e| crate::Error::DenoiseError(e.to_string()))?; + let channel_data = hypr_audio_utils::deinterleave(&samples, channels); + + let total_chunks_per_channel = channel_data[0].len().div_ceil(CHUNK_SIZE); + let total_chunks = total_chunks_per_channel * channels; + let mut chunks_done = 0; - let total_chunks = samples.len().div_ceil(CHUNK_SIZE); - let mut output = Vec::with_capacity(samples.len()); + let mut denoised_channels: Vec> = Vec::with_capacity(channels); - for (i, chunk) in samples.chunks(CHUNK_SIZE).enumerate() { - let denoised = denoiser - .process_streaming(chunk) + for ch_samples in &channel_data { + let mut denoiser = hypr_denoise::onnx::Denoiser::new() .map_err(|e| crate::Error::DenoiseError(e.to_string()))?; - output.extend_from_slice(&denoised); - let percentage = ((i + 1) as f64 / total_chunks as f64) * 100.0; - runtime.emit(DenoiseEvent::DenoiseProgress { - session_id: params.session_id.clone(), - percentage, - }); + let mut ch_output = Vec::with_capacity(ch_samples.len()); + + for chunk in ch_samples.chunks(CHUNK_SIZE) { + let denoised = denoiser + .process_streaming(chunk) + .map_err(|e| crate::Error::DenoiseError(e.to_string()))?; + ch_output.extend_from_slice(&denoised); + + chunks_done += 1; + let percentage = (chunks_done as f64 / total_chunks as f64) * 100.0; + runtime.emit(DenoiseEvent::DenoiseProgress { + session_id: params.session_id.clone(), + percentage, + }); + } + + denoised_channels.push(ch_output); } + let output = hypr_audio_utils::interleave(&denoised_channels); + let spec = hound::WavSpec { - channels: 1, + channels: channels as u16, sample_rate: DENOISE_SAMPLE_RATE, bits_per_sample: 32, sample_format: hound::SampleFormat::Float, @@ -93,3 +114,117 @@ fn run_denoise_blocking( Ok(()) } + +pub fn confirm_denoise(runtime: &dyn DenoiseRuntime, session_id: &str) -> crate::Result { + let session_dir = resolve_session_dir(runtime, session_id)?; + let postprocess_path = session_dir.join(AUDIO_POSTPROCESS_WAV); + + if !postprocess_path.exists() { + return Err(crate::Error::DenoiseError(format!( + "{AUDIO_POSTPROCESS_WAV} not found" + ))); + } + + let target_path = session_dir.join(AUDIO_MP3); + let tmp_target_path = target_path.with_extension("mp3.tmp"); + + if tmp_target_path.exists() { + std::fs::remove_file(&tmp_target_path)?; + } + + if let Err(error) = hypr_mp3::encode_wav(&postprocess_path, &tmp_target_path) { + let _ = std::fs::remove_file(&tmp_target_path); + return Err(crate::Error::DenoiseError(error.to_string())); + } + + replace_file_atomically(&tmp_target_path, &target_path)?; + + for format in SESSION_ORIGINAL_AUDIO_FORMATS { + if format == AUDIO_MP3 { + continue; + } + let p = session_dir.join(format); + if p.exists() { + std::fs::remove_file(&p)?; + } + } + + std::fs::remove_file(&postprocess_path)?; + + Ok(target_path) +} + +pub fn revert_denoise(runtime: &dyn DenoiseRuntime, session_id: &str) -> crate::Result<()> { + let session_dir = resolve_session_dir(runtime, session_id)?; + let postprocess_path = session_dir.join(AUDIO_POSTPROCESS_WAV); + + if postprocess_path.exists() { + std::fs::remove_file(&postprocess_path)?; + } + + Ok(()) +} + +fn resolve_session_dir(runtime: &dyn DenoiseRuntime, session_id: &str) -> crate::Result { + let vault_base = runtime + .vault_base() + .map_err(|e| crate::Error::DenoiseError(e.to_string()))?; + let sessions_base = vault_base.join("sessions"); + Ok(find_session_dir(&sessions_base, session_id)) +} + +fn find_session_dir(sessions_base: &Path, session_id: &str) -> PathBuf { + find_session_dir_recursive(sessions_base, session_id) + .unwrap_or_else(|| sessions_base.join(session_id)) +} + +fn find_session_dir_recursive(dir: &Path, session_id: &str) -> Option { + let entries = std::fs::read_dir(dir).ok()?; + + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_dir() { + continue; + } + + let name = path.file_name()?.to_str()?; + + if name == session_id { + return Some(path); + } + + if uuid::Uuid::try_parse(name).is_err() { + if let Some(found) = find_session_dir_recursive(&path, session_id) { + return Some(found); + } + } + } + + None +} + +fn replace_file_atomically(tmp_path: &Path, target_path: &Path) -> std::io::Result<()> { + let backup_path = target_path.with_extension("mp3.bak"); + + if backup_path.exists() { + std::fs::remove_file(&backup_path)?; + } + + let had_target = target_path.exists(); + if had_target { + std::fs::rename(target_path, &backup_path)?; + } + + if let Err(error) = std::fs::rename(tmp_path, target_path) { + if had_target { + let _ = std::fs::rename(&backup_path, target_path); + } + return Err(error); + } + + if had_target { + std::fs::remove_file(backup_path)?; + } + + Ok(()) +} diff --git a/crates/listener2-core/src/lib.rs b/crates/listener2-core/src/lib.rs index 75153f266e..230b827f04 100644 --- a/crates/listener2-core/src/lib.rs +++ b/crates/listener2-core/src/lib.rs @@ -6,7 +6,7 @@ mod runtime; mod subtitle; pub use batch::{BatchParams, BatchProvider, run_batch}; -pub use denoise::{DenoiseParams, run_denoise}; +pub use denoise::*; pub use error::*; pub use events::*; pub use runtime::*; diff --git a/crates/listener2-core/src/runtime.rs b/crates/listener2-core/src/runtime.rs index 989fb74575..e721a81325 100644 --- a/crates/listener2-core/src/runtime.rs +++ b/crates/listener2-core/src/runtime.rs @@ -5,6 +5,6 @@ pub trait BatchRuntime: Send + Sync + 'static { fn emit(&self, event: BatchEvent); } -pub trait DenoiseRuntime: Send + Sync + 'static { +pub trait DenoiseRuntime: hypr_storage::StorageRuntime { fn emit(&self, event: DenoiseEvent); } diff --git a/crates/mp3/Cargo.toml b/crates/mp3/Cargo.toml index 24ca05d95e..d155dc236d 100644 --- a/crates/mp3/Cargo.toml +++ b/crates/mp3/Cargo.toml @@ -7,6 +7,7 @@ edition = "2024" hound = { workspace = true } hypr-audio-utils = { workspace = true } mp3lame-encoder = { workspace = true } +thiserror = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/crates/mp3/src/error.rs b/crates/mp3/src/error.rs new file mode 100644 index 0000000000..3631dec2bc --- /dev/null +++ b/crates/mp3/src/error.rs @@ -0,0 +1,37 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("unsupported channel count: {0} (expected 1 or 2)")] + UnsupportedChannelCount(u16), + + #[error("unsupported float bit depth: {0}")] + UnsupportedFloatBitDepth(u16), + + #[error("unsupported integer bit depth: {0}")] + UnsupportedIntBitDepth(u16), + + #[error("failed to create LAME encoder")] + LameInit, + + #[error("LAME configuration error: {0}")] + LameConfig(String), + + #[error("LAME build error: {0}")] + LameBuild(String), + + #[error("LAME encode error: {0}")] + LameEncode(String), + + #[error("LAME flush error: {0}")] + LameFlush(String), + + #[error(transparent)] + Wav(#[from] hound::Error), + + #[error(transparent)] + Io(#[from] std::io::Error), + + #[error(transparent)] + AudioUtils(#[from] hypr_audio_utils::Error), +} diff --git a/crates/mp3/src/lib.rs b/crates/mp3/src/lib.rs index 52db2ea96e..49f9a584df 100644 --- a/crates/mp3/src/lib.rs +++ b/crates/mp3/src/lib.rs @@ -1,3 +1,7 @@ +mod error; + +pub use error::Error; + use std::path::Path; use hound::SampleFormat; @@ -5,41 +9,51 @@ use mp3lame_encoder::{Builder as LameBuilder, DualPcm, FlushNoGap, MonoPcm}; const CHUNK_FRAMES: usize = 4096; -pub fn encode_wav(wav_path: &Path, mp3_path: &Path) -> Result<(), Box> { +pub fn encode_wav(wav_path: &Path, mp3_path: &Path) -> Result<(), Error> { let mut reader = hound::WavReader::open(wav_path)?; let spec = reader.spec(); + let num_channels = match spec.channels { 1 => 1u8, 2 => 2u8, - count => { - return Err(format!("unsupported channel count: {count} (expected 1 or 2)").into()); - } + count => return Err(Error::UnsupportedChannelCount(count)), }; + let bitrate = if num_channels > 1 { + mp3lame_encoder::Bitrate::Kbps128 + } else { + mp3lame_encoder::Bitrate::Kbps64 + }; + let sample_rate = spec.sample_rate; - let mut mp3_builder = LameBuilder::new().ok_or("Failed to create LAME builder")?; - mp3_builder - .set_num_channels(num_channels) - .map_err(|e| format!("set channels error: {:?}", e))?; - mp3_builder - .set_sample_rate(sample_rate) - .map_err(|e| format!("set sample rate error: {:?}", e))?; - mp3_builder - .set_brate(mp3lame_encoder::Bitrate::Kbps128) - .map_err(|e| format!("set bitrate error: {:?}", e))?; - mp3_builder - .set_quality(mp3lame_encoder::Quality::Best) - .map_err(|e| format!("set quality error: {:?}", e))?; - let mut encoder = mp3_builder - .build() - .map_err(|e| format!("LAME build error: {:?}", e))?; + let mut encoder = { + let mut mp3_builder = LameBuilder::new().ok_or(Error::LameInit)?; + mp3_builder + .set_num_channels(num_channels) + .map_err(|e| Error::LameConfig(format!("{:?}", e)))?; + mp3_builder + .set_sample_rate(sample_rate) + .map_err(|e| Error::LameConfig(format!("{:?}", e)))?; + + mp3_builder + .set_brate(bitrate) + .map_err(|e| Error::LameConfig(format!("{:?}", e)))?; + mp3_builder + .set_quality(mp3lame_encoder::Quality::NearBest) + .map_err(|e| Error::LameConfig(format!("{:?}", e)))?; + + mp3_builder + .build() + .map_err(|e| Error::LameBuild(format!("{:?}", e)))? + }; + let mut mp3_out: Vec = Vec::new(); let bits_per_sample = spec.bits_per_sample; match spec.sample_format { SampleFormat::Float => { if bits_per_sample != 32 { - return Err(format!("unsupported float bit depth: {bits_per_sample}").into()); + return Err(Error::UnsupportedFloatBitDepth(bits_per_sample)); } if num_channels == 1 { @@ -48,7 +62,7 @@ pub fn encode_wav(wav_path: &Path, mp3_path: &Path) -> Result<(), Box(), f32_to_i16, |left, right| { @@ -56,7 +70,7 @@ pub fn encode_wav(wav_path: &Path, mp3_path: &Path) -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box return Err(format!("unsupported integer bit depth: {bits}").into()), + bits => return Err(Error::UnsupportedIntBitDepth(bits)), }, } mp3_out.reserve(mp3lame_encoder::max_required_buffer_size(0)); encoder .flush_to_vec::(&mut mp3_out) - .map_err(|e| format!("flush error: {:?}", e))?; + .map_err(|e| Error::LameFlush(format!("{:?}", e)))?; std::fs::write(mp3_path, &mp3_out)?; Ok(()) } -pub fn decode_to_wav(mp3_path: &Path, wav_path: &Path) -> Result<(), Box> { +pub fn decode_to_wav(mp3_path: &Path, wav_path: &Path) -> Result<(), Error> { use hypr_audio_utils::Source; let source = hypr_audio_utils::source_from_path(mp3_path)?; @@ -194,11 +208,11 @@ fn encode_mono_samples( samples: I, mut sample_to_i16: F, mut encode_chunk: E, -) -> Result<(), Box> +) -> Result<(), Error> where I: Iterator>, F: FnMut(S) -> i16, - E: FnMut(&[i16]) -> Result>, + E: FnMut(&[i16]) -> Result, { let mut pcm_i16 = Vec::with_capacity(CHUNK_FRAMES); for sample in samples { @@ -222,11 +236,11 @@ fn encode_stereo_samples( mut samples: I, mut sample_to_i16: F, mut encode_chunk: E, -) -> Result<(), Box> +) -> Result<(), Error> where I: Iterator>, F: FnMut(S) -> i16, - E: FnMut(&[i16], &[i16]) -> Result>, + E: FnMut(&[i16], &[i16]) -> Result, { let mut left = Vec::with_capacity(CHUNK_FRAMES); let mut right = Vec::with_capacity(CHUNK_FRAMES); @@ -280,7 +294,7 @@ mod tests { } #[test] - fn encode_mono_samples_flushes_partial_tail() -> Result<(), Box> { + fn encode_mono_samples_flushes_partial_tail() -> Result<(), Error> { let samples = (0..(CHUNK_FRAMES + 1)) .map(|n| Ok(n as i16)) .collect::>() @@ -301,7 +315,7 @@ mod tests { } #[test] - fn encode_stereo_samples_pads_missing_right_sample() -> Result<(), Box> { + fn encode_stereo_samples_pads_missing_right_sample() -> Result<(), Error> { let samples = vec![Ok(10i16), Ok(20i16), Ok(30i16)].into_iter(); let mut encoded = Vec::new(); diff --git a/crates/mp3/tests/roundtrip.rs b/crates/mp3/tests/roundtrip.rs index 4717839693..6e4fda2f05 100644 --- a/crates/mp3/tests/roundtrip.rs +++ b/crates/mp3/tests/roundtrip.rs @@ -39,27 +39,7 @@ fn write_fixture_wav(path: &Path, case: Case) -> Result, Box Result<(), Box> { - let spec = hound::WavSpec { - channels: case.channels, - sample_rate: case.sample_rate, - bits_per_sample: 16, - sample_format: hound::SampleFormat::Int, - }; - - let mut writer = hound::WavWriter::create(path, spec)?; - for frame in 0..case.frames { - for channel in 0..case.channels as usize { - let sample = fixture_sample(frame, channel); - let sample_i16 = (sample * i16::MAX as f32) as i16; - writer.write_sample(sample_i16)?; - } - } - writer.finalize()?; - Ok(()) -} - -fn write_fixture_wav_i32( +fn write_fixture_wav_int( path: &Path, case: Case, bits_per_sample: u16, @@ -71,17 +51,19 @@ fn write_fixture_wav_i32( sample_format: hound::SampleFormat::Int, }; - let mut writer = hound::WavWriter::create(path, spec)?; let max_amplitude = match bits_per_sample { + 8 => i8::MAX as f32, + 16 => i16::MAX as f32, 17..=31 => ((1i64 << (bits_per_sample - 1)) - 1) as f32, 32 => i32::MAX as f32, - bits => return Err(format!("unsupported bit depth for i32 fixture: {bits}").into()), + bits => return Err(format!("unsupported bit depth: {bits}").into()), }; + + let mut writer = hound::WavWriter::create(path, spec)?; for frame in 0..case.frames { for channel in 0..case.channels as usize { let sample = fixture_sample(frame, channel); - let sample_i32 = (sample * max_amplitude) as i32; - writer.write_sample(sample_i32)?; + writer.write_sample((sample * max_amplitude) as i32)?; } } writer.finalize()?; @@ -122,6 +104,16 @@ fn write_malformed_stereo_wav_with_odd_samples( Ok(()) } +fn assert_samples_valid(samples: &[f32]) { + for sample in samples { + assert!(sample.is_finite(), "decoded sample is not finite"); + assert!( + (-1.1..=1.1).contains(sample), + "decoded sample out of expected range: {sample}" + ); + } +} + fn read_wav(path: &Path) -> Result<(hound::WavSpec, Vec), Box> { let mut reader = hound::WavReader::open(path)?; let spec = reader.spec(); @@ -153,13 +145,7 @@ fn assert_roundtrip(case: Case) -> Result<(), Box> { "sample rate changed" ); - for sample in &decoded_samples { - assert!(sample.is_finite(), "decoded sample is not finite"); - assert!( - (-1.1..=1.1).contains(sample), - "decoded sample out of expected range: {sample}" - ); - } + assert_samples_valid(&decoded_samples); if case.frames == 0 { let max_len = 4096 * case.channels as usize; @@ -249,149 +235,38 @@ fn rejects_more_than_two_channels() -> Result<(), Box> { Ok(()) } -#[test] -fn roundtrip_pcm16_stereo_input() -> Result<(), Box> { - let tempdir = tempdir()?; - let case = Case { - channels: 2, - frames: 8_192, - sample_rate: 44_100, - }; - let wav_path = tempdir.path().join("input_i16.wav"); - let mp3_path = tempdir.path().join("encoded.mp3"); - let decoded_wav_path = tempdir.path().join("decoded.wav"); - - write_fixture_wav_i16(&wav_path, case)?; - encode_wav(&wav_path, &mp3_path)?; - decode_to_wav(&mp3_path, &decoded_wav_path)?; - - let (decoded_spec, decoded_samples) = read_wav(&decoded_wav_path)?; - assert_eq!( - decoded_spec.channels, case.channels, - "channel count changed" - ); - assert_eq!( - decoded_spec.sample_rate, case.sample_rate, - "sample rate changed" - ); - assert!( - !decoded_samples.is_empty(), - "decoded pcm16 input to empty output" - ); - for sample in &decoded_samples { - assert!(sample.is_finite(), "decoded sample is not finite"); - assert!( - (-1.1..=1.1).contains(sample), - "decoded sample out of expected range: {sample}" - ); - } - - Ok(()) -} - -#[test] -fn roundtrip_pcm8_mono_input() -> Result<(), Box> { - let tempdir = tempdir()?; - let case = Case { - channels: 1, - frames: 4_096, - sample_rate: 16_000, - }; - let wav_path = tempdir.path().join("input_i8.wav"); - let mp3_path = tempdir.path().join("encoded.mp3"); - let decoded_wav_path = tempdir.path().join("decoded.wav"); - - let spec = hound::WavSpec { - channels: case.channels, - sample_rate: case.sample_rate, - bits_per_sample: 8, - sample_format: hound::SampleFormat::Int, - }; - let mut writer = hound::WavWriter::create(&wav_path, spec)?; - for frame in 0..case.frames { - let sample = fixture_sample(frame, 0); - writer.write_sample((sample * i8::MAX as f32) as i8)?; - } - writer.finalize()?; - - encode_wav(&wav_path, &mp3_path)?; - decode_to_wav(&mp3_path, &decoded_wav_path)?; - - let (decoded_spec, decoded_samples) = read_wav(&decoded_wav_path)?; - assert_eq!(decoded_spec.channels, case.channels); - assert_eq!(decoded_spec.sample_rate, case.sample_rate); - assert!(!decoded_samples.is_empty()); - for sample in &decoded_samples { - assert!(sample.is_finite(), "decoded sample is not finite"); - assert!( - (-1.1..=1.1).contains(sample), - "decoded sample out of expected range: {sample}" - ); - } - - Ok(()) -} - -#[test] -fn roundtrip_pcm24_mono_input() -> Result<(), Box> { - let tempdir = tempdir()?; - let case = Case { - channels: 1, - frames: 8_192, - sample_rate: 22_050, +macro_rules! pcm_roundtrip_cases { + ($($name:ident => { bits: $bits:expr, channels: $channels:expr, frames: $frames:expr, sample_rate: $sample_rate:expr }),+ $(,)?) => { + $( + #[test] + fn $name() -> Result<(), Box> { + let tempdir = tempdir()?; + let case = Case { channels: $channels, frames: $frames, sample_rate: $sample_rate }; + let wav_path = tempdir.path().join("input.wav"); + let mp3_path = tempdir.path().join("encoded.mp3"); + let decoded_wav_path = tempdir.path().join("decoded.wav"); + + write_fixture_wav_int(&wav_path, case, $bits)?; + encode_wav(&wav_path, &mp3_path)?; + decode_to_wav(&mp3_path, &decoded_wav_path)?; + + let (decoded_spec, decoded_samples) = read_wav(&decoded_wav_path)?; + assert_eq!(decoded_spec.channels, case.channels, "channel count changed"); + assert_eq!(decoded_spec.sample_rate, case.sample_rate, "sample rate changed"); + assert!(!decoded_samples.is_empty(), "decoded output is empty"); + assert_samples_valid(&decoded_samples); + + Ok(()) + } + )+ }; - let wav_path = tempdir.path().join("input_i24.wav"); - let mp3_path = tempdir.path().join("encoded.mp3"); - let decoded_wav_path = tempdir.path().join("decoded.wav"); - - write_fixture_wav_i32(&wav_path, case, 24)?; - encode_wav(&wav_path, &mp3_path)?; - decode_to_wav(&mp3_path, &decoded_wav_path)?; - - let (decoded_spec, decoded_samples) = read_wav(&decoded_wav_path)?; - assert_eq!(decoded_spec.channels, case.channels); - assert_eq!(decoded_spec.sample_rate, case.sample_rate); - assert!(!decoded_samples.is_empty()); - for sample in &decoded_samples { - assert!(sample.is_finite(), "decoded sample is not finite"); - assert!( - (-1.1..=1.1).contains(sample), - "decoded sample out of expected range: {sample}" - ); - } - - Ok(()) } -#[test] -fn roundtrip_pcm32_stereo_input() -> Result<(), Box> { - let tempdir = tempdir()?; - let case = Case { - channels: 2, - frames: 6_321, - sample_rate: 48_000, - }; - let wav_path = tempdir.path().join("input_i32.wav"); - let mp3_path = tempdir.path().join("encoded.mp3"); - let decoded_wav_path = tempdir.path().join("decoded.wav"); - - write_fixture_wav_i32(&wav_path, case, 32)?; - encode_wav(&wav_path, &mp3_path)?; - decode_to_wav(&mp3_path, &decoded_wav_path)?; - - let (decoded_spec, decoded_samples) = read_wav(&decoded_wav_path)?; - assert_eq!(decoded_spec.channels, case.channels); - assert_eq!(decoded_spec.sample_rate, case.sample_rate); - assert!(!decoded_samples.is_empty()); - for sample in &decoded_samples { - assert!(sample.is_finite(), "decoded sample is not finite"); - assert!( - (-1.1..=1.1).contains(sample), - "decoded sample out of expected range: {sample}" - ); - } - - Ok(()) +pcm_roundtrip_cases! { + roundtrip_pcm8_mono => { bits: 8, channels: 1, frames: 4_096, sample_rate: 16_000 }, + roundtrip_pcm16_stereo => { bits: 16, channels: 2, frames: 8_192, sample_rate: 44_100 }, + roundtrip_pcm24_mono => { bits: 24, channels: 1, frames: 8_192, sample_rate: 22_050 }, + roundtrip_pcm32_stereo => { bits: 32, channels: 2, frames: 6_321, sample_rate: 48_000 }, } #[test] diff --git a/crates/storage/src/vault/audio.rs b/crates/storage/src/vault/audio.rs new file mode 100644 index 0000000000..f490627b51 --- /dev/null +++ b/crates/storage/src/vault/audio.rs @@ -0,0 +1,12 @@ +pub const AUDIO_POSTPROCESS_WAV: &str = "audio-postprocess.wav"; +pub const AUDIO_MP3: &str = "audio.mp3"; +pub const AUDIO_WAV: &str = "audio.wav"; +pub const AUDIO_OGG: &str = "audio.ogg"; + +/// All session audio filenames in lookup precedence order. +/// Postprocessed output takes priority over originals. +pub const SESSION_AUDIO_CANDIDATES: [&str; 4] = + [AUDIO_POSTPROCESS_WAV, AUDIO_MP3, AUDIO_WAV, AUDIO_OGG]; + +/// Original (non-postprocessed) audio formats. +pub const SESSION_ORIGINAL_AUDIO_FORMATS: [&str; 3] = [AUDIO_MP3, AUDIO_WAV, AUDIO_OGG]; diff --git a/crates/storage/src/vault/mod.rs b/crates/storage/src/vault/mod.rs index 2d9f819898..cb9105fc61 100644 --- a/crates/storage/src/vault/mod.rs +++ b/crates/storage/src/vault/mod.rs @@ -1,5 +1,7 @@ +pub mod audio; pub mod fs; pub mod path; +pub use audio::*; pub use fs::*; pub use path::*; diff --git a/plugins/fs-sync/Cargo.toml b/plugins/fs-sync/Cargo.toml index e3a48f7319..a63ed6316a 100644 --- a/plugins/fs-sync/Cargo.toml +++ b/plugins/fs-sync/Cargo.toml @@ -21,6 +21,7 @@ tokio = { workspace = true, features = ["macros"] } [dependencies] hypr-audio-utils = { workspace = true } hypr-frontmatter = { workspace = true } +hypr-storage = { workspace = true } hypr-tiptap = { workspace = true } tauri = { workspace = true, features = ["test"] } @@ -36,6 +37,7 @@ specta = { workspace = true, features = ["serde_json"] } glob = "0.3" hypr-afconvert = { workspace = true } +hypr-mp3 = { workspace = true } rayon = { workspace = true } rodio = { workspace = true, features = ["symphonia-all"] } diff --git a/plugins/fs-sync/js/bindings.gen.ts b/plugins/fs-sync/js/bindings.gen.ts index 873691e6ee..9f54d78a0f 100644 --- a/plugins/fs-sync/js/bindings.gen.ts +++ b/plugins/fs-sync/js/bindings.gen.ts @@ -211,7 +211,7 @@ export type ListFoldersResult = { folders: Partial<{ [key in string]: FolderInfo export type ParsedDocument = { frontmatter: Partial<{ [key in string]: JsonValue }>; content: string } export type ScanResult = { files: Partial<{ [key in string]: string }>; dirs: string[] } export type SessionContentData = { sessionId: string; meta: SessionMetaData | null; rawMemoTiptapJson: JsonValue | null; transcript: TranscriptData | null; notes: SessionNoteData[] } -export type SessionMetaData = { id: string; userId: string; createdAt: string | null; title: string | null; event: JsonValue | null; eventId: string | null; participants: SessionMetaParticipant[]; tags: string[] } +export type SessionMetaData = { id: string; userId: string; createdAt: string | null; title: string | null; event: JsonValue | null; eventId: string | null; participants?: SessionMetaParticipant[]; tags?: string[] } export type SessionMetaParticipant = { id: string; userId: string; sessionId: string; humanId: string; source: string } export type SessionNoteData = { id: string; sessionId: string; templateId: string | null; position: number | null; title: string | null; tiptapJson: JsonValue } export type TranscriptData = { transcripts: TranscriptEntry[] } diff --git a/plugins/fs-sync/src/audio.rs b/plugins/fs-sync/src/audio.rs index d42c58576e..b77aa3fc1d 100644 --- a/plugins/fs-sync/src/audio.rs +++ b/plugins/fs-sync/src/audio.rs @@ -7,13 +7,14 @@ use hypr_audio_utils::{ Source, VorbisEncodeSettings, encode_vorbis_mono, mix_down_to_mono, resample_audio, }; +use hypr_storage::vault::audio::SESSION_AUDIO_CANDIDATES; + use crate::error::{AudioImportError, AudioProcessingError}; const TARGET_SAMPLE_RATE_HZ: u32 = 16_000; -const AUDIO_FORMATS: [&str; 3] = ["audio.mp3", "audio.wav", "audio.ogg"]; pub fn exists(session_dir: &Path) -> std::io::Result { - AUDIO_FORMATS + SESSION_AUDIO_CANDIDATES .iter() .map(|format| session_dir.join(format)) .try_fold(false, |acc, path| { @@ -22,7 +23,7 @@ pub fn exists(session_dir: &Path) -> std::io::Result { } pub fn delete(session_dir: &Path) -> std::io::Result<()> { - for format in AUDIO_FORMATS { + for format in SESSION_AUDIO_CANDIDATES { let path = session_dir.join(format); if std::fs::exists(&path).unwrap_or(false) { std::fs::remove_file(&path)?; @@ -32,7 +33,7 @@ pub fn delete(session_dir: &Path) -> std::io::Result<()> { } pub fn path(session_dir: &Path) -> Option { - AUDIO_FORMATS + SESSION_AUDIO_CANDIDATES .iter() .map(|format| session_dir.join(format)) .find(|path| path.exists()) diff --git a/plugins/listener/src/lib.rs b/plugins/listener/src/lib.rs index c7fcfb506f..1bf733f9ff 100644 --- a/plugins/listener/src/lib.rs +++ b/plugins/listener/src/lib.rs @@ -54,7 +54,9 @@ pub fn init() -> tauri::plugin::TauriPlugin { let app_handle = app.app_handle().clone(); let runtime = Arc::new(TauriRuntime { - app: app_handle.clone(), + storage: tauri_plugin_settings::TauriStorageRuntime { + app: app_handle.clone(), + }, }); tauri::async_runtime::spawn(async move { diff --git a/plugins/listener/src/runtime.rs b/plugins/listener/src/runtime.rs index 6b64a3daaf..ad27b9eb37 100644 --- a/plugins/listener/src/runtime.rs +++ b/plugins/listener/src/runtime.rs @@ -1,26 +1,17 @@ use hypr_listener_core::ListenerRuntime; -use tauri_plugin_settings::SettingsPluginExt; use tauri_specta::Event; pub struct TauriRuntime { - pub app: tauri::AppHandle, + pub storage: tauri_plugin_settings::TauriStorageRuntime, } impl hypr_storage::StorageRuntime for TauriRuntime { fn global_base(&self) -> Result { - self.app - .settings() - .global_base() - .map(|p| p.into_std_path_buf()) - .map_err(|_| hypr_storage::Error::DataDirUnavailable) + self.storage.global_base() } fn vault_base(&self) -> Result { - self.app - .settings() - .cached_vault_base() - .map(|p| p.into_std_path_buf()) - .map_err(|_| hypr_storage::Error::DataDirUnavailable) + self.storage.vault_base() } } @@ -29,37 +20,37 @@ impl ListenerRuntime for TauriRuntime { use tauri_plugin_tray::TrayPluginExt; match &event { hypr_listener_core::SessionLifecycleEvent::Active { .. } => { - let _ = self.app.tray().set_start_disabled(true); + let _ = self.storage.app.tray().set_start_disabled(true); } hypr_listener_core::SessionLifecycleEvent::Inactive { .. } => { - let _ = self.app.tray().set_start_disabled(false); + let _ = self.storage.app.tray().set_start_disabled(false); } hypr_listener_core::SessionLifecycleEvent::Finalizing { .. } => {} } let plugin_event: crate::events::SessionLifecycleEvent = event.into(); - if let Err(error) = plugin_event.emit(&self.app) { + if let Err(error) = plugin_event.emit(&self.storage.app) { tracing::error!(?error, "failed_to_emit_lifecycle_event"); } } fn emit_progress(&self, event: hypr_listener_core::SessionProgressEvent) { let plugin_event: crate::events::SessionProgressEvent = event.into(); - if let Err(error) = plugin_event.emit(&self.app) { + if let Err(error) = plugin_event.emit(&self.storage.app) { tracing::error!(?error, "failed_to_emit_progress_event"); } } fn emit_error(&self, event: hypr_listener_core::SessionErrorEvent) { let plugin_event: crate::events::SessionErrorEvent = event.into(); - if let Err(error) = plugin_event.emit(&self.app) { + if let Err(error) = plugin_event.emit(&self.storage.app) { tracing::error!(?error, "failed_to_emit_error_event"); } } fn emit_data(&self, event: hypr_listener_core::SessionDataEvent) { let plugin_event: crate::events::SessionDataEvent = event.into(); - if let Err(error) = plugin_event.emit(&self.app) { + if let Err(error) = plugin_event.emit(&self.storage.app) { tracing::error!(?error, "failed_to_emit_data_event"); } } diff --git a/plugins/listener2/Cargo.toml b/plugins/listener2/Cargo.toml index 7b6e333db5..ae338e2047 100644 --- a/plugins/listener2/Cargo.toml +++ b/plugins/listener2/Cargo.toml @@ -16,6 +16,7 @@ specta-typescript = { workspace = true } [dependencies] hypr-language = { workspace = true } hypr-listener2-core = { workspace = true, features = ["specta"] } +hypr-storage = { workspace = true } owhisper-interface = { workspace = true } tauri-plugin-settings = { workspace = true } diff --git a/plugins/listener2/build.rs b/plugins/listener2/build.rs index c99e474e42..ddcccd0905 100644 --- a/plugins/listener2/build.rs +++ b/plugins/listener2/build.rs @@ -1,6 +1,8 @@ const COMMANDS: &[&str] = &[ "run_batch", "run_denoise", + "audio_confirm_denoise", + "audio_revert_denoise", "parse_subtitle", "export_to_vtt", "is_supported_languages_batch", diff --git a/plugins/listener2/js/bindings.gen.ts b/plugins/listener2/js/bindings.gen.ts index ecdcac0f09..bb97965f57 100644 --- a/plugins/listener2/js/bindings.gen.ts +++ b/plugins/listener2/js/bindings.gen.ts @@ -22,6 +22,22 @@ async runDenoise(params: DenoiseParams) : Promise> { else return { status: "error", error: e as any }; } }, +async audioConfirmDenoise(sessionId: string) : Promise> { + try { + return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|audio_confirm_denoise", { sessionId }) }; +} catch (e) { + if(e instanceof Error) throw e; + else return { status: "error", error: e as any }; +} +}, +async audioRevertDenoise(sessionId: string) : Promise> { + try { + return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|audio_revert_denoise", { sessionId }) }; +} catch (e) { + if(e instanceof Error) throw e; + else return { status: "error", error: e as any }; +} +}, async parseSubtitle(path: string) : Promise> { try { return { status: "ok", data: await TAURI_INVOKE("plugin:listener2|parse_subtitle", { path }) }; diff --git a/plugins/listener2/permissions/autogenerated/commands/audio_confirm_denoise.toml b/plugins/listener2/permissions/autogenerated/commands/audio_confirm_denoise.toml new file mode 100644 index 0000000000..f0f8f03017 --- /dev/null +++ b/plugins/listener2/permissions/autogenerated/commands/audio_confirm_denoise.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-audio-confirm-denoise" +description = "Enables the audio_confirm_denoise command without any pre-configured scope." +commands.allow = ["audio_confirm_denoise"] + +[[permission]] +identifier = "deny-audio-confirm-denoise" +description = "Denies the audio_confirm_denoise command without any pre-configured scope." +commands.deny = ["audio_confirm_denoise"] diff --git a/plugins/listener2/permissions/autogenerated/commands/audio_revert_denoise.toml b/plugins/listener2/permissions/autogenerated/commands/audio_revert_denoise.toml new file mode 100644 index 0000000000..8c0c5d65b8 --- /dev/null +++ b/plugins/listener2/permissions/autogenerated/commands/audio_revert_denoise.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-audio-revert-denoise" +description = "Enables the audio_revert_denoise command without any pre-configured scope." +commands.allow = ["audio_revert_denoise"] + +[[permission]] +identifier = "deny-audio-revert-denoise" +description = "Denies the audio_revert_denoise command without any pre-configured scope." +commands.deny = ["audio_revert_denoise"] diff --git a/plugins/listener2/permissions/autogenerated/reference.md b/plugins/listener2/permissions/autogenerated/reference.md index 64b5aad495..04a435b82a 100644 --- a/plugins/listener2/permissions/autogenerated/reference.md +++ b/plugins/listener2/permissions/autogenerated/reference.md @@ -6,6 +6,8 @@ Default permissions for the plugin - `allow-run-batch` - `allow-run-denoise` +- `allow-audio-confirm-denoise` +- `allow-audio-revert-denoise` - `allow-parse-subtitle` - `allow-export-to-vtt` - `allow-is-supported-languages-batch` @@ -21,6 +23,58 @@ Default permissions for the plugin + + + +`listener2:allow-audio-confirm-denoise` + + + + +Enables the audio_confirm_denoise command without any pre-configured scope. + + + + + + + +`listener2:deny-audio-confirm-denoise` + + + + +Denies the audio_confirm_denoise command without any pre-configured scope. + + + + + + + +`listener2:allow-audio-revert-denoise` + + + + +Enables the audio_revert_denoise command without any pre-configured scope. + + + + + + + +`listener2:deny-audio-revert-denoise` + + + + +Denies the audio_revert_denoise command without any pre-configured scope. + + + + diff --git a/plugins/listener2/permissions/default.toml b/plugins/listener2/permissions/default.toml index 05df082f4f..06dc737e39 100644 --- a/plugins/listener2/permissions/default.toml +++ b/plugins/listener2/permissions/default.toml @@ -3,6 +3,8 @@ description = "Default permissions for the plugin" permissions = [ "allow-run-batch", "allow-run-denoise", + "allow-audio-confirm-denoise", + "allow-audio-revert-denoise", "allow-parse-subtitle", "allow-export-to-vtt", "allow-is-supported-languages-batch", diff --git a/plugins/listener2/permissions/schemas/schema.json b/plugins/listener2/permissions/schemas/schema.json index 9ad15c3b2b..3476961d1b 100644 --- a/plugins/listener2/permissions/schemas/schema.json +++ b/plugins/listener2/permissions/schemas/schema.json @@ -294,6 +294,30 @@ "PermissionKind": { "type": "string", "oneOf": [ + { + "description": "Enables the audio_confirm_denoise command without any pre-configured scope.", + "type": "string", + "const": "allow-audio-confirm-denoise", + "markdownDescription": "Enables the audio_confirm_denoise command without any pre-configured scope." + }, + { + "description": "Denies the audio_confirm_denoise command without any pre-configured scope.", + "type": "string", + "const": "deny-audio-confirm-denoise", + "markdownDescription": "Denies the audio_confirm_denoise command without any pre-configured scope." + }, + { + "description": "Enables the audio_revert_denoise command without any pre-configured scope.", + "type": "string", + "const": "allow-audio-revert-denoise", + "markdownDescription": "Enables the audio_revert_denoise command without any pre-configured scope." + }, + { + "description": "Denies the audio_revert_denoise command without any pre-configured scope.", + "type": "string", + "const": "deny-audio-revert-denoise", + "markdownDescription": "Denies the audio_revert_denoise command without any pre-configured scope." + }, { "description": "Enables the export_to_vtt command without any pre-configured scope.", "type": "string", @@ -379,10 +403,10 @@ "markdownDescription": "Denies the suggest_providers_for_languages_batch command without any pre-configured scope." }, { - "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-run-denoise`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`", + "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-run-denoise`\n- `allow-audio-confirm-denoise`\n- `allow-audio-revert-denoise`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`", "type": "string", "const": "default", - "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-run-denoise`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`" + "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-run-batch`\n- `allow-run-denoise`\n- `allow-audio-confirm-denoise`\n- `allow-audio-revert-denoise`\n- `allow-parse-subtitle`\n- `allow-export-to-vtt`\n- `allow-is-supported-languages-batch`\n- `allow-suggest-providers-for-languages-batch`\n- `allow-list-documented-language-codes-batch`" } ] } diff --git a/plugins/listener2/src/commands.rs b/plugins/listener2/src/commands.rs index 0b24271ecb..fb539ea9ea 100644 --- a/plugins/listener2/src/commands.rs +++ b/plugins/listener2/src/commands.rs @@ -47,6 +47,30 @@ pub async fn run_denoise( .map_err(|e| e.to_string()) } +#[tauri::command] +#[specta::specta] +pub async fn audio_confirm_denoise( + app: tauri::AppHandle, + session_id: String, +) -> Result<(), String> { + app.listener2() + .confirm_denoise(&session_id) + .await + .map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub async fn audio_revert_denoise( + app: tauri::AppHandle, + session_id: String, +) -> Result<(), String> { + app.listener2() + .revert_denoise(&session_id) + .await + .map_err(|e| e.to_string()) +} + #[tauri::command] #[specta::specta] pub async fn is_supported_languages_batch( diff --git a/plugins/listener2/src/ext.rs b/plugins/listener2/src/ext.rs index 523326b5db..6a9bb873e6 100644 --- a/plugins/listener2/src/ext.rs +++ b/plugins/listener2/src/ext.rs @@ -10,25 +10,51 @@ pub struct Listener2<'a, R: tauri::Runtime, M: tauri::Manager> { impl<'a, R: tauri::Runtime, M: tauri::Manager> Listener2<'a, R, M> { pub async fn run_batch(&self, params: core::BatchParams) -> Result<(), core::Error> { - let state = self.manager.state::(); - let guard = state.lock().await; - let app = guard.app.clone(); - drop(guard); + let app = self.manager.state::().inner().clone(); - let runtime = Arc::new(TauriBatchRuntime { app }); + let runtime = Arc::new(Listener2Runtime { + storage: tauri_plugin_settings::TauriStorageRuntime { app }, + }); core::run_batch(runtime, params).await } pub async fn run_denoise(&self, params: core::DenoiseParams) -> Result<(), core::Error> { - let state = self.manager.state::(); - let guard = state.lock().await; - let app = guard.app.clone(); - drop(guard); + let app = self.manager.state::().inner().clone(); - let runtime = Arc::new(TauriDenoiseRuntime { app }); + let runtime = Arc::new(Listener2Runtime { + storage: tauri_plugin_settings::TauriStorageRuntime { app }, + }); core::run_denoise(runtime, params).await } + pub async fn confirm_denoise(&self, session_id: &str) -> Result<(), core::Error> { + let app = self.manager.state::().inner().clone(); + + let runtime = Listener2Runtime { + storage: tauri_plugin_settings::TauriStorageRuntime { app }, + }; + let session_id = session_id.to_string(); + + tokio::task::spawn_blocking(move || { + core::confirm_denoise(&runtime, &session_id).map(|_| ()) + }) + .await + .map_err(|e| core::Error::DenoiseError(e.to_string()))? + } + + pub async fn revert_denoise(&self, session_id: &str) -> Result<(), core::Error> { + let app = self.manager.state::().inner().clone(); + + let runtime = Listener2Runtime { + storage: tauri_plugin_settings::TauriStorageRuntime { app }, + }; + let session_id = session_id.to_string(); + + tokio::task::spawn_blocking(move || core::revert_denoise(&runtime, &session_id)) + .await + .map_err(|e| core::Error::DenoiseError(e.to_string()))? + } + pub fn parse_subtitle(&self, path: String) -> Result { core::parse_subtitle_from_path(path) } @@ -74,24 +100,30 @@ impl> Listener2PluginExt for T { } } -struct TauriBatchRuntime { - app: tauri::AppHandle, +struct Listener2Runtime { + storage: tauri_plugin_settings::TauriStorageRuntime, } -impl core::BatchRuntime for TauriBatchRuntime { - fn emit(&self, event: core::BatchEvent) { - let tauri_event: crate::BatchEvent = event.into(); - let _ = tauri_event.emit(&self.app); +impl hypr_storage::StorageRuntime for Listener2Runtime { + fn global_base(&self) -> Result { + self.storage.global_base() + } + + fn vault_base(&self) -> Result { + self.storage.vault_base() } } -struct TauriDenoiseRuntime { - app: tauri::AppHandle, +impl core::BatchRuntime for Listener2Runtime { + fn emit(&self, event: core::BatchEvent) { + let tauri_event: crate::BatchEvent = event.into(); + let _ = tauri_event.emit(&self.storage.app); + } } -impl core::DenoiseRuntime for TauriDenoiseRuntime { +impl core::DenoiseRuntime for Listener2Runtime { fn emit(&self, event: core::DenoiseEvent) { let tauri_event: crate::DenoiseEvent = event.into(); - let _ = tauri_event.emit(&self.app); + let _ = tauri_event.emit(&self.storage.app); } } diff --git a/plugins/listener2/src/lib.rs b/plugins/listener2/src/lib.rs index 662d808297..42933dbeae 100644 --- a/plugins/listener2/src/lib.rs +++ b/plugins/listener2/src/lib.rs @@ -1,6 +1,4 @@ -use std::sync::Arc; use tauri::Manager; -use tokio::sync::Mutex; mod commands; mod events; @@ -12,12 +10,7 @@ pub use ext::*; pub use hypr_listener2_core::{BatchParams, BatchProvider, DenoiseParams, Subtitle, VttWord}; const PLUGIN_NAME: &str = "listener2"; - -pub type SharedState = Arc>; - -pub struct State { - pub app: tauri::AppHandle, -} +pub type SharedState = tauri::AppHandle; fn make_specta_builder() -> tauri_specta::Builder { tauri_specta::Builder::::new() @@ -25,6 +18,8 @@ fn make_specta_builder() -> tauri_specta::Builder { .commands(tauri_specta::collect_commands![ commands::run_batch::, commands::run_denoise::, + commands::audio_confirm_denoise::, + commands::audio_revert_denoise::, commands::parse_subtitle::, commands::export_to_vtt::, commands::is_supported_languages_batch::, @@ -42,10 +37,7 @@ pub fn init() -> tauri::plugin::TauriPlugin { .invoke_handler(specta_builder.invoke_handler()) .setup(move |app, _api| { specta_builder.mount_events(app); - - let app_handle = app.app_handle().clone(); - let state: SharedState = Arc::new(Mutex::new(State { app: app_handle })); - app.manage(state); + app.manage(app.app_handle().clone()); Ok(()) }) diff --git a/plugins/settings/src/lib.rs b/plugins/settings/src/lib.rs index 062db5a50a..cd1462c0b3 100644 --- a/plugins/settings/src/lib.rs +++ b/plugins/settings/src/lib.rs @@ -3,11 +3,13 @@ use tauri::Manager; mod commands; mod error; mod ext; +mod runtime; mod state; pub use error::{Error, Result}; pub use ext::*; pub use hypr_storage::ObsidianVault; +pub use runtime::TauriStorageRuntime; pub use state::*; const PLUGIN_NAME: &str = "settings"; diff --git a/plugins/settings/src/runtime.rs b/plugins/settings/src/runtime.rs new file mode 100644 index 0000000000..f029a5f7f1 --- /dev/null +++ b/plugins/settings/src/runtime.rs @@ -0,0 +1,25 @@ +use std::path::PathBuf; + +pub struct TauriStorageRuntime { + pub app: tauri::AppHandle, +} + +impl hypr_storage::StorageRuntime for TauriStorageRuntime { + fn global_base(&self) -> Result { + use crate::SettingsPluginExt; + self.app + .settings() + .global_base() + .map(|p| p.into_std_path_buf()) + .map_err(|_| hypr_storage::Error::DataDirUnavailable) + } + + fn vault_base(&self) -> Result { + use crate::SettingsPluginExt; + self.app + .settings() + .cached_vault_base() + .map(|p| p.into_std_path_buf()) + .map_err(|_| hypr_storage::Error::DataDirUnavailable) + } +}