From be0b0d9211c47f412e83ecb3b5b8ae7f20550d91 Mon Sep 17 00:00:00 2001 From: arunasrivastava Date: Sat, 9 Aug 2025 23:42:05 -0700 Subject: [PATCH 1/7] initial working decoupled transformer and feature extractor --- src/server.py | 75 +++++++++++++++++++++---------------- src/timestamp_testing.py | 80 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 31 deletions(-) create mode 100644 src/timestamp_testing.py diff --git a/src/server.py b/src/server.py index 9a45cf0..9ffb7ae 100644 --- a/src/server.py +++ b/src/server.py @@ -1,7 +1,8 @@ import json - +import torch import numpy as np + from flask import Flask, send_from_directory, request, jsonify from flask_cors import CORS, cross_origin from flask_sock import Sock @@ -12,7 +13,12 @@ top_phonetic_errors, pair_by_words, ) -from transcription import transcribe_timestamped, SAMPLE_RATE +from transcription import ( + extract_features_only, + run_transformer_on_features, + SAMPLE_RATE, + transcribe_timestamped, +) from phoneme_utils import TIMESTAMPED_PHONES_T, TIMESTAMPED_PHONES_BY_WORD_T # Constants @@ -83,44 +89,51 @@ def get_score_words_cer(): @sock.route("/stream") def stream(ws): - buffer = b"" # Buffer to hold audio chunks + buffer = b"" + feature_list = [] + total_samples_processed = 0 + + CHUNK_SIZE_SAMPLES = 320 # 20ms at 16kHz + TRANSFORMER_INTERVAL = 25 # Every 500ms - full_transcription: TIMESTAMPED_PHONES_T = [] - accumulated_duration = 0 - combined = np.array([], dtype=np.float32) while True: try: - # Receive audio data from the client data = ws.receive() if data and data != "stop": buffer += data - # Process when buffer has at least one chunk in it or when we are done - if ( - data == "stop" - or len(buffer) - >= SAMPLE_RATE * NUM_SECONDS_PER_CHUNK * np.dtype(np.float32).itemsize - ): - audio = np.frombuffer(buffer, dtype=np.float32) - transcription = transcribe_timestamped(audio, accumulated_duration) - accumulated_duration += len(audio) / SAMPLE_RATE - full_transcription.extend(transcription) - ws.send(json.dumps(full_transcription)) - - if DEBUG: - from scipy.io import wavfile - - wavfile.write("src/audio.wav", SAMPLE_RATE, audio) - combined = np.concatenate([combined, audio]) - wavfile.write("src/combined.wav", SAMPLE_RATE, combined) - - if data == "stop": - break - - buffer = b"" # Clear the buffer + # Process 20ms chunks + while len(buffer) >= CHUNK_SIZE_SAMPLES * np.dtype(np.float32).itemsize: + chunk_bytes = buffer[ + : CHUNK_SIZE_SAMPLES * np.dtype(np.float32).itemsize + ] + buffer = buffer[CHUNK_SIZE_SAMPLES * np.dtype(np.float32).itemsize :] + + audio_chunk = np.frombuffer(chunk_bytes, dtype=np.float32) + + features, samples = extract_features_only(audio_chunk) + feature_list.append(features) + total_samples_processed += samples + # Every 500ms, send COMPLETE transcription from start + if len(feature_list) % TRANSFORMER_INTERVAL == 0: + all_features = torch.cat(feature_list, dim=1) + full_transcription = run_transformer_on_features( + all_features, total_samples_processed + ) + ws.send(json.dumps(full_transcription)) + + if data == "stop": + # Final update with any remaining features + if feature_list: + all_features = torch.cat(feature_list, dim=1) + full_transcription = run_transformer_on_features( + all_features, total_samples_processed + ) + ws.send(json.dumps(full_transcription)) + break + except Exception as e: print(f"Error: {e}") - print(f"Line: {e.__traceback__.tb_lineno if e.__traceback__ else -1}") break diff --git a/src/timestamp_testing.py b/src/timestamp_testing.py new file mode 100644 index 0000000..ab799f4 --- /dev/null +++ b/src/timestamp_testing.py @@ -0,0 +1,80 @@ +import sys +import os +from scipy.io import wavfile +import numpy as np + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.feedback import pair_by_words, score_words_wfed, score_words_cer, user_phonetic_errors +from src.phoneme_utils import weighted_needleman_wunsch, map_target_data, validate_target_data +import panphon +import panphon.distance + +panphon_dist = panphon.distance.Distance() + +user_speech_timestamps = [ + ("j", 0.501336302895323, 0.521380846325167), + ("u", 0.5414253897550111, 0.5614699331848553), + ("j", 0.601336302895323, 0.621380846325167), + ("u", 0.6414253897550111, 0.6614699331848553), + ("ɡ", 0.8218262806236081, 0.8418708240534521), + ("ɡ", 0.8218262806236081, 0.8418708240534521), + ("ɑ", 0.8619153674832962, 0.8819599109131403), + ("t", 1.042316258351893, 1.0623608017817372), + ("ə", 1.0824053452115814, 1.1024498886414253), + ("s", 1.22271714922049, 1.2628062360801782), + ("t", 1.3028953229398665, 1.3229398663697105), + ("eɪ", 1.3429844097995547, 1.3630289532293987), + ("ʌ", 1.5835189309576838, 1.6035634743875278), + ("w", 1.7639198218262806, 1.7839643652561248), + ("ɜ˞", 1.7839643652561248, 1.8040089086859687), + ("d", 2.024498886414254, 2.044543429844098), + ("ɔ", 2.1648106904231628, 2.1848552338530065), + ("l", 2.305122494432071, 2.3251670378619154), + ("ð", 2.385300668151448, 2.405345211581292), + ("ə", 2.405345211581292, 2.425389755011136), + # ("t", 2.505567928730512, 2.5256124721603563), + # ("aɪ", 2.605790645879733, 2.625835189309577), + # ("m", 2.906458797327394, 2.9265033407572383), +] +actor_speech_timestamps = [ + ("j", 0.24160443037974685, 0.2617381329113924), + ("u", 0.2617381329113924, 0.281871835443038), + ("ɡ", 0.34227294303797473, 0.36240664556962027), + ("ɑ", 0.38254034810126586, 0.40267405063291145), + ("t", 0.4429414556962025, 0.4630751582278481), + ("ʌ", 0.4832088607594937, 0.5033425632911394), + ("s", 0.6845458860759495, 0.7046795886075949), + ("t", 0.7650806962025317, 0.7852143987341773), + ("eɪ", 0.8053481012658229, 0.8254818037974684), + ("æ", 1.0872199367088609, 1.1073536392405066), + ("n", 1.22815585443038, 1.2482895569620254), + ("l", 1.369091772151899, 1.3892254746835444), + ("ɜ˞", 1.4093591772151899, 1.4294928797468356), + ("t", 1.6106962025316458, 1.630829905063291), + ("ɔ", 2.0536376582278484, 2.073771360759494), + ("l", 2.1945735759493674, 2.214707278481013), + ("ð", 2.27510838607595, 2.295242088607595), + ("ʌ", 2.295242088607595, 2.3153757911392407), + ("t", 2.416044303797469, 2.436178006329114), + ("aɪ", 2.4764454113924055, 2.4965791139240507), + ("m", 2.7985846518987345, 2.8187183544303798), +] +actor_speech_by_words = [ + ["you", ['j', 'u']], + ["gotta", ['ɡ', 'ɑ', 't', "ʌ"]], + ["stay", ['s', 't', 'eɪ']], + ["alert", ["æ", 'n','l', 'ɜ˞', 't']], + ["all", ['ɔ', 'l']], + ["the", ["ð", "ʌ"]], + ["time", ['t', 'aɪ', 'm']] + ] + + +user_speech_transcribed = ["".join(s[0]) for s in user_speech_timestamps] +actor_speech_transcribed = ["".join(s[0]) for s in actor_speech_timestamps] + +mapped_user_speech_timestamps, mapped_actor_speech_by_words = map_target_data(user_speech_timestamps, actor_speech_by_words) +print(mapped_user_speech_timestamps) +print(mapped_actor_speech_by_words) +print(validate_target_data(actor_speech_timestamps, actor_speech_by_words)) \ No newline at end of file From a659b8a2d021fd5c5879e8a71802db4409944dec Mon Sep 17 00:00:00 2001 From: arunasrivastava Date: Sun, 10 Aug 2025 13:52:15 -0700 Subject: [PATCH 2/7] updating stream code. todo: check transformer interval is a good threshold --- src/server.py | 11 +++---- src/transcription.py | 69 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 70 insertions(+), 10 deletions(-) diff --git a/src/server.py b/src/server.py index 9ffb7ae..e967ce9 100644 --- a/src/server.py +++ b/src/server.py @@ -1,7 +1,6 @@ import json -import torch import numpy as np - +import torch from flask import Flask, send_from_directory, request, jsonify from flask_cors import CORS, cross_origin @@ -23,7 +22,8 @@ # Constants DEBUG = False -NUM_SECONDS_PER_CHUNK = 0.5 +CHUNK_SIZE_SAMPLES = 320 # 20ms at 16kHz +TRANSFORMER_INTERVAL = 30 # Initialize Flask app app = Flask(__name__) @@ -93,9 +93,6 @@ def stream(ws): feature_list = [] total_samples_processed = 0 - CHUNK_SIZE_SAMPLES = 320 # 20ms at 16kHz - TRANSFORMER_INTERVAL = 25 # Every 500ms - while True: try: data = ws.receive() @@ -114,7 +111,7 @@ def stream(ws): features, samples = extract_features_only(audio_chunk) feature_list.append(features) total_samples_processed += samples - # Every 500ms, send COMPLETE transcription from start + # accumulate features for 500ms (25 sets of 20ms), then send COMPLETE transcription from start if len(feature_list) % TRANSFORMER_INTERVAL == 0: all_features = torch.cat(feature_list, dim=1) full_transcription = run_transformer_on_features( diff --git a/src/transcription.py b/src/transcription.py index e5e920b..775ab40 100644 --- a/src/transcription.py +++ b/src/transcription.py @@ -1,15 +1,78 @@ import torch import numpy as np -from transformers import AutoProcessor, AutoModelForCTC +from transformers import ( + AutoProcessor, + AutoModelForCTC, + Wav2Vec2Processor, + Wav2Vec2ForCTC, +) from phoneme_utils import TIMESTAMPED_PHONES_T SAMPLE_RATE = 16_000 # Load Wav2Vec2 model model_id = "KoelLabs/xlsr-english-01" -processor = AutoProcessor.from_pretrained(model_id) -model = AutoModelForCTC.from_pretrained(model_id) +processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(model_id) +model: Wav2Vec2ForCTC = AutoModelForCTC.from_pretrained(model_id) assert processor.feature_extractor.sampling_rate == SAMPLE_RATE +MIN_LEN_SAMPLES = ( + 400 # computed from model.config.conv_kernel and model.config.conv_stride +) + + +def extract_features_only(audio: np.ndarray): + """Extract CNN features and project to encoder hidden size (transformer-ready).""" + # True raw sample count before any padding + raw_sample_count = int(np.asarray(audio).shape[-1]) + if raw_sample_count < MIN_LEN_SAMPLES: + audio = np.pad(audio, (0, MIN_LEN_SAMPLES - raw_sample_count), mode="constant") + inputs = processor( + audio, + sampling_rate=SAMPLE_RATE, + return_tensors="pt", + padding=False, + ) + input_values = inputs.input_values.type(torch.float32).to(model.device) + with torch.no_grad(): + conv_feats = model.wav2vec2.feature_extractor(input_values) # (B, C, T') + conv_feats_t = conv_feats.transpose(1, 2) # (B, T', C) + # Project to hidden size for transformer; also returns normalized conv features + features, normed_conv_feats = model.wav2vec2.feature_projection(conv_feats_t) + # Return transformer-ready features and original (unpadded) input length in samples + return features, raw_sample_count + + +def run_transformer_on_features(features, total_audio_samples, time_offset=0.0): + """Run transformer and decode""" + # slowest step + with torch.no_grad(): + encoder_outputs = model.wav2vec2.encoder(features) + logits = model.lm_head(encoder_outputs[0]) + + predicted_ids = torch.argmax(logits, dim=-1)[0].tolist() + # Use original audio length in samples to compute duration + duration_sec = total_audio_samples / processor.feature_extractor.sampling_rate + ids_w_time = [ + (time_offset + i / len(predicted_ids) * duration_sec, _id) + for i, _id in enumerate(predicted_ids) + ] + current_phoneme_id = processor.tokenizer.pad_token_id + current_start_time = 0 + phonemes_with_time = [] + for timestamp, _id in ids_w_time: + if current_phoneme_id != _id: + if current_phoneme_id != processor.tokenizer.pad_token_id: + phonemes_with_time.append( + ( + processor.decode(current_phoneme_id), + current_start_time, + timestamp, + ) + ) + + current_start_time = timestamp + current_phoneme_id = _id + return phonemes_with_time def transcribe_timestamped(audio: np.ndarray, time_offset=0.0) -> TIMESTAMPED_PHONES_T: From aa3592b14d47dee8ea6299c6148a6b8c2ff07319 Mon Sep 17 00:00:00 2001 From: arunasrivastava Date: Sun, 10 Aug 2025 13:54:35 -0700 Subject: [PATCH 3/7] remove old transcription version --- src/transcription.py | 43 +------------------------------------------ 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/src/transcription.py b/src/transcription.py index 775ab40..95acc1b 100644 --- a/src/transcription.py +++ b/src/transcription.py @@ -72,45 +72,4 @@ def run_transformer_on_features(features, total_audio_samples, time_offset=0.0): current_start_time = timestamp current_phoneme_id = _id - return phonemes_with_time - - -def transcribe_timestamped(audio: np.ndarray, time_offset=0.0) -> TIMESTAMPED_PHONES_T: - input_values = ( - processor( - audio, - sampling_rate=processor.feature_extractor.sampling_rate, - return_tensors="pt", - padding=True, - ) - .input_values.type(torch.float32) - .to(model.device) - ) - with torch.no_grad(): - logits = model(input_values).logits - - predicted_ids = torch.argmax(logits, dim=-1)[0].tolist() - duration_sec = input_values.shape[1] / processor.feature_extractor.sampling_rate - - ids_w_time = [ - (time_offset + i / len(predicted_ids) * duration_sec, _id) - for i, _id in enumerate(predicted_ids) - ] - - current_phoneme_id = processor.tokenizer.pad_token_id - current_start_time = 0 - phonemes_with_time = [] - for time, _id in ids_w_time: - if current_phoneme_id != _id: - if current_phoneme_id != processor.tokenizer.pad_token_id: - phonemes_with_time.append( - ( - processor.decode(current_phoneme_id), - current_start_time, - time, - ) - ) - current_start_time = time - current_phoneme_id = _id - - return phonemes_with_time + return phonemes_with_time \ No newline at end of file From cdefb7b3559e1a4932b48958af6600e83d5f7896 Mon Sep 17 00:00:00 2001 From: arunasrivastava Date: Sun, 10 Aug 2025 19:42:26 -0700 Subject: [PATCH 4/7] delte timestamp test file --- src/timestamp_testing.py | 80 ---------------------------------------- 1 file changed, 80 deletions(-) delete mode 100644 src/timestamp_testing.py diff --git a/src/timestamp_testing.py b/src/timestamp_testing.py deleted file mode 100644 index ab799f4..0000000 --- a/src/timestamp_testing.py +++ /dev/null @@ -1,80 +0,0 @@ -import sys -import os -from scipy.io import wavfile -import numpy as np - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from src.feedback import pair_by_words, score_words_wfed, score_words_cer, user_phonetic_errors -from src.phoneme_utils import weighted_needleman_wunsch, map_target_data, validate_target_data -import panphon -import panphon.distance - -panphon_dist = panphon.distance.Distance() - -user_speech_timestamps = [ - ("j", 0.501336302895323, 0.521380846325167), - ("u", 0.5414253897550111, 0.5614699331848553), - ("j", 0.601336302895323, 0.621380846325167), - ("u", 0.6414253897550111, 0.6614699331848553), - ("ɡ", 0.8218262806236081, 0.8418708240534521), - ("ɡ", 0.8218262806236081, 0.8418708240534521), - ("ɑ", 0.8619153674832962, 0.8819599109131403), - ("t", 1.042316258351893, 1.0623608017817372), - ("ə", 1.0824053452115814, 1.1024498886414253), - ("s", 1.22271714922049, 1.2628062360801782), - ("t", 1.3028953229398665, 1.3229398663697105), - ("eɪ", 1.3429844097995547, 1.3630289532293987), - ("ʌ", 1.5835189309576838, 1.6035634743875278), - ("w", 1.7639198218262806, 1.7839643652561248), - ("ɜ˞", 1.7839643652561248, 1.8040089086859687), - ("d", 2.024498886414254, 2.044543429844098), - ("ɔ", 2.1648106904231628, 2.1848552338530065), - ("l", 2.305122494432071, 2.3251670378619154), - ("ð", 2.385300668151448, 2.405345211581292), - ("ə", 2.405345211581292, 2.425389755011136), - # ("t", 2.505567928730512, 2.5256124721603563), - # ("aɪ", 2.605790645879733, 2.625835189309577), - # ("m", 2.906458797327394, 2.9265033407572383), -] -actor_speech_timestamps = [ - ("j", 0.24160443037974685, 0.2617381329113924), - ("u", 0.2617381329113924, 0.281871835443038), - ("ɡ", 0.34227294303797473, 0.36240664556962027), - ("ɑ", 0.38254034810126586, 0.40267405063291145), - ("t", 0.4429414556962025, 0.4630751582278481), - ("ʌ", 0.4832088607594937, 0.5033425632911394), - ("s", 0.6845458860759495, 0.7046795886075949), - ("t", 0.7650806962025317, 0.7852143987341773), - ("eɪ", 0.8053481012658229, 0.8254818037974684), - ("æ", 1.0872199367088609, 1.1073536392405066), - ("n", 1.22815585443038, 1.2482895569620254), - ("l", 1.369091772151899, 1.3892254746835444), - ("ɜ˞", 1.4093591772151899, 1.4294928797468356), - ("t", 1.6106962025316458, 1.630829905063291), - ("ɔ", 2.0536376582278484, 2.073771360759494), - ("l", 2.1945735759493674, 2.214707278481013), - ("ð", 2.27510838607595, 2.295242088607595), - ("ʌ", 2.295242088607595, 2.3153757911392407), - ("t", 2.416044303797469, 2.436178006329114), - ("aɪ", 2.4764454113924055, 2.4965791139240507), - ("m", 2.7985846518987345, 2.8187183544303798), -] -actor_speech_by_words = [ - ["you", ['j', 'u']], - ["gotta", ['ɡ', 'ɑ', 't', "ʌ"]], - ["stay", ['s', 't', 'eɪ']], - ["alert", ["æ", 'n','l', 'ɜ˞', 't']], - ["all", ['ɔ', 'l']], - ["the", ["ð", "ʌ"]], - ["time", ['t', 'aɪ', 'm']] - ] - - -user_speech_transcribed = ["".join(s[0]) for s in user_speech_timestamps] -actor_speech_transcribed = ["".join(s[0]) for s in actor_speech_timestamps] - -mapped_user_speech_timestamps, mapped_actor_speech_by_words = map_target_data(user_speech_timestamps, actor_speech_by_words) -print(mapped_user_speech_timestamps) -print(mapped_actor_speech_by_words) -print(validate_target_data(actor_speech_timestamps, actor_speech_by_words)) \ No newline at end of file From 801f74bff018d27ff3f32f7dead4d79a35aa45a0 Mon Sep 17 00:00:00 2001 From: arunasrivastava Date: Sun, 10 Aug 2025 20:22:39 -0700 Subject: [PATCH 5/7] remove hardcoded values. add type timestamp phones --- src/server.py | 12 +++++------- src/transcription.py | 30 +++++++++++++++++++++++------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/server.py b/src/server.py index e967ce9..22b40e5 100644 --- a/src/server.py +++ b/src/server.py @@ -16,13 +16,13 @@ extract_features_only, run_transformer_on_features, SAMPLE_RATE, - transcribe_timestamped, + RECEPTIVE_FIELD_SIZE, + STRIDE_SIZE, ) from phoneme_utils import TIMESTAMPED_PHONES_T, TIMESTAMPED_PHONES_BY_WORD_T # Constants DEBUG = False -CHUNK_SIZE_SAMPLES = 320 # 20ms at 16kHz TRANSFORMER_INTERVAL = 30 # Initialize Flask app @@ -100,11 +100,9 @@ def stream(ws): buffer += data # Process 20ms chunks - while len(buffer) >= CHUNK_SIZE_SAMPLES * np.dtype(np.float32).itemsize: - chunk_bytes = buffer[ - : CHUNK_SIZE_SAMPLES * np.dtype(np.float32).itemsize - ] - buffer = buffer[CHUNK_SIZE_SAMPLES * np.dtype(np.float32).itemsize :] + while len(buffer) >= STRIDE_SIZE * np.dtype(np.float32).itemsize: + chunk_bytes = buffer[: STRIDE_SIZE * np.dtype(np.float32).itemsize] + buffer = buffer[STRIDE_SIZE * np.dtype(np.float32).itemsize :] audio_chunk = np.frombuffer(chunk_bytes, dtype=np.float32) diff --git a/src/transcription.py b/src/transcription.py index 95acc1b..b5e4de4 100644 --- a/src/transcription.py +++ b/src/transcription.py @@ -15,17 +15,31 @@ processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(model_id) model: Wav2Vec2ForCTC = AutoModelForCTC.from_pretrained(model_id) assert processor.feature_extractor.sampling_rate == SAMPLE_RATE -MIN_LEN_SAMPLES = ( - 400 # computed from model.config.conv_kernel and model.config.conv_stride -) + + +def _calculate_cnn_window(model: Wav2Vec2ForCTC): + receptive_field = 1 + stride = 1 + for conv_layer in model.wav2vec2.feature_extractor.conv_layers: + assert hasattr(conv_layer, "conv") + conv = conv_layer.conv + assert isinstance(conv, torch.nn.Conv1d) + receptive_field += (conv.kernel_size[0] - 1) * stride + stride *= conv.stride[0] + return receptive_field, stride + + +RECEPTIVE_FIELD_SIZE, STRIDE_SIZE = _calculate_cnn_window(model) def extract_features_only(audio: np.ndarray): """Extract CNN features and project to encoder hidden size (transformer-ready).""" # True raw sample count before any padding raw_sample_count = int(np.asarray(audio).shape[-1]) - if raw_sample_count < MIN_LEN_SAMPLES: - audio = np.pad(audio, (0, MIN_LEN_SAMPLES - raw_sample_count), mode="constant") + if raw_sample_count < RECEPTIVE_FIELD_SIZE: + audio = np.pad( + audio, (0, RECEPTIVE_FIELD_SIZE - raw_sample_count), mode="constant" + ) inputs = processor( audio, sampling_rate=SAMPLE_RATE, @@ -42,7 +56,9 @@ def extract_features_only(audio: np.ndarray): return features, raw_sample_count -def run_transformer_on_features(features, total_audio_samples, time_offset=0.0): +def run_transformer_on_features( + features: torch.Tensor, total_audio_samples: int, time_offset: float = 0.0 +) -> TIMESTAMPED_PHONES_T: """Run transformer and decode""" # slowest step with torch.no_grad(): @@ -72,4 +88,4 @@ def run_transformer_on_features(features, total_audio_samples, time_offset=0.0): current_start_time = timestamp current_phoneme_id = _id - return phonemes_with_time \ No newline at end of file + return phonemes_with_time From a39c1259383e7089544b9fe2a27adad9ee98a820 Mon Sep 17 00:00:00 2001 From: arunasrivastava Date: Sun, 10 Aug 2025 20:38:17 -0700 Subject: [PATCH 6/7] add padding arg to processor --- src/transcription.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transcription.py b/src/transcription.py index b5e4de4..ea7190e 100644 --- a/src/transcription.py +++ b/src/transcription.py @@ -36,15 +36,13 @@ def extract_features_only(audio: np.ndarray): """Extract CNN features and project to encoder hidden size (transformer-ready).""" # True raw sample count before any padding raw_sample_count = int(np.asarray(audio).shape[-1]) - if raw_sample_count < RECEPTIVE_FIELD_SIZE: - audio = np.pad( - audio, (0, RECEPTIVE_FIELD_SIZE - raw_sample_count), mode="constant" - ) + inputs = processor( audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", - padding=False, + padding="max_length", + max_length=RECEPTIVE_FIELD_SIZE, ) input_values = inputs.input_values.type(torch.float32).to(model.device) with torch.no_grad(): From 09a1a69ba97a1470dbf089819e3bd7c99158160e Mon Sep 17 00:00:00 2001 From: arunasrivastava Date: Tue, 12 Aug 2025 11:32:22 -0700 Subject: [PATCH 7/7] remove unused imports --- src/server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/server.py b/src/server.py index 22b40e5..ef7aa13 100644 --- a/src/server.py +++ b/src/server.py @@ -15,8 +15,6 @@ from transcription import ( extract_features_only, run_transformer_on_features, - SAMPLE_RATE, - RECEPTIVE_FIELD_SIZE, STRIDE_SIZE, ) from phoneme_utils import TIMESTAMPED_PHONES_T, TIMESTAMPED_PHONES_BY_WORD_T