Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ void DurationPredictor::scaleDurations(Tensor &durations, size_t nTokens,
shrinking ? std::ceil(scaled) - scaled : scaled - std::floor(scaled);

durationsPtr[i] = static_cast<int64_t>(shrinking ? std::ceil(scaled)
: std::floor(scaled));
: std::floor(scaled));
scaledSum += durationsPtr[i];

// Keeps the entries sorted by the remainders
Expand All @@ -193,4 +193,4 @@ void DurationPredictor::scaleDurations(Tensor &durations, size_t nTokens,
}
}

} // namespace rnexecutorch::models::text_to_speech::kokoro
} // namespace rnexecutorch::models::text_to_speech::kokoro
Original file line number Diff line number Diff line change
Expand Up @@ -39,38 +39,39 @@ Kokoro::Kokoro(const std::string &lang, const std::string &taggerDataSource,
}

void Kokoro::loadVoice(const std::string &voiceSource) {
constexpr size_t rows = static_cast<size_t>(constants::kMaxInputTokens);
constexpr size_t cols = static_cast<size_t>(constants::kVoiceRefSize); // 256
const size_t expectedCount = rows * cols;
const std::streamsize expectedBytes =
static_cast<std::streamsize>(expectedCount * sizeof(float));
constexpr size_t cols = static_cast<size_t>(constants::kVoiceRefSize);
constexpr size_t bytesPerRow = cols * sizeof(float);

std::ifstream in(voiceSource, std::ios::binary);
if (!in) {
throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed,
"[Kokoro::loadSingleVoice]: cannot open file: " +
"[Kokoro::loadVoice]: cannot open file: " +
voiceSource);
}

