diff --git a/Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift b/Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift new file mode 100644 index 000000000..07d3b4973 --- /dev/null +++ b/Sources/FluidAudio/ASR/Cohere/CohereAsrConfig.swift @@ -0,0 +1,163 @@ +import Foundation + +/// Configuration for Cohere Transcribe CoreML ASR model. +public enum CohereAsrConfig { + /// Sample rate expected by the model (16kHz). + public static let sampleRate: Int = 16000 + + /// Maximum audio duration in seconds (30s). + public static let maxAudioSeconds: Float = 30.0 + + /// Maximum number of audio samples (480,000 at 16kHz = 30 seconds). + public static let maxSamples: Int = 480_000 + + /// Vocabulary size. + public static let vocabSize: Int = 16_384 + + /// Encoder hidden size (Conformer blocks). + public static let encoderHiddenSize: Int = 1280 + + /// Decoder hidden size. + public static let decoderHiddenSize: Int = 1024 + + /// Number of encoder layers. + public static let numEncoderLayers: Int = 48 + + /// Number of decoder layers. + public static let numDecoderLayers: Int = 8 + + /// Number of attention heads in decoder. + public static let numDecoderHeads: Int = 8 + + /// Head dimension (1024 / 8). + public static let headDim: Int = 128 + + /// Maximum sequence length for decoder KV cache. + public static let maxSeqLen: Int = 108 + + /// Number of mel bins. + public static let numMelBins: Int = 128 + + /// Mel spectrogram parameters. + public enum MelSpec { + public static let nFFT: Int = 1024 + public static let hopLength: Int = 160 + public static let nMels: Int = 128 + public static let fMin: Float = 0.0 + public static let fMax: Float = 8000.0 + public static let preemphasis: Float = 0.97 + } + + /// Special tokens. + public enum SpecialTokens { + /// Unknown token. + public static let unkToken: Int = 0 + /// No speech token. + public static let noSpeechToken: Int = 1 + /// Padding token. + public static let padToken: Int = 2 + /// End of text / End of sequence token. + public static let eosToken: Int = 3 + /// Start of transcript token. + public static let startToken: Int = 4 + /// Start of context token. + public static let startOfContext: Int = 7 + /// Emotion undefined token. + public static let emoUndefined: Int = 16 + /// Punctuation token. + public static let pnc: Int = 5 + /// No inverse text normalization. + public static let noitn: Int = 9 + /// No timestamp token. + public static let notimestamp: Int = 11 + /// No diarization token. + public static let nodiarize: Int = 13 + /// Word boundary marker. + public static let wordBoundary: Int = 13764 + } + + /// Supported languages. + public enum Language: String, CaseIterable, Sendable { + case english = "en" + case french = "fr" + case german = "de" + case spanish = "es" + case italian = "it" + case portuguese = "pt" + case dutch = "nl" + case polish = "pl" + case greek = "el" + case arabic = "ar" + case japanese = "ja" + case chinese = "zh" + case vietnamese = "vi" + case korean = "ko" + + public var englishName: String { + switch self { + case .english: return "English" + case .french: return "French" + case .german: return "German" + case .spanish: return "Spanish" + case .italian: return "Italian" + case .portuguese: return "Portuguese" + case .dutch: return "Dutch" + case .polish: return "Polish" + case .greek: return "Greek" + case .arabic: return "Arabic" + case .japanese: return "Japanese" + case .chinese: return "Chinese" + case .vietnamese: return "Vietnamese" + case .korean: return "Korean" + } + } + + /// Language token ID (used as start token for conditioned generation). + public var tokenId: Int { + switch self { + case .english: return 62 + case .french: return 69 + case .german: return 76 + case .spanish: return 169 + case .italian: return 97 + case .portuguese: return 149 + case .dutch: return 60 + case .polish: return 148 + case .greek: return 77 + case .arabic: return 28 + case .japanese: return 98 + case .chinese: return 50 + case .vietnamese: return 194 + case .korean: return 110 + } + } + + /// Build the prompt sequence for this language. + /// + /// Cohere models expect a specific prompt sequence: + /// 1. Word boundary marker + /// 2. Start of context + /// 3. Start of transcript + /// 4. Emotion undefined + /// 5-6. Language token (repeated twice) + /// 7. Punctuation + /// 8. No inverse text normalization + /// 9. No timestamp + /// 10. No diarization + public var promptSequence: [Int] { + let langToken = tokenId + return [ + SpecialTokens.wordBoundary, // ▁ + SpecialTokens.startOfContext, // <|startofcontext|> + SpecialTokens.startToken, // <|startoftranscript|> + SpecialTokens.emoUndefined, // <|emo:undefined|> + langToken, // <|en|> (or other language) + langToken, // <|en|> (repeated) + SpecialTokens.pnc, // <|pnc|> + SpecialTokens.noitn, // <|noitn|> + SpecialTokens.notimestamp, // <|notimestamp|> + SpecialTokens.nodiarize, // <|nodiarize|> + ] + } + } +} diff --git a/Sources/FluidAudio/ASR/Cohere/CohereAsrManager.swift b/Sources/FluidAudio/ASR/Cohere/CohereAsrManager.swift new file mode 100644 index 000000000..67ca7ee5b --- /dev/null +++ b/Sources/FluidAudio/ASR/Cohere/CohereAsrManager.swift @@ -0,0 +1,459 @@ +import Accelerate +@preconcurrency import CoreML +import Foundation +import OSLog + +private let logger = Logger(subsystem: "FluidAudio", category: "CohereAsrManager") + +// MARK: - Cohere Transcribe ASR Manager + +/// Manages Cohere Transcribe CoreML inference. +/// +/// Pipeline: +/// 1. Audio -> mel spectrogram -> encoder -> hidden states (1, 376, 1024) +/// 2. Decode loop with KV cache: +/// - Feed previous token + encoder_hidden_states +/// - Get logits + updated cache +/// - Sample next token +/// 3. Continue until EOS or max tokens +@available(macOS 14, iOS 17, *) +public actor CohereAsrManager { + private var models: CohereAsrModels? + private let melExtractor: CohereMelSpectrogram + + public init() { + self.melExtractor = CohereMelSpectrogram() + } + + /// Load models from the specified directory. + public func loadModels(from directory: URL, computeUnits: MLComputeUnits = .all) async throws { + models = try await CohereAsrModels.load(from: directory, computeUnits: computeUnits) + logger.info("Cohere Transcribe models loaded successfully") + } + + /// Transcribe raw audio samples. + /// + /// - Important: The cache-external decoder only works reliably for **Spanish** (18-24% WER). + /// Other languages may hallucinate and produce wrong-language output (>50% WER). + /// For multilingual ASR, use Whisper or Qwen3 models instead. + /// + /// - Parameters: + /// - audioSamples: 16kHz mono Float32 audio samples. + /// - language: Target language for transcription. Only `.spanish` is reliable. + /// - maxNewTokens: Maximum number of tokens to generate. + /// - Returns: Transcribed text. + public func transcribe( + audioSamples: [Float], + language: CohereAsrConfig.Language? = .english, + maxNewTokens: Int = 200 + ) async throws -> String { + guard let models = models else { + throw CohereAsrError.generationFailed("Models not loaded") + } + + // IMPORTANT: Cache-external decoder only works reliably for Spanish + // Other languages may hallucinate (produce wrong-language output) + // For multilingual ASR, use Whisper or Qwen3 models instead + if let lang = language, lang != .spanish { + logger.warning( + "Cache-external decoder only supports Spanish reliably. Language '\(lang.rawValue)' may produce incorrect output. Consider using Whisper or Qwen3 for multilingual ASR." + ) + } + + let start = CFAbsoluteTimeGetCurrent() + + // Step 1: Extract mel spectrogram + let mel = melExtractor.compute(audio: audioSamples) + guard !mel.isEmpty else { + throw CohereAsrError.invalidInput("Audio too short to extract mel spectrogram") + } + + let nFrames = mel[0].count + + // Pad to 3500 frames (max length) + let paddedMel = padMelSpectrogram(mel, targetFrames: 3500) + + // Step 2: Encode audio + let encodeStart = CFAbsoluteTimeGetCurrent() + let encoderHidden = try await encodeAudio(paddedMel: paddedMel, featureLength: nFrames, models: models) + let encodeTime = CFAbsoluteTimeGetCurrent() - encodeStart + logger.debug("Encoder: \(String(format: "%.3f", encodeTime))s") + + // Step 3: Decode with KV cache + let decodeStart = CFAbsoluteTimeGetCurrent() + let tokens: [Int] + + // Use cache-external decoder (stateful not supported on macOS) + tokens = try await decodeCacheExternal( + encoderHidden: encoderHidden, + language: language, + maxNewTokens: maxNewTokens, + models: models + ) + let decodeTime = CFAbsoluteTimeGetCurrent() - decodeStart + logger.debug("Decoder: \(String(format: "%.3f", decodeTime))s (\(tokens.count) tokens)") + + let totalTime = CFAbsoluteTimeGetCurrent() - start + logger.info( + "Transcribed \(String(format: "%.2f", Float(audioSamples.count) / 16000.0))s audio in \(String(format: "%.3f", totalTime))s" + ) + + // Step 4: Detokenize + let text = convertTokensToText(tokens, vocabulary: models.vocabulary) + + return text + } + + // MARK: - Private Helpers + + /// Pad mel spectrogram to target number of frames. + private func padMelSpectrogram(_ mel: [[Float]], targetFrames: Int) -> [[Float]] { + let nMels = mel.count + let nFrames = mel[0].count + + guard nFrames < targetFrames else { + return mel + } + + var padded = [[Float]](repeating: [Float](repeating: 0, count: targetFrames), count: nMels) + for m in 0..