Summary
Added Flash Attention and batched segment decoding to mlx-whisper, achieving 9.5x speedup on Apple Silicon.
Changes
1. Flash Attention (whisper.py)
- Replaced manual QKV attention with
mx.fast.scaled_dot_product_attention
- Conditional path: uses flash attention by default, falls back to standard attention when
word_timestamps=True (needs QK weights for DTW alignment)
- Proper mask handling for autoregressive decoding with KV cache
2. Batched decoding (transcribe.py)
- New
batch_size parameter in transcribe() (default=1, fully backward-compatible)
- Pre-slices audio into fixed 30s chunks, stacks into batch tensor
(N, 3000, n_mels), decodes simultaneously
- Per-segment temperature fallback for quality control
batch_size=1 produces identical output to current code
Zero new dependencies. No breaking changes.
Benchmarks (M2 8GB, whisper-small, 5 min Russian audio)
| Mode |
Time |
Realtime Factor |
Speedup |
| Sequential (batch_size=1) |
9.4s |
4.8x RT |
1x |
| Batched (batch_size=12) |
6.6s |
44.8x RT |
9.5x |
For a 15-hour video: ~20 minutes instead of ~3 hours.
Code
Full implementation with benchmarks: https://github.com/ilyasmukiev/mlx-whisper-pr
- Branch
flash-attention-batch: minimal changes (Flash Attention + batch_size parameter only)
- Branch
full-batching-vad-diarize: adds optional VAD (Silero) and speaker diarization
Standalone package: https://github.com/ilyasmukiev/mlx-whisper-fast
Notes
- Could not create a PR directly because
gh repo fork returns HTTP 502 (repo too large?)
- Happy to submit a proper PR once the fork works
- The batched path uses fixed-stride chunking (no dynamic seeking), which is a deliberate trade-off for parallelism — same approach as WhisperX and lightning-whisper-mlx
- Related: Discussion Unexpected Processing Times for Short vs. Long Audio Files with MLX-Whisper #1275 where batching was acknowledged as possible but not implemented
Summary
Added Flash Attention and batched segment decoding to mlx-whisper, achieving 9.5x speedup on Apple Silicon.
Changes
1. Flash Attention (
whisper.py)mx.fast.scaled_dot_product_attentionword_timestamps=True(needs QK weights for DTW alignment)2. Batched decoding (
transcribe.py)batch_sizeparameter intranscribe()(default=1, fully backward-compatible)(N, 3000, n_mels), decodes simultaneouslybatch_size=1produces identical output to current codeZero new dependencies. No breaking changes.
Benchmarks (M2 8GB, whisper-small, 5 min Russian audio)
For a 15-hour video: ~20 minutes instead of ~3 hours.
Code
Full implementation with benchmarks: https://github.com/ilyasmukiev/mlx-whisper-pr
flash-attention-batch: minimal changes (Flash Attention + batch_size parameter only)full-batching-vad-diarize: adds optional VAD (Silero) and speaker diarizationStandalone package: https://github.com/ilyasmukiev/mlx-whisper-fast
Notes
gh repo forkreturns HTTP 502 (repo too large?)