From 18d45a5feabee4a6452f1e6cf7b9e99e257c8803 Mon Sep 17 00:00:00 2001 From: Alex-Wengg Date: Mon, 6 Apr 2026 15:54:15 -0400 Subject: [PATCH 1/8] feat(asr): Add Cohere Transcribe support with 14 languages Add Cohere Transcribe CoreML ASR implementation supporting 14 languages: - English, French, German, Spanish, Italian, Portuguese, Dutch, Polish - Greek, Arabic, Japanese, Chinese, Korean, Vietnamese Features: - Core ASR manager with stateful decoder - Mel spectrogram preprocessing compatible with Cohere models - CLI transcription command with language selection - Benchmark command supporting LibriSpeech and FLEURS datasets - INT8 quantized models for efficient inference Usage: swift run fluidaudiocli cohere-transcribe audio.wav --language ja_jp swift run fluidaudiocli cohere-benchmark --dataset fleurs --languages en_us,fr_fr swift run fluidaudiocli download --dataset fleurs Models: FluidInference/cohere-transcribe-03-2026-coreml --- .../ASR/Cohere/CohereAsrConfig.swift | 101 ++++ .../ASR/Cohere/CohereAsrManager.swift | 299 +++++++++++ .../ASR/Cohere/CohereAsrModels.swift | 244 +++++++++ .../ASR/Cohere/CohereMelSpectrogram.swift | 284 +++++++++++ .../Commands/ASR/Cohere/CohereBenchmark.swift | 463 ++++++++++++++++++ .../ASR/Cohere/CohereTranscribeCommand.swift | 161 ++++++ Sources/FluidAudioCLI/FluidAudioCLI.swift | 10 + 7 files changed, 1562 insertions(+) create mode 100644 Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift create mode 100644 Sources/FluidAudio/ASR/Cohere/CohereAsrManager.swift create mode 100644 Sources/FluidAudio/ASR/Cohere/CohereAsrModels.swift create mode 100644 Sources/FluidAudio/ASR/Cohere/CohereMelSpectrogram.swift create mode 100644 Sources/FluidAudioCLI/Commands/ASR/Cohere/CohereBenchmark.swift create mode 100644 Sources/FluidAudioCLI/Commands/ASR/Cohere/CohereTranscribeCommand.swift diff --git a/Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift b/Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift new file mode 100644 index 000000000..19de64bfd --- /dev/null +++ b/Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift @@ -0,0 +1,101 @@ +import Foundation + +/// Configuration for Cohere Transcribe CoreML ASR model. +public enum CohereAsrConfig { + /// Sample rate expected by the model (16kHz). + public static let sampleRate: Int = 16000 + + /// Maximum audio duration in seconds (30s). + public static let maxAudioSeconds: Float = 30.0 + + /// Maximum number of audio samples (480,000 at 16kHz = 30 seconds). + public static let maxSamples: Int = 480_000 + + /// Vocabulary size. + public static let vocabSize: Int = 16_384 + + /// Encoder hidden size (Conformer blocks). + public static let encoderHiddenSize: Int = 1280 + + /// Decoder hidden size. + public static let decoderHiddenSize: Int = 1024 + + /// Number of encoder layers. + public static let numEncoderLayers: Int = 48 + + /// Number of decoder layers. + public static let numDecoderLayers: Int = 8 + + /// Number of attention heads in decoder. + public static let numDecoderHeads: Int = 8 + + /// Head dimension (1024 / 8). + public static let headDim: Int = 128 + + /// Maximum sequence length for decoder KV cache. + public static let maxSeqLen: Int = 108 + + /// Number of mel bins. + public static let numMelBins: Int = 128 + + /// Mel spectrogram parameters. + public enum MelSpec { + public static let nFFT: Int = 1024 + public static let hopLength: Int = 160 + public static let nMels: Int = 128 + public static let fMin: Float = 0.0 + public static let fMax: Float = 8000.0 + public static let preemphasis: Float = 0.97 + } + + /// Special tokens. + public enum SpecialTokens { + /// Unknown token. + public static let unkToken: Int = 0 + /// No speech token. + public static let noSpeechToken: Int = 1 + /// Padding token. + public static let padToken: Int = 2 + /// End of text / End of sequence token. + public static let eosToken: Int = 3 + /// Start of transcript token. + public static let startToken: Int = 4 + } + + /// Supported languages. + public enum Language: String, CaseIterable { + case english = "en" + case french = "fr" + case german = "de" + case spanish = "es" + case italian = "it" + case portuguese = "pt" + case dutch = "nl" + case polish = "pl" + case greek = "el" + case arabic = "ar" + case japanese = "ja" + case chinese = "zh" + case vietnamese = "vi" + case korean = "ko" + + public var englishName: String { + switch self { + case .english: return "English" + case .french: return "French" + case .german: return "German" + case .spanish: return "Spanish" + case .italian: return "Italian" + case .portuguese: return "Portuguese" + case .dutch: return "Dutch" + case .polish: return "Polish" + case .greek: return "Greek" + case .arabic: return "Arabic" + case .japanese: return "Japanese" + case .chinese: return "Chinese" + case .vietnamese: return "Vietnamese" + case .korean: return "Korean" + } + } + } +} diff --git a/Sources/FluidAudio/ASR/Cohere/CohereAsrManager.swift b/Sources/FluidAudio/ASR/Cohere/CohereAsrManager.swift new file mode 100644 index 000000000..7331805a6 --- /dev/null +++ b/Sources/FluidAudio/ASR/Cohere/CohereAsrManager.swift @@ -0,0 +1,299 @@ +import Accelerate +@preconcurrency import CoreML +import Foundation +import OSLog + +private let logger = Logger(subsystem: "FluidAudio", category: "CohereAsrManager") + +// MARK: - Cohere Transcribe ASR Manager + +/// Manages Cohere Transcribe CoreML inference. +/// +/// Pipeline: +/// 1. Audio -> mel spectrogram -> encoder -> hidden states (1, 376, 1024) +/// 2. Decode loop with KV cache: +/// - Feed previous token + encoder_hidden_states +/// - Get logits + updated cache +/// - Sample next token +/// 3. Continue until EOS or max tokens +@available(macOS 14, iOS 17, *) +public actor CohereAsrManager { + private var models: CohereAsrModels? + private let melExtractor: CohereMelSpectrogram + + public init() { + self.melExtractor = CohereMelSpectrogram() + } + + /// Load models from the specified directory. + public func loadModels(from directory: URL, computeUnits: MLComputeUnits = .all) async throws { + models = try await CohereAsrModels.load(from: directory, computeUnits: computeUnits) + logger.info("Cohere Transcribe models loaded successfully") + } + + /// Transcribe raw audio samples. + /// + /// - Parameters: + /// - audioSamples: 16kHz mono Float32 audio samples. + /// - maxNewTokens: Maximum number of tokens to generate. + /// - Returns: Transcribed text. + public func transcribe( + audioSamples: [Float], + maxNewTokens: Int = 200 + ) async throws -> String { + guard let models = models else { + throw CohereAsrError.generationFailed("Models not loaded") + } + + let start = CFAbsoluteTimeGetCurrent() + + // Step 1: Extract mel spectrogram + let mel = melExtractor.compute(audio: audioSamples) + guard !mel.isEmpty else { + throw CohereAsrError.invalidInput("Audio too short to extract mel spectrogram") + } + + let nFrames = mel[0].count + + // Pad to 3001 frames (max length) + let paddedMel = padMelSpectrogram(mel, targetFrames: 3001) + + // Step 2: Encode audio + let encodeStart = CFAbsoluteTimeGetCurrent() + let encoderHidden = try await encodeAudio(paddedMel: paddedMel, featureLength: nFrames, models: models) + let encodeTime = CFAbsoluteTimeGetCurrent() - encodeStart + logger.debug("Encoder: \(String(format: "%.3f", encodeTime))s") + + // Step 3: Decode with KV cache + let decodeStart = CFAbsoluteTimeGetCurrent() + let tokens = try await decode( + encoderHidden: encoderHidden, + maxNewTokens: maxNewTokens, + models: models + ) + let decodeTime = CFAbsoluteTimeGetCurrent() - decodeStart + logger.debug("Decoder: \(String(format: "%.3f", decodeTime))s (\(tokens.count) tokens)") + + let totalTime = CFAbsoluteTimeGetCurrent() - start + logger.info( + "Transcribed \(String(format: "%.2f", Float(audioSamples.count) / 16000.0))s audio in \(String(format: "%.3f", totalTime))s" + ) + + // Step 4: Detokenize + let text = convertTokensToText(tokens, vocabulary: models.vocabulary) + + return text + } + + // MARK: - Private Helpers + + /// Pad mel spectrogram to target number of frames. + private func padMelSpectrogram(_ mel: [[Float]], targetFrames: Int) -> [[Float]] { + let nMels = mel.count + let nFrames = mel[0].count + + guard nFrames < targetFrames else { + return mel + } + + var padded = [[Float]](repeating: [Float](repeating: 0, count: targetFrames), count: nMels) + for m in 0..