diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp index bf34b68a2..ff71d2b53 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp @@ -175,7 +175,7 @@ void DurationPredictor::scaleDurations(Tensor &durations, size_t nTokens, shrinking ? std::ceil(scaled) - scaled : scaled - std::floor(scaled); durationsPtr[i] = static_cast(shrinking ? std::ceil(scaled) - : std::floor(scaled)); + : std::floor(scaled)); scaledSum += durationsPtr[i]; // Keeps the entries sorted by the remainders @@ -193,4 +193,4 @@ void DurationPredictor::scaleDurations(Tensor &durations, size_t nTokens, } } -} // namespace rnexecutorch::models::text_to_speech::kokoro \ No newline at end of file +} // namespace rnexecutorch::models::text_to_speech::kokoro diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp index 52da0fc46..d7da13b94 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp @@ -39,38 +39,39 @@ Kokoro::Kokoro(const std::string &lang, const std::string &taggerDataSource, } void Kokoro::loadVoice(const std::string &voiceSource) { - constexpr size_t rows = static_cast(constants::kMaxInputTokens); - constexpr size_t cols = static_cast(constants::kVoiceRefSize); // 256 - const size_t expectedCount = rows * cols; - const std::streamsize expectedBytes = - static_cast(expectedCount * sizeof(float)); + constexpr size_t cols = static_cast(constants::kVoiceRefSize); + constexpr size_t bytesPerRow = cols * sizeof(float); std::ifstream in(voiceSource, std::ios::binary); if (!in) { throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, - "[Kokoro::loadSingleVoice]: cannot open file: " + + "[Kokoro::loadVoice]: cannot open file: " + voiceSource); } - // Check the file size + // Determine number of rows from file size in.seekg(0, std::ios::end); - const std::streamsize fileSize = in.tellg(); + const auto fileSize = static_cast(in.tellg()); in.seekg(0, std::ios::beg); - if (fileSize < expectedBytes) { + + if (fileSize < bytesPerRow) { throw RnExecutorchError( RnExecutorchErrorCode::FileReadFailed, - "[Kokoro::loadSingleVoice]: file too small: expected at least " + - std::to_string(expectedBytes) + " bytes, got " + + "[Kokoro::loadVoice]: file too small: need at least " + + std::to_string(bytesPerRow) + " bytes for one row, got " + std::to_string(fileSize)); } - // Read [rows, 1, cols] as contiguous floats directly into voice_ - // ([rows][cols]) - if (!in.read(reinterpret_cast(voice_.data()->data()), - expectedBytes)) { + const size_t rows = fileSize / bytesPerRow; + const auto readBytes = static_cast(rows * bytesPerRow); + + // Resize voice vector to hold all rows from the file + voice_.resize(rows); + + if (!in.read(reinterpret_cast(voice_.data()->data()), readBytes)) { throw RnExecutorchError( RnExecutorchErrorCode::FileReadFailed, - "[Kokoro::loadSingleVoice]: failed to read voice weights"); + "[Kokoro::loadVoice]: failed to read voice weights"); } } @@ -92,7 +93,6 @@ Kokoro::generateFromPhonemesImpl(const std::u32string &phonemes, float speed) { size_t pauseMs = params::kPauseValues.contains(lastPhoneme) ? params::kPauseValues.at(lastPhoneme) : params::kDefaultPause; - // Add audio part and silence pause to the main audio vector audio.insert(audio.end(), std::make_move_iterator(audioPart.begin()), std::make_move_iterator(audioPart.end())); @@ -108,10 +108,11 @@ void Kokoro::streamFromPhonemesImpl( std::shared_ptr callback) { auto nativeCallback = [this, callback](const std::vector &audioVec) { if (this->isStreaming_) { - this->callInvoker_->invokeAsync([callback, audioVec](jsi::Runtime &rt) { - callback->call(rt, - rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt)); - }); + this->callInvoker_->invokeAsync( + [callback, audioVec = std::move(audioVec)](jsi::Runtime &rt) { + callback->call( + rt, rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt)); + }); } }; @@ -149,7 +150,7 @@ void Kokoro::streamFromPhonemesImpl( audioPart.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F); // Push the audio right away to the JS side - nativeCallback(audioPart); + nativeCallback(std::move(audioPart)); } isStreaming_ = false; @@ -219,7 +220,8 @@ std::vector Kokoro::synthesize(const std::u32string &phonemes, const auto tokens = utils::tokenize(phonemes, {noTokens}); // Select the appropriate voice vector - size_t voiceID = std::min(phonemes.size() - 1, noTokens); + size_t voiceID = std::min({phonemes.size() - 1, noTokens - 1, + voice_.size() - 1}); auto &voice = voice_[voiceID]; // Initialize text mask @@ -254,9 +256,7 @@ std::vector Kokoro::synthesize(const std::u32string &phonemes, auto croppedAudio = utils::stripAudio(audio, paddingMs * constants::kSamplesPerMilisecond); - std::vector result(croppedAudio.begin(), croppedAudio.end()); - - return result; + return {croppedAudio.begin(), croppedAudio.end()}; } std::size_t Kokoro::getMemoryLowerBound() const noexcept { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h index d7a4c2ae6..47fdab769 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -75,19 +76,16 @@ class Kokoro { DurationPredictor durationPredictor_; Synthesizer synthesizer_; - // Voice array - // There is a separate voice vector for each of the possible numbers of input - // tokens. - std::array, - constants::kMaxInputTokens> - voice_; + // Voice array — dynamically sized to match the voice file. + // Each row is a style vector for a given input token count. + std::vector> voice_; // Extra control variables - bool isStreaming_ = false; + std::atomic isStreaming_{false}; }; } // namespace models::text_to_speech::kokoro REGISTER_CONSTRUCTOR(models::text_to_speech::kokoro::Kokoro, std::string, std::string, std::string, std::string, std::string, std::string, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp index 121337937..fd69c43ee 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp @@ -13,18 +13,34 @@ Synthesizer::Synthesizer(const std::string &modelSource, const Context &modelContext, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker), context_(modelContext) { - const auto inputTensors = getAllInputShapes("forward"); + // Discover all forward methods (forward, forward_8, forward_32, etc.) + auto availableMethods = module_->method_names(); + if (availableMethods.ok()) { + const auto &names = *availableMethods; + for (const auto &name : names) { + if (name.rfind("forward", 0) == 0) { + const auto inputTensors = getAllInputShapes(name); + CHECK_SIZE(inputTensors, 5); + CHECK_SIZE(inputTensors[0], 2); + CHECK_SIZE(inputTensors[1], 2); + CHECK_SIZE(inputTensors[2], 1); + size_t inputSize = inputTensors[0][1]; + forwardMethods_.emplace_back(name, inputSize); + } + } + std::stable_sort(forwardMethods_.begin(), forwardMethods_.end(), + [](const auto &a, const auto &b) { return a.second < b.second; }); + } - // Perform checks to validate model's compatibility with native code - CHECK_SIZE(inputTensors, 5); - CHECK_SIZE( - inputTensors[0], - 2); // input tokens must be of shape {1, T}, where T is number of tokens - CHECK_SIZE( - inputTensors[1], - 2); // text mask must be of shape {1, T}, where T is number of tokens - CHECK_SIZE(inputTensors[2], - 1); // indices must be of shape {D}, where D is a maximum duration + // Fallback: if no methods discovered, validate "forward" directly + if (forwardMethods_.empty()) { + const auto inputTensors = getAllInputShapes("forward"); + CHECK_SIZE(inputTensors, 5); + CHECK_SIZE(inputTensors[0], 2); + CHECK_SIZE(inputTensors[1], 2); + CHECK_SIZE(inputTensors[2], 1); + forwardMethods_.emplace_back("forward", inputTensors[0][1]); + } } Result> Synthesizer::generate(std::span tokens, @@ -54,14 +70,19 @@ Result> Synthesizer::generate(std::span tokens, auto voiceRefTensor = make_tensor_ptr({1, constants::kVoiceRefSize}, ref_s.data(), ScalarType::Float); - // Execute the appropriate "forward_xyz" method, based on given method name - auto results = forward( + // Select appropriate forward method based on token count + auto it = std::ranges::find_if(forwardMethods_, + [noTokens](const auto &entry) { return static_cast(entry.second) >= noTokens; }); + std::string selectedMethod = (it != forwardMethods_.end()) ? it->first : forwardMethods_.back().first; + + // Execute the selected forward method + auto results = execute(selectedMethod, {tokensTensor, textMaskTensor, indicesTensor, durTensor, voiceRefTensor}); if (!results.ok()) { throw RnExecutorchError( RnExecutorchErrorCode::InvalidModelOutput, - "[Kokoro::Synthesizer] Failed to execute method forward" + "[Kokoro::Synthesizer] Failed to execute method " + selectedMethod + ", error: " + std::to_string(static_cast(results.error()))); } @@ -72,13 +93,12 @@ Result> Synthesizer::generate(std::span tokens, } size_t Synthesizer::getTokensLimit() const { - // Returns tokens input (shape {1, T}) second dim - return getInputShape("forward", 0)[1]; + return forwardMethods_.empty() ? 0 : forwardMethods_.back().second; } size_t Synthesizer::getDurationLimit() const { - // Returns indices vector first dim (shape {D}) - return getInputShape("forward", 2)[0]; + if (forwardMethods_.empty()) return 0; + return getInputShape(forwardMethods_.back().first, 2)[0]; } } // namespace rnexecutorch::models::text_to_speech::kokoro \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h index 2c6f47d0a..bfbbd0263 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h @@ -50,6 +50,8 @@ class Synthesizer : public BaseModel { size_t getDurationLimit() const; private: + // Forward methods discovered at construction (e.g. forward_8, forward_64, forward_128) + std::vector> forwardMethods_; // Shared model context // A const reference to singleton in Kokoro. const Context &context_; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp index 18261e3f7..a77e40a93 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp @@ -55,8 +55,8 @@ std::span stripAudio(std::span audio, size_t margin) { auto lbound = findAudioBound(audio); auto rbound = findAudioBound(audio); - lbound = std::max(lbound - margin, size_t(0)); - rbound = std::min(rbound + margin, audio.size() - 1); + lbound = lbound > margin ? lbound - margin : 0; + rbound = std::min(rbound + margin, audio.size() > 0 ? audio.size() - 1 : 0); return audio.subspan(lbound, rbound >= lbound ? rbound - lbound + 1 : 0); } @@ -85,7 +85,7 @@ std::vector tokenize(const std::u32string &phonemes, ? constants::kVocab.at(p) : constants::kInvalidToken; }); - auto validSeqEnd = std::partition( + auto validSeqEnd = std::stable_partition( tokens.begin() + 1, tokens.begin() + effNoTokens + 1, [](Token t) -> bool { return t != constants::kInvalidToken; }); std::fill(validSeqEnd, tokens.begin() + effNoTokens + 1, @@ -94,4 +94,4 @@ std::vector tokenize(const std::u32string &phonemes, return tokens; } -} // namespace rnexecutorch::models::text_to_speech::kokoro::utils \ No newline at end of file +} // namespace rnexecutorch::models::text_to_speech::kokoro::utils