Skip to content
Open
8 changes: 4 additions & 4 deletions apps/speech/screens/SpeechToTextScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
const [liveTranscribing, setLiveTranscribing] = useState(false);
const scrollViewRef = useRef<ScrollView>(null);

const recorder = new AudioRecorder();
const recorder = useRef(new AudioRecorder());

useEffect(() => {
AudioManager.setAudioSessionOptions({
Expand Down Expand Up @@ -115,7 +115,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {

const sampleRate = 16000;

recorder.onAudioReady(
recorder.current.onAudioReady(
{
sampleRate,
bufferLength: 0.1 * sampleRate,
Expand All @@ -131,7 +131,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
if (!success) {
console.warn('Cannot start audio session correctly');
}
const result = recorder.start();
const result = recorder.current.start();
if (result.status === 'error') {
console.warn('Recording problems: ', result.message);
}
Expand Down Expand Up @@ -177,7 +177,7 @@ export const SpeechToTextScreen = ({ onBack }: { onBack: () => void }) => {
const handleStopTranscribeFromMicrophone = () => {
isRecordingRef.current = false;

recorder.stop();
recorder.current.stop();
model.streamStop();
console.log('Live transcription stopped');
setLiveTranscribing(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
#include <rnexecutorch/models/object_detection/Constants.h>
#include <rnexecutorch/models/object_detection/Types.h>
#include <rnexecutorch/models/ocr/Types.h>
#include <rnexecutorch/models/speech_to_text/types/Segment.h>
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
#include <rnexecutorch/models/speech_to_text/common/types/Segment.h>
#include <rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h>
#include <rnexecutorch/models/voice_activity_detection/Types.h>

using namespace rnexecutorch::models::speech_to_text::types;
using namespace rnexecutorch::models::speech_to_text;

namespace rnexecutorch::jsi_conversion {

Expand Down Expand Up @@ -507,7 +507,8 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) {
jsi::Object wordObj(runtime);
wordObj.setProperty(
runtime, "word",
jsi::String::createFromUtf8(runtime, seg.words[i].content));
jsi::String::createFromUtf8(runtime, seg.words[i].content +
seg.words[i].punctations));
wordObj.setProperty(runtime, "start",
static_cast<double>(seg.words[i].start));
wordObj.setProperty(runtime, "end", static_cast<double>(seg.words[i].end));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#include <string>
#include <vector>

#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
#include <ReactCommon/CallInvoker.h>
#include <executorch/extension/module/module.h>
#include <jsi/jsi.h>
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>

namespace rnexecutorch {
namespace models {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,54 +1,53 @@
#include <thread>

#include "SpeechToText.h"
#include "common/types/TranscriptionResult.h"
#include "whisper/ASR.h"
#include "whisper/OnlineASR.h"
#include <rnexecutorch/Error.h>
#include <rnexecutorch/ErrorCodes.h>
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>

namespace rnexecutorch::models::speech_to_text {

using namespace ::executorch::extension;
using namespace asr;
using namespace types;
using namespace stream;

SpeechToText::SpeechToText(const std::string &encoderSource,
const std::string &decoderSource,
SpeechToText::SpeechToText(const std::string &modelName,
const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr<react::CallInvoker> callInvoker)
: callInvoker(std::move(callInvoker)),
encoder(std::make_unique<BaseModel>(encoderSource, this->callInvoker)),
decoder(std::make_unique<BaseModel>(decoderSource, this->callInvoker)),
tokenizer(std::make_unique<TokenizerModule>(tokenizerSource,
this->callInvoker)),
asr(std::make_unique<ASR>(this->encoder.get(), this->decoder.get(),
this->tokenizer.get())),
processor(std::make_unique<OnlineASRProcessor>(this->asr.get())),
isStreaming(false), readyToProcess(false) {}

void SpeechToText::unload() noexcept {
this->encoder->unload();
this->decoder->unload();
: callInvoker_(std::move(callInvoker)) {
// Switch between the ASR implementations based on model name
if (modelName == "whisper") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

food for thought: as we discussed a few days back, think about how we can make it work so that the native side doesn't need the model name, but accepts a bunch of configurable pipeline steps. no need to do this now IMO, but just a note.

Maybe we can have different ASR implementations based on whether the model does support timestamps or not?

transcriber_ = std::make_unique<whisper::ASR>(modelSource, tokenizerSource,
callInvoker_);
streamer_ = std::make_unique<whisper::stream::OnlineASR>(
static_cast<const whisper::ASR *>(transcriber_.get()));
} else {
throw rnexecutorch::RnExecutorchError(
rnexecutorch::RnExecutorchErrorCode::InvalidConfig,
"[SpeechToText]: Invalid model name: " + modelName);
}
}

void SpeechToText::unload() noexcept { transcriber_->unload(); }

std::shared_ptr<OwningArrayBuffer>
SpeechToText::encode(std::span<float> waveform) const {
std::vector<float> encoderOutput = this->asr->encode(waveform);
std::vector<float> encoderOutput = transcriber_->encode(waveform);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking whether we need to return std::vector from the encoder? Maybe we would just return a span. We wrap this in OwningArrayBuffer, which copies the data.

return std::make_shared<OwningArrayBuffer>(encoderOutput);
}

std::shared_ptr<OwningArrayBuffer>
SpeechToText::decode(std::span<uint64_t> tokens,
std::span<float> encoderOutput) const {
std::vector<float> decoderOutput = this->asr->decode(tokens, encoderOutput);
std::vector<float> decoderOutput =
transcriber_->decode(tokens, encoderOutput);
return std::make_shared<OwningArrayBuffer>(decoderOutput);
}

TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
std::string languageOption,
bool verbose) const {
DecodingOptions options(languageOption, verbose);
std::vector<Segment> segments = this->asr->transcribe(waveform, options);
std::vector<Segment> segments = transcriber_->transcribe(waveform, options);

std::string fullText;
for (const auto &segment : segments) {
Expand All @@ -70,8 +69,7 @@ TranscriptionResult SpeechToText::transcribe(std::span<float> waveform,
}

size_t SpeechToText::getMemoryLowerBound() const noexcept {
return this->encoder->getMemoryLowerBound() +
this->decoder->getMemoryLowerBound();
return transcriber_->getMemoryLowerBound();
}

namespace {
Expand All @@ -83,7 +81,7 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,

std::string fullText;
for (const auto &w : words) {
fullText += w.content;
fullText += w.content + w.punctations;
}
res.text = fullText;

Expand All @@ -105,68 +103,70 @@ TranscriptionResult wordsToResult(const std::vector<Word> &words,

void SpeechToText::stream(std::shared_ptr<jsi::Function> callback,
std::string languageOption, bool verbose) {
if (this->isStreaming) {
if (isStreaming_) {
throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress,
"Streaming is already in progress!");
}

auto nativeCallback = [this, callback,
verbose](const TranscriptionResult &committed,
const TranscriptionResult &nonCommitted,
bool isDone) {
// This moves execution to the JS thread
this->callInvoker->invokeAsync(
[callback, committed, nonCommitted, isDone, verbose](jsi::Runtime &rt) {
jsi::Value jsiCommitted =
rnexecutorch::jsi_conversion::getJsiValue(committed, rt);
jsi::Value jsiNonCommitted =
rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt);

callback->call(rt, std::move(jsiCommitted),
std::move(jsiNonCommitted), jsi::Value(isDone));
});
};

this->isStreaming = true;
auto nativeCallback =
[this, callback](const TranscriptionResult &committed,
const TranscriptionResult &nonCommitted, bool isDone) {
// This moves execution to the JS thread
callInvoker_->invokeAsync(
[callback, committed, nonCommitted, isDone](jsi::Runtime &rt) {
jsi::Value jsiCommitted =
rnexecutorch::jsi_conversion::getJsiValue(committed, rt);
jsi::Value jsiNonCommitted =
rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt);

callback->call(rt, std::move(jsiCommitted),
std::move(jsiNonCommitted), jsi::Value(isDone));
});
};

isStreaming_ = true;
DecodingOptions options(languageOption, verbose);

while (this->isStreaming) {
if (!this->readyToProcess ||
this->processor->audioBuffer.size() < SpeechToText::kMinAudioSamples) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
continue;
}
while (isStreaming_) {
if (readyToProcess_ && streamer_->isReady()) {
ProcessResult res = streamer_->process(options);

ProcessResult res = this->processor->processIter(options);
TranscriptionResult cRes =
wordsToResult(res.committed, languageOption, verbose);
TranscriptionResult ncRes =
wordsToResult(res.nonCommitted, languageOption, verbose);

TranscriptionResult cRes =
wordsToResult(res.committed, languageOption, verbose);
TranscriptionResult ncRes =
wordsToResult(res.nonCommitted, languageOption, verbose);
nativeCallback(cRes, ncRes, false);
readyToProcess_ = false;
}

nativeCallback(cRes, ncRes, false);
this->readyToProcess = false;
// Add a minimal pause between transcriptions.
// The reasoning is very simple: with the current liberal threshold values,
// running transcriptions too rapidly (before the audio buffer is filled
// with significant amount of new data) can cause streamer to commit wrong
// phrases.
std::this_thread::sleep_for(std::chrono::milliseconds(75));
}

std::vector<Word> finalWords = this->processor->finish();
std::vector<Word> finalWords = streamer_->finish();
TranscriptionResult finalRes =
wordsToResult(finalWords, languageOption, verbose);

nativeCallback(finalRes, {}, true);
this->resetStreamState();
resetStreamState();
}

void SpeechToText::streamStop() { this->isStreaming = false; }
void SpeechToText::streamStop() { isStreaming_ = false; }

void SpeechToText::streamInsert(std::span<float> waveform) {
this->processor->insertAudioChunk(waveform);
this->readyToProcess = true;
streamer_->insertAudioChunk(waveform);
readyToProcess_ = true;
}

void SpeechToText::resetStreamState() {
this->isStreaming = false;
this->readyToProcess = false;
this->processor = std::make_unique<OnlineASRProcessor>(this->asr.get());
isStreaming_ = false;
readyToProcess_ = false;
streamer_->reset();
}

} // namespace rnexecutorch::models::speech_to_text
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
#pragma once

#include "rnexecutorch/models/speech_to_text/stream/OnlineASRProcessor.h"
#include <rnexecutorch/models/speech_to_text/types/TranscriptionResult.h>
#include <span>
#include <string>
#include <vector>

#include "common/schema/ASR.h"
#include "common/schema/OnlineASR.h"
#include "common/types/TranscriptionResult.h"

namespace rnexecutorch {

namespace models::speech_to_text {

class SpeechToText {
public:
explicit SpeechToText(const std::string &encoderSource,
const std::string &decoderSource,
explicit SpeechToText(const std::string &modelName,
const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr<react::CallInvoker> callInvoker);

Expand All @@ -25,9 +27,9 @@ class SpeechToText {
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
decode(std::span<uint64_t> tokens, std::span<float> encoderOutput) const;
[[nodiscard("Registered non-void function")]]
types::TranscriptionResult transcribe(std::span<float> waveform,
std::string languageOption,
bool verbose) const;
TranscriptionResult transcribe(std::span<float> waveform,
std::string languageOption,
bool verbose) const;

[[nodiscard("Registered non-void function")]]
std::vector<char> transcribeStringOnly(std::span<float> waveform,
Expand All @@ -42,20 +44,17 @@ class SpeechToText {
void streamInsert(std::span<float> waveform);

private:
std::shared_ptr<react::CallInvoker> callInvoker;
std::unique_ptr<BaseModel> encoder;
std::unique_ptr<BaseModel> decoder;
std::unique_ptr<TokenizerModule> tokenizer;
std::unique_ptr<asr::ASR> asr;
void resetStreamState();

// Stream
std::unique_ptr<stream::OnlineASRProcessor> processor;
bool isStreaming;
bool readyToProcess;
std::shared_ptr<react::CallInvoker> callInvoker_;

constexpr static int32_t kMinAudioSamples = 16000; // 1 second
// ASR-like module (both static transcription & streaming)
std::unique_ptr<schema::ASR> transcriber_ = nullptr;

void resetStreamState();
// Online ASR-like module (streaming only)
std::unique_ptr<schema::OnlineASR> streamer_ = nullptr;
bool isStreaming_ = false;
bool readyToProcess_ = false;
};

} // namespace models::speech_to_text
Expand Down
Loading