diff --git a/src/server.py b/src/server.py index 9a45cf0..ef7aa13 100644 --- a/src/server.py +++ b/src/server.py @@ -1,6 +1,6 @@ import json - import numpy as np +import torch from flask import Flask, send_from_directory, request, jsonify from flask_cors import CORS, cross_origin @@ -12,12 +12,16 @@ top_phonetic_errors, pair_by_words, ) -from transcription import transcribe_timestamped, SAMPLE_RATE +from transcription import ( + extract_features_only, + run_transformer_on_features, + STRIDE_SIZE, +) from phoneme_utils import TIMESTAMPED_PHONES_T, TIMESTAMPED_PHONES_BY_WORD_T # Constants DEBUG = False -NUM_SECONDS_PER_CHUNK = 0.5 +TRANSFORMER_INTERVAL = 30 # Initialize Flask app app = Flask(__name__) @@ -83,44 +87,46 @@ def get_score_words_cer(): @sock.route("/stream") def stream(ws): - buffer = b"" # Buffer to hold audio chunks + buffer = b"" + feature_list = [] + total_samples_processed = 0 - full_transcription: TIMESTAMPED_PHONES_T = [] - accumulated_duration = 0 - combined = np.array([], dtype=np.float32) while True: try: - # Receive audio data from the client data = ws.receive() if data and data != "stop": buffer += data - # Process when buffer has at least one chunk in it or when we are done - if ( - data == "stop" - or len(buffer) - >= SAMPLE_RATE * NUM_SECONDS_PER_CHUNK * np.dtype(np.float32).itemsize - ): - audio = np.frombuffer(buffer, dtype=np.float32) - transcription = transcribe_timestamped(audio, accumulated_duration) - accumulated_duration += len(audio) / SAMPLE_RATE - full_transcription.extend(transcription) - ws.send(json.dumps(full_transcription)) - - if DEBUG: - from scipy.io import wavfile - - wavfile.write("src/audio.wav", SAMPLE_RATE, audio) - combined = np.concatenate([combined, audio]) - wavfile.write("src/combined.wav", SAMPLE_RATE, combined) - - if data == "stop": - break - - buffer = b"" # Clear the buffer + # Process 20ms chunks + while len(buffer) >= STRIDE_SIZE * np.dtype(np.float32).itemsize: + chunk_bytes = buffer[: STRIDE_SIZE * np.dtype(np.float32).itemsize] + buffer = buffer[STRIDE_SIZE * np.dtype(np.float32).itemsize :] + + audio_chunk = np.frombuffer(chunk_bytes, dtype=np.float32) + + features, samples = extract_features_only(audio_chunk) + feature_list.append(features) + total_samples_processed += samples + # accumulate features for 500ms (25 sets of 20ms), then send COMPLETE transcription from start + if len(feature_list) % TRANSFORMER_INTERVAL == 0: + all_features = torch.cat(feature_list, dim=1) + full_transcription = run_transformer_on_features( + all_features, total_samples_processed + ) + ws.send(json.dumps(full_transcription)) + + if data == "stop": + # Final update with any remaining features + if feature_list: + all_features = torch.cat(feature_list, dim=1) + full_transcription = run_transformer_on_features( + all_features, total_samples_processed + ) + ws.send(json.dumps(full_transcription)) + break + except Exception as e: print(f"Error: {e}") - print(f"Line: {e.__traceback__.tb_lineno if e.__traceback__ else -1}") break diff --git a/src/transcription.py b/src/transcription.py index e5e920b..ea7190e 100644 --- a/src/transcription.py +++ b/src/transcription.py @@ -1,53 +1,89 @@ import torch import numpy as np -from transformers import AutoProcessor, AutoModelForCTC +from transformers import ( + AutoProcessor, + AutoModelForCTC, + Wav2Vec2Processor, + Wav2Vec2ForCTC, +) from phoneme_utils import TIMESTAMPED_PHONES_T SAMPLE_RATE = 16_000 # Load Wav2Vec2 model model_id = "KoelLabs/xlsr-english-01" -processor = AutoProcessor.from_pretrained(model_id) -model = AutoModelForCTC.from_pretrained(model_id) +processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(model_id) +model: Wav2Vec2ForCTC = AutoModelForCTC.from_pretrained(model_id) assert processor.feature_extractor.sampling_rate == SAMPLE_RATE -def transcribe_timestamped(audio: np.ndarray, time_offset=0.0) -> TIMESTAMPED_PHONES_T: - input_values = ( - processor( - audio, - sampling_rate=processor.feature_extractor.sampling_rate, - return_tensors="pt", - padding=True, - ) - .input_values.type(torch.float32) - .to(model.device) +def _calculate_cnn_window(model: Wav2Vec2ForCTC): + receptive_field = 1 + stride = 1 + for conv_layer in model.wav2vec2.feature_extractor.conv_layers: + assert hasattr(conv_layer, "conv") + conv = conv_layer.conv + assert isinstance(conv, torch.nn.Conv1d) + receptive_field += (conv.kernel_size[0] - 1) * stride + stride *= conv.stride[0] + return receptive_field, stride + + +RECEPTIVE_FIELD_SIZE, STRIDE_SIZE = _calculate_cnn_window(model) + + +def extract_features_only(audio: np.ndarray): + """Extract CNN features and project to encoder hidden size (transformer-ready).""" + # True raw sample count before any padding + raw_sample_count = int(np.asarray(audio).shape[-1]) + + inputs = processor( + audio, + sampling_rate=SAMPLE_RATE, + return_tensors="pt", + padding="max_length", + max_length=RECEPTIVE_FIELD_SIZE, ) + input_values = inputs.input_values.type(torch.float32).to(model.device) with torch.no_grad(): - logits = model(input_values).logits + conv_feats = model.wav2vec2.feature_extractor(input_values) # (B, C, T') + conv_feats_t = conv_feats.transpose(1, 2) # (B, T', C) + # Project to hidden size for transformer; also returns normalized conv features + features, normed_conv_feats = model.wav2vec2.feature_projection(conv_feats_t) + # Return transformer-ready features and original (unpadded) input length in samples + return features, raw_sample_count - predicted_ids = torch.argmax(logits, dim=-1)[0].tolist() - duration_sec = input_values.shape[1] / processor.feature_extractor.sampling_rate +def run_transformer_on_features( + features: torch.Tensor, total_audio_samples: int, time_offset: float = 0.0 +) -> TIMESTAMPED_PHONES_T: + """Run transformer and decode""" + # slowest step + with torch.no_grad(): + encoder_outputs = model.wav2vec2.encoder(features) + logits = model.lm_head(encoder_outputs[0]) + + predicted_ids = torch.argmax(logits, dim=-1)[0].tolist() + # Use original audio length in samples to compute duration + duration_sec = total_audio_samples / processor.feature_extractor.sampling_rate ids_w_time = [ (time_offset + i / len(predicted_ids) * duration_sec, _id) for i, _id in enumerate(predicted_ids) ] - current_phoneme_id = processor.tokenizer.pad_token_id current_start_time = 0 phonemes_with_time = [] - for time, _id in ids_w_time: + for timestamp, _id in ids_w_time: if current_phoneme_id != _id: if current_phoneme_id != processor.tokenizer.pad_token_id: phonemes_with_time.append( ( processor.decode(current_phoneme_id), current_start_time, - time, + timestamp, ) ) - current_start_time = time - current_phoneme_id = _id + current_start_time = timestamp + current_phoneme_id = _id return phonemes_with_time