-
Notifications
You must be signed in to change notification settings - Fork 61
fix: speech to text live transcription #816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bab2ffb
ea943e4
b54e469
ce5a39a
278985c
2a37867
f42351b
2ee6d1d
915c8e7
7029184
7aac36d
6e84c3d
d253381
9041e0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,54 +1,53 @@ | ||
| #include <thread> | ||
|
|
||
| #include "SpeechToText.h" | ||
| #include "common/types/TranscriptionResult.h" | ||
| #include "whisper/ASR.h" | ||
| #include "whisper/OnlineASR.h" | ||
| #include <rnexecutorch/Error.h> | ||
| #include <rnexecutorch/ErrorCodes.h> | ||
| #include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h> | ||
|
|
||
| namespace rnexecutorch::models::speech_to_text { | ||
|
|
||
| using namespace ::executorch::extension; | ||
| using namespace asr; | ||
| using namespace types; | ||
| using namespace stream; | ||
|
|
||
| SpeechToText::SpeechToText(const std::string &encoderSource, | ||
| const std::string &decoderSource, | ||
| SpeechToText::SpeechToText(const std::string &modelName, | ||
| const std::string &modelSource, | ||
| const std::string &tokenizerSource, | ||
| std::shared_ptr<react::CallInvoker> callInvoker) | ||
| : callInvoker(std::move(callInvoker)), | ||
| encoder(std::make_unique<BaseModel>(encoderSource, this->callInvoker)), | ||
| decoder(std::make_unique<BaseModel>(decoderSource, this->callInvoker)), | ||
| tokenizer(std::make_unique<TokenizerModule>(tokenizerSource, | ||
| this->callInvoker)), | ||
| asr(std::make_unique<ASR>(this->encoder.get(), this->decoder.get(), | ||
| this->tokenizer.get())), | ||
| processor(std::make_unique<OnlineASRProcessor>(this->asr.get())), | ||
| isStreaming(false), readyToProcess(false) {} | ||
|
|
||
| void SpeechToText::unload() noexcept { | ||
| this->encoder->unload(); | ||
| this->decoder->unload(); | ||
| : callInvoker_(std::move(callInvoker)) { | ||
| // Switch between the ASR implementations based on model name | ||
| if (modelName == "whisper") { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. food for thought: as we discussed a few days back, think about how we can make it work so that the native side doesn't need the model name, but accepts a bunch of configurable pipeline steps. no need to do this now IMO, but just a note. Maybe we can have different ASR implementations based on whether the model does support timestamps or not? |
||
| transcriber_ = std::make_unique<whisper::ASR>(modelSource, tokenizerSource, | ||
| callInvoker_); | ||
| streamer_ = std::make_unique<whisper::stream::OnlineASR>( | ||
| static_cast<const whisper::ASR *>(transcriber_.get())); | ||
| } else { | ||
| throw rnexecutorch::RnExecutorchError( | ||
| rnexecutorch::RnExecutorchErrorCode::InvalidConfig, | ||
| "[SpeechToText]: Invalid model name: " + modelName); | ||
| } | ||
| } | ||
|
|
||
| void SpeechToText::unload() noexcept { transcriber_->unload(); } | ||
|
|
||
| std::shared_ptr<OwningArrayBuffer> | ||
| SpeechToText::encode(std::span<float> waveform) const { | ||
| std::vector<float> encoderOutput = this->asr->encode(waveform); | ||
| std::vector<float> encoderOutput = transcriber_->encode(waveform); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking whether we need to return std::vector from the encoder? Maybe we would just return a span. We wrap this in OwningArrayBuffer, which copies the data. |
||
| return std::make_shared<OwningArrayBuffer>(encoderOutput); | ||
| } | ||
|
|
||
| std::shared_ptr<OwningArrayBuffer> | ||
| SpeechToText::decode(std::span<uint64_t> tokens, | ||
| std::span<float> encoderOutput) const { | ||
| std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput); | ||
| std::vector<float> decoderOutput = | ||
| transcriber_->decode(tokens, encoderOutput); | ||
| return std::make_shared<OwningArrayBuffer>(decoderOutput); | ||
| } | ||
|
|
||
| TranscriptionResult SpeechToText::transcribe(std::span<float> waveform, | ||
| std::string languageOption, | ||
| bool verbose) const { | ||
| DecodingOptions options(languageOption, verbose); | ||
| std::vector<Segment> segments = this->asr->transcribe(waveform, options); | ||
| std::vector<Segment> segments = transcriber_->transcribe(waveform, options); | ||
|
|
||
| std::string fullText; | ||
| for (const auto &segment : segments) { | ||
|
|
@@ -70,8 +69,7 @@ TranscriptionResult SpeechToText::transcribe(std::span<float> waveform, | |
| } | ||
|
|
||
| size_t SpeechToText::getMemoryLowerBound() const noexcept { | ||
| return this->encoder->getMemoryLowerBound() + | ||
| this->decoder->getMemoryLowerBound(); | ||
| return transcriber_->getMemoryLowerBound(); | ||
| } | ||
|
|
||
| namespace { | ||
|
|
@@ -83,7 +81,7 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words, | |
|
|
||
| std::string fullText; | ||
| for (const auto &w : words) { | ||
| fullText += w.content; | ||
| fullText += w.content + w.punctations; | ||
| } | ||
| res.text = fullText; | ||
|
|
||
|
|
@@ -105,68 +103,70 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words, | |
|
|
||
| void SpeechToText::stream(std::shared_ptr<jsi::Function> callback, | ||
| std::string languageOption, bool verbose) { | ||
| if (this->isStreaming) { | ||
| if (isStreaming_) { | ||
| throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress, | ||
| "Streaming is already in progress!"); | ||
| } | ||
|
|
||
| auto nativeCallback = [this, callback, | ||
| verbose](const TranscriptionResult &committed, | ||
| const TranscriptionResult &nonCommitted, | ||
| bool isDone) { | ||
| // This moves execution to the JS thread | ||
| this->callInvoker->invokeAsync( | ||
| [callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) { | ||
| jsi::Value jsiCommitted = | ||
| rnexecutorch::jsi_conversion::getJsiValue(committed, rt); | ||
| jsi::Value jsiNonCommitted = | ||
| rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt); | ||
|
|
||
| callback->call(rt, std::move(jsiCommitted), | ||
| std::move(jsiNonCommitted), jsi::Value(isDone)); | ||
| }); | ||
| }; | ||
|
|
||
| this->isStreaming = true; | ||
| auto nativeCallback = | ||
| [this, callback](const TranscriptionResult &committed, | ||
| const TranscriptionResult &nonCommitted, bool isDone) { | ||
| // This moves execution to the JS thread | ||
| callInvoker_->invokeAsync( | ||
| [callback, committed, nonCommitted, isDone](jsi::Runtime &rt) { | ||
| jsi::Value jsiCommitted = | ||
| rnexecutorch::jsi_conversion::getJsiValue(committed, rt); | ||
| jsi::Value jsiNonCommitted = | ||
| rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt); | ||
|
|
||
| callback->call(rt, std::move(jsiCommitted), | ||
| std::move(jsiNonCommitted), jsi::Value(isDone)); | ||
| }); | ||
| }; | ||
|
|
||
| isStreaming_ = true; | ||
| DecodingOptions options(languageOption, verbose); | ||
|
|
||
| while (this->isStreaming) { | ||
| if (!this->readyToProcess || | ||
| this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) { | ||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | ||
| continue; | ||
| } | ||
| while (isStreaming_) { | ||
| if (readyToProcess_ && streamer_->isReady()) { | ||
| ProcessResult res = streamer_->process(options); | ||
|
|
||
| ProcessResult res = this->processor->processIter(options); | ||
| TranscriptionResult cRes = | ||
| wordsToResult(res.committed, languageOption, verbose); | ||
| TranscriptionResult ncRes = | ||
| wordsToResult(res.nonCommitted, languageOption, verbose); | ||
|
|
||
| TranscriptionResult cRes = | ||
| wordsToResult(res.committed, languageOption, verbose); | ||
| TranscriptionResult ncRes = | ||
| wordsToResult(res.nonCommitted, languageOption, verbose); | ||
| nativeCallback(cRes, ncRes, false); | ||
| readyToProcess_ = false; | ||
| } | ||
|
|
||
| nativeCallback(cRes, ncRes, false); | ||
| this->readyToProcess = false; | ||
| // Add a minimal pause between transcriptions. | ||
| // The reasoning is very simple: with the current liberal threshold values, | ||
| // running transcriptions too rapidly (before the audio buffer is filled | ||
| // with significant amount of new data) can cause streamer to commit wrong | ||
| // phrases. | ||
| std::this_thread::sleep_for(std::chrono::milliseconds(75)); | ||
| } | ||
|
|
||
| std::vector<Word> finalWords = this->processor->finish(); | ||
| std::vector<Word> finalWords = streamer_->finish(); | ||
| TranscriptionResult finalRes = | ||
| wordsToResult(finalWords, languageOption, verbose); | ||
|
|
||
| nativeCallback(finalRes, {}, true); | ||
| this->resetStreamState(); | ||
| resetStreamState(); | ||
| } | ||
|
|
||
| void SpeechToText::streamStop() { this->isStreaming = false; } | ||
| void SpeechToText::streamStop() { isStreaming_ = false; } | ||
|
|
||
| void SpeechToText::streamInsert(std::span<float> waveform) { | ||
| this->processor->insertAudioChunk(waveform); | ||
| this->readyToProcess = true; | ||
| streamer_->insertAudioChunk(waveform); | ||
| readyToProcess_ = true; | ||
| } | ||
|
|
||
| void SpeechToText::resetStreamState() { | ||
| this->isStreaming = false; | ||
| this->readyToProcess = false; | ||
| this->processor = std::make_unique<OnlineASRProcessor>(this->asr.get()); | ||
| isStreaming_ = false; | ||
| readyToProcess_ = false; | ||
| streamer_->reset(); | ||
| } | ||
|
|
||
| } // namespace rnexecutorch::models::speech_to_text | ||
Uh oh!
There was an error while loading. Please reload this page.