diff --git a/extension/audio/mel_spectrogram.py b/extension/audio/mel_spectrogram.py index 50b9ded01af..4d7180854f1 100644 --- a/extension/audio/mel_spectrogram.py +++ b/extension/audio/mel_spectrogram.py @@ -192,10 +192,15 @@ def export_processor(model=None, output_file="whisper_preprocess.pte"): if model is None: model = WhisperAudioProcessor() - audio_tensor = torch.randn(93680) + if model.streaming: + # Streaming processes small windows per step. 2 seconds gives + # comfortable headroom while keeping the memory plan tight. + max_samples = 2 * model.sampling_rate + else: + max_samples = model.max_audio_len * model.sampling_rate + audio_tensor = torch.randn(min(93680, max_samples)) shapes_collection = torch.export.ShapesCollection() - max_n_chunks = int(model.max_audio_len * model.n_samples) - shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_n_chunks)} + shapes_collection[audio_tensor] = {0: Dim.DYNAMIC(max=max_samples)} with torch.no_grad(), torch.fx.experimental._config.patch( backed_size_oblivious=True ):