// Check the file size
// Determine number of rows from file size
in.seekg(0, std::ios::end);
const std::streamsize fileSize = in.tellg();
const auto fileSize = static_cast<size_t>(in.tellg());
in.seekg(0, std::ios::beg);
if (fileSize < expectedBytes) {

if (fileSize < bytesPerRow) {
throw RnExecutorchError(
RnExecutorchErrorCode::FileReadFailed,
"[Kokoro::loadSingleVoice]: file too small: expected at least " +
std::to_string(expectedBytes) + " bytes, got " +
"[Kokoro::loadVoice]: file too small: need at least " +
std::to_string(bytesPerRow) + " bytes for one row, got " +
std::to_string(fileSize));
}

// Read [rows, 1, cols] as contiguous floats directly into voice_
// ([rows][cols])
if (!in.read(reinterpret_cast<char *>(voice_.data()->data()),
expectedBytes)) {
const size_t rows = fileSize / bytesPerRow;
const auto readBytes = static_cast<std::streamsize>(rows * bytesPerRow);

// Resize voice vector to hold all rows from the file
voice_.resize(rows);

if (!in.read(reinterpret_cast<char *>(voice_.data()->data()), readBytes)) {
throw RnExecutorchError(
RnExecutorchErrorCode::FileReadFailed,
"[Kokoro::loadSingleVoice]: failed to read voice weights");
"[Kokoro::loadVoice]: failed to read voice weights");
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never really loaded the entire voice array, cause we don't process inputs longer than 128 tokens - and the rest of the voice data which we omit is used specifically for longer inputs.

I guess reading the entire file is also OK, although it's a little bit of memory waste since we don't use the remainder of the voice data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IgorSwat This change fixed a bunch of the voice files I was trying to use (the kokoro 82m voice files in their repo, not the .bin files y'all provided). I don't think the input length impacts how many rows of the voice weights are needed but I can look into it more here.


Expand All @@ -92,7 +93,6 @@ Kokoro::generateFromPhonemesImpl(const std::u32string &phonemes, float speed) {
size_t pauseMs = params::kPauseValues.contains(lastPhoneme)
? params::kPauseValues.at(lastPhoneme)
: params::kDefaultPause;

// Add audio part and silence pause to the main audio vector
audio.insert(audio.end(), std::make_move_iterator(audioPart.begin()),
std::make_move_iterator(audioPart.end()));
Expand All @@ -108,10 +108,11 @@ void Kokoro::streamFromPhonemesImpl(
std::shared_ptr<jsi::Function> callback) {
auto nativeCallback = [this, callback](const std::vector<float> &audioVec) {
if (this->isStreaming_) {
this->callInvoker_->invokeAsync([callback, audioVec](jsi::Runtime &rt) {
callback->call(rt,
rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt));
});
this->callInvoker_->invokeAsync(
[callback, audioVec = std::move(audioVec)](jsi::Runtime &rt) {
callback->call(
rt, rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt));
});
}
};

Expand Down Expand Up @@ -149,7 +150,7 @@ void Kokoro::streamFromPhonemesImpl(
audioPart.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F);

// Push the audio right away to the JS side
nativeCallback(audioPart);
nativeCallback(std::move(audioPart));
}

isStreaming_ = false;
Expand Down Expand Up @@ -219,7 +220,8 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
const auto tokens = utils::tokenize(phonemes, {noTokens});

// Select the appropriate voice vector
size_t voiceID = std::min(phonemes.size() - 1, noTokens);
size_t voiceID = std::min({phonemes.size() - 1, noTokens - 1,
voice_.size() - 1});
auto &voice = voice_[voiceID];

// Initialize text mask
Expand Down Expand Up @@ -254,9 +256,7 @@ std::vector<float> Kokoro::synthesize(const std::u32string &phonemes,
auto croppedAudio =
utils::stripAudio(audio, paddingMs * constants::kSamplesPerMilisecond);

std::vector<float> result(croppedAudio.begin(), croppedAudio.end());

return result;
return {croppedAudio.begin(), croppedAudio.end()};
}

std::size_t Kokoro::getMemoryLowerBound() const noexcept {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <array>
#include <atomic>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -75,19 +76,16 @@ class Kokoro {
DurationPredictor durationPredictor_;
Synthesizer synthesizer_;

// Voice array
// There is a separate voice vector for each of the possible numbers of input
// tokens.
std::array<std::array<float, constants::kVoiceRefSize>,
constants::kMaxInputTokens>
voice_;
// Voice array — dynamically sized to match the voice file.
// Each row is a style vector for a given input token count.
std::vector<std::array<float, constants::kVoiceRefSize>> voice_;

// Extra control variables
bool isStreaming_ = false;
std::atomic<bool> isStreaming_{false};
};
} // namespace models::text_to_speech::kokoro

REGISTER_CONSTRUCTOR(models::text_to_speech::kokoro::Kokoro, std::string,
std::string, std::string, std::string, std::string,
std::string, std::shared_ptr<react::CallInvoker>);
} // namespace rnexecutorch
} // namespace rnexecutorch
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,34 @@ Synthesizer::Synthesizer(const std::string &modelSource,
const Context &modelContext,
std::shared_ptr<react::CallInvoker> callInvoker)
: BaseModel(modelSource, callInvoker), context_(modelContext) {
const auto inputTensors = getAllInputShapes("forward");
// Discover all forward methods (forward, forward_8, forward_32, etc.)
auto availableMethods = module_->method_names();
if (availableMethods.ok()) {
const auto &names = *availableMethods;
for (const auto &name : names) {
if (name.rfind("forward", 0) == 0) {
const auto inputTensors = getAllInputShapes(name);
CHECK_SIZE(inputTensors, 5);
CHECK_SIZE(inputTensors[0], 2);
CHECK_SIZE(inputTensors[1], 2);
CHECK_SIZE(inputTensors[2], 1);
size_t inputSize = inputTensors[0][1];
forwardMethods_.emplace_back(name, inputSize);
}
}
std::stable_sort(forwardMethods_.begin(), forwardMethods_.end(),
[](const auto &a, const auto &b) { return a.second < b.second; });
}

// Perform checks to validate model's compatibility with native code
CHECK_SIZE(inputTensors, 5);
CHECK_SIZE(
inputTensors[0],
2); // input tokens must be of shape {1, T}, where T is number of tokens
CHECK_SIZE(
inputTensors[1],
2); // text mask must be of shape {1, T}, where T is number of tokens
CHECK_SIZE(inputTensors[2],
1); // indices must be of shape {D}, where D is a maximum duration
// Fallback: if no methods discovered, validate "forward" directly
if (forwardMethods_.empty()) {
const auto inputTensors = getAllInputShapes("forward");
CHECK_SIZE(inputTensors, 5);
CHECK_SIZE(inputTensors[0], 2);
CHECK_SIZE(inputTensors[1], 2);
CHECK_SIZE(inputTensors[2], 1);
forwardMethods_.emplace_back("forward", inputTensors[0][1]);
}
}

Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
Expand Down Expand Up @@ -54,14 +70,19 @@ Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
auto voiceRefTensor = make_tensor_ptr({1, constants::kVoiceRefSize},
ref_s.data(), ScalarType::Float);

// Execute the appropriate "forward_xyz" method, based on given method name
auto results = forward(
// Select appropriate forward method based on token count
auto it = std::ranges::find_if(forwardMethods_,
[noTokens](const auto &entry) { return static_cast<int32_t>(entry.second) >= noTokens; });
std::string selectedMethod = (it != forwardMethods_.end()) ? it->first : forwardMethods_.back().first;

// Execute the selected forward method
auto results = execute(selectedMethod,
{tokensTensor, textMaskTensor, indicesTensor, durTensor, voiceRefTensor});

if (!results.ok()) {
throw RnExecutorchError(
RnExecutorchErrorCode::InvalidModelOutput,
"[Kokoro::Synthesizer] Failed to execute method forward"
"[Kokoro::Synthesizer] Failed to execute method " + selectedMethod +
", error: " +
std::to_string(static_cast<uint32_t>(results.error())));
}
Expand All @@ -72,13 +93,12 @@ Result<std::vector<EValue>> Synthesizer::generate(std::span<const Token> tokens,
}

size_t Synthesizer::getTokensLimit() const {
// Returns tokens input (shape {1, T}) second dim
return getInputShape("forward", 0)[1];
return forwardMethods_.empty() ? 0 : forwardMethods_.back().second;
}

size_t Synthesizer::getDurationLimit() const {
// Returns indices vector first dim (shape {D})
return getInputShape("forward", 2)[0];
if (forwardMethods_.empty()) return 0;
return getInputShape(forwardMethods_.back().first, 2)[0];
}

} // namespace rnexecutorch::models::text_to_speech::kokoro
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class Synthesizer : public BaseModel {
size_t getDurationLimit() const;

private:
// Forward methods discovered at construction (e.g. forward_8, forward_64, forward_128)
std::vector<std::pair<std::string, size_t>> forwardMethods_;
// Shared model context
// A const reference to singleton in Kokoro.
const Context &context_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ std::span<const float> stripAudio(std::span<const float> audio, size_t margin) {
auto lbound = findAudioBound<false>(audio);
auto rbound = findAudioBound<true>(audio);

lbound = std::max(lbound - margin, size_t(0));
rbound = std::min(rbound + margin, audio.size() - 1);
lbound = lbound > margin ? lbound - margin : 0;
rbound = std::min(rbound + margin, audio.size() > 0 ? audio.size() - 1 : 0);

return audio.subspan(lbound, rbound >= lbound ? rbound - lbound + 1 : 0);
}
Expand Down Expand Up @@ -85,7 +85,7 @@ std::vector<Token> tokenize(const std::u32string &phonemes,
? constants::kVocab.at(p)
: constants::kInvalidToken;
});
auto validSeqEnd = std::partition(
auto validSeqEnd = std::stable_partition(
tokens.begin() + 1, tokens.begin() + effNoTokens + 1,
[](Token t) -> bool { return t != constants::kInvalidToken; });
std::fill(validSeqEnd, tokens.begin() + effNoTokens + 1,
Expand All @@ -94,4 +94,4 @@ std::vector<Token> tokenize(const std::u32string &phonemes,
return tokens;
}

} // namespace rnexecutorch::models::text_to_speech::kokoro::utils
} // namespace rnexecutorch::models::text_to_speech::kokoro::utils
Loading