Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add style and template #4

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ ten_package("azure_tts") {
"include",

# The build flags for in-app building
"//ten_packages/system/nlohmann_json/include",
"//ten_packages/system/ten_runtime/include",
"//ten_packages/system/azure_speech_sdk/include/microsoft/c_api",
"//ten_packages/system/azure_speech_sdk/include/microsoft/cxx_api",

# The build flags for standalone building.
".ten/app/ten_packages/system/nlohmann_json/include",
".ten/app/ten_packages/system/ten_runtime/include",
".ten/app/ten_packages/system/azure_speech_sdk/include/microsoft/c_api",
".ten/app/ten_packages/system/azure_speech_sdk/include/microsoft/cxx_api",
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ TEN extension of [azure Text to speech service](https://learn.microsoft.com/en-u
| `azure_subscription_key` | `string` | `""` | Azure Speech service subscription key |
| `azure_subscription_region` | `string` | `""` | Azure Speech service subscription region |
| `azure_synthesis_voice_name` | `string` | `""` | e.g., `en-US-AdamMultilingualNeural`, check more available voices in [languages and voices support](https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts) |
| `prosody` | `string` | `""` | Azure Speech prosody |
| `language` | `string` | `""` | Azure Speech language |
| `role` | `string` | `""` | Azure Speech role |
| `style` | `string` | `""` | Azure Speech style |

## Development

Expand Down
19 changes: 18 additions & 1 deletion manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"type": "extension",
"name": "azure_tts",
"version": "0.6.0",
"version": "0.6.1",
"dependencies": [
{
"type": "system",
Expand All @@ -12,6 +12,11 @@
"type": "system",
"name": "azure_speech_sdk",
"version": "1.38.0"
},
{
"type": "system",
"name": "nlohmann_json",
"version": "=3.11.2"
}
],
"api": {
Expand All @@ -24,6 +29,18 @@
},
"azure_synthesis_voice_name": {
"type": "string"
},
"style": {
"type": "string"
},
"prosody": {
"type": "string"
},
"language": {
"type": "string"
},
"role": {
"type": "string"
}
},
"data_in": [
Expand Down
57 changes: 52 additions & 5 deletions src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,37 @@
#include <cstdint>
#include <cstdio>
#include <memory>
#include <string>

#include "log.h"
#include "ten_runtime/binding/cpp/ten.h"
#include "ten_utils/macro/check.h"
#include "tts.h"
#include "tmpl.h"

namespace azure_tts_extension {

std::string trimString(const std::string& input) {
std::string result = input;
std::string::size_type pos;

// Remove all occurrences of "\n"
while ((pos = result.find("\\n")) != std::string::npos) {
result.erase(pos, 2);
}

// Remove all occurrences of "\r"
while ((pos = result.find("\\r")) != std::string::npos) {
result.erase(pos, 2);
}

// Remove all occurrences of "\t"
while ((pos = result.find("\\t")) != std::string::npos) {
result.erase(pos, 2);
}
return result;
}

class azure_tts_extension_t : public ten::extension_t {
public:
explicit azure_tts_extension_t(const std::string &name) : extension_t(name) {}
Expand All @@ -34,14 +57,19 @@ class azure_tts_extension_t : public ten::extension_t {
// read properties
auto key = ten.get_property_string("azure_subscription_key");
auto region = ten.get_property_string("azure_subscription_region");
auto voice_name = ten.get_property_string("azure_synthesis_voice_name");
if (key.empty() || region.empty() || voice_name.empty()) {
voice_ = ten.get_property_string("azure_synthesis_voice_name");
if (key.empty() || region.empty() || voice_.empty()) {
AZURE_TTS_LOGE(
"azure_subscription_key, azure_subscription_region, azure_synthesis_voice_name should not be empty, start "
"failed");
return;
}

style_ = ten.get_property_string("style");
prosody_ = ten.get_property_string("prosody");
language_ = ten.get_property_string("language");
role_ = ten.get_property_string("role");

ten_proxy_ = std::unique_ptr<ten::ten_env_proxy_t>(ten::ten_env_proxy_t::create(ten));
TEN_ASSERT(ten_proxy_ != nullptr, "ten_proxy should not be nullptr");

Expand Down Expand Up @@ -83,7 +111,7 @@ class azure_tts_extension_t : public ten::extension_t {
azure_tts_ = std::make_unique<AzureTTS>(
key,
region,
voice_name,
voice_,
Microsoft::CognitiveServices::Speech::SpeechSynthesisOutputFormat::Raw16Khz16BitMonoPcm,
pcm_frame_size,
std::move(pcm_callback));
Expand Down Expand Up @@ -132,10 +160,22 @@ class azure_tts_extension_t : public ten::extension_t {
AZURE_TTS_LOGD("input text is empty, ignored");
return;
}
AZURE_TTS_LOGI("input text: [%s]", text.c_str());

text = trimString(text);
// push received text to tts queue for synthesis
azure_tts_->Push(text);
if (!language_.empty() && (!prosody_.empty() || !role_.empty() || !style_.empty())) {
MsttsTemplate tmpl;
auto ssml_text = tmpl.replace(json{{"role", role_},
{"voice", voice_},
{"lang", language_},
{"style", style_},
{"prosody", prosody_},
{"text", text}});
AZURE_TTS_LOGI("input ssml text: [%s]", ssml_text.c_str());
azure_tts_->Push(ssml_text, true);
} else {
azure_tts_->Push(text, false);
}
}

// on_stop will be called when the extension is stopping.
Expand All @@ -157,8 +197,15 @@ class azure_tts_extension_t : public ten::extension_t {

std::unique_ptr<AzureTTS> azure_tts_;

std::string voice_;
std::string prosody_;
std::string language_;
std::string role_;
std::string style_;

const std::string kCmdNameFlush{"flush"};
const std::string kDataFieldText{"text"};
const std::string kDataFieldSSML{"ssml"};
};

TEN_CPP_REGISTER_ADDON_AS_EXTENSION(azure_tts, azure_tts_extension_t);
Expand Down
70 changes: 70 additions & 0 deletions src/tmpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include <nlohmann/json.hpp>
#include <string>
#include <iostream>
#include <regex>

namespace azure_tts_extension {

using json = nlohmann::json;

class MsttsTemplate {
public:
MsttsTemplate() = default;
~MsttsTemplate() = default;

std::string replace(const json& params) const {
std::string result = templateStr_;
if (params.contains("prosody")) {
auto value = params["prosody"].get<std::string>();
if (!value.empty()) {
result = templateProsodyStr_;
}
}
replacePlaceholder(result, "lang", params, "xml:lang=\"{lang}\"");
replacePlaceholder(result, "voice", params, "name=\"{voice}\"");
replacePlaceholder(result, "style", params, "style=\"{style}\"");
replacePlaceholder(result, "role", params, "role=\"{role}\"");
replacePlaceholder(result, "prosody", params, "{prosody}");
replacePlaceholder(result, "text", params, "{text}");
return result;
}

private:
std::string templateProsodyStr_ = R"(
<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xmlns:mstts="https://www.w3.org/2001/mstts" {lang}>
<voice {voice}>
<mstts:express-as {role} {style} >
<prosody {prosody}>
{text}
</prosody>
</mstts:express-as>
</voice>
</speak>
)";

std::string templateStr_ = R"(
<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xmlns:mstts="https://www.w3.org/2001/mstts" {lang}>
<voice {voice}>
<mstts:express-as {role} {style} >
{text}
</mstts:express-as>
</voice>
</speak>
)";

void replacePlaceholder(std::string& result, const std::string& placeholder, const json& params, const std::string& templateStr) const {
std::string value = "";
if (params.contains(placeholder)) {
value = params[placeholder].get<std::string>();
}
std::string tempStr = templateStr;
if (value.empty()) {
result = std::regex_replace(result, std::regex("\\{" + placeholder + "\\}"), "");
} else {
tempStr = std::regex_replace(tempStr, std::regex("\\{" + placeholder + "\\}"), value);
result = std::regex_replace(result, std::regex("\\{" + placeholder + "\\}"), tempStr);
}
}
};

} // namespace azure_tts_extension
68 changes: 45 additions & 23 deletions src/tts.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#include "tts.h"

#include <sys/types.h>
Expand Down Expand Up @@ -44,7 +43,7 @@ bool AzureTTS::Start() {
tasks_.pop();
}

SpeechText(task->text, task->ts);
SpeechText(task->text, task->ts, task->ssml);
}

AZURE_TTS_LOGI("tts_thread stopped");
Expand All @@ -67,16 +66,17 @@ bool AzureTTS::Stop() {
return true;
}

void AzureTTS::Push(const std::string& text) noexcept {
void AzureTTS::Push(const std::string& text, bool ssml) noexcept {
auto ts = time_since_epoch_in_us();

{
std::unique_lock<std::mutex> lock(tasks_mutex_);
tasks_.emplace(std::make_unique<Task>(text, ts));
tasks_.emplace(std::make_unique<Task>(text, ts, ssml));
tasks_cv_.notify_one();

AZURE_TTS_LOGD("task pushed for text: [%s], text_recv_ts: %" PRId64 ", queue size %d",
AZURE_TTS_LOGD("task pushed for text: [%s], ssml: %d, text_recv_ts: %" PRId64 ", queue size %d",
text.c_str(),
ssml,
ts,
int(tasks_.size()));
}
Expand All @@ -94,9 +94,9 @@ void AzureTTS::Flush() noexcept {
}
}

void AzureTTS::SpeechText(const std::string& text, int64_t text_recv_ts) {
void AzureTTS::SpeechText(const std::string& text, int64_t text_recv_ts, bool ssml) {
auto start_time = time_since_epoch_in_us();
AZURE_TTS_LOGD("task starting for text: [%s], text_recv_ts: %" PRId64, text.c_str(), text_recv_ts);
AZURE_TTS_LOGD("task starting for text: [%s], ssml: %d text_recv_ts: %" PRId64, text.c_str(), ssml, text_recv_ts);

if (text_recv_ts < outdate_ts_.load()) {
AZURE_TTS_LOGI("task discard for text: [%s], text_recv_ts: %" PRId64 ", outdate_ts: %" PRId64,
Expand All @@ -108,24 +108,46 @@ void AzureTTS::SpeechText(const std::string& text, int64_t text_recv_ts) {

using namespace Microsoft::CognitiveServices;

std::shared_ptr<Speech::SpeechSynthesisResult> result;
// async mode
auto result = speech_synthesizer_->StartSpeakingTextAsync(text).get();
if (result->Reason == Speech::ResultReason::Canceled) {
auto cancellation = Speech::SpeechSynthesisCancellationDetails::FromResult(result);
AZURE_TTS_LOGW("task canceled for text: [%s], text_recv_ts: %" PRId64 ", reason: %d",
text.c_str(),
text_recv_ts,
(int)cancellation->Reason);

if (cancellation->Reason == Speech::CancellationReason::Error) {
AZURE_TTS_LOGW("task canceled on error for text: [%s], text_recv_ts: %" PRId64
", errorcode: %d, details: %s, did you update the subscription info?",
text.c_str(),
text_recv_ts,
(int)cancellation->ErrorCode,
cancellation->ErrorDetails.c_str());
if (ssml) {
result = speech_synthesizer_->StartSpeakingSsmlAsync(text).get();
if (result->Reason == Speech::ResultReason::Canceled) {
auto cancellation = Speech::SpeechSynthesisCancellationDetails::FromResult(result);
AZURE_TTS_LOGW("task canceled for ssml: [%s], text_recv_ts: %" PRId64 ", reason: %d",
text.c_str(),
text_recv_ts,
(int)cancellation->Reason);

if (cancellation->Reason == Speech::CancellationReason::Error) {
AZURE_TTS_LOGW("task canceled on error for ssml: [%s], text_recv_ts: %" PRId64
", errorcode: %d, details: %s, did you update the subscription info?",
text.c_str(),
text_recv_ts,
(int)cancellation->ErrorCode,
cancellation->ErrorDetails.c_str());
}
return;
}
} else {
result = speech_synthesizer_->StartSpeakingTextAsync(text).get();
if (result->Reason == Speech::ResultReason::Canceled) {
auto cancellation = Speech::SpeechSynthesisCancellationDetails::FromResult(result);
AZURE_TTS_LOGW("task canceled for text: [%s], text_recv_ts: %" PRId64 ", reason: %d",
text.c_str(),
text_recv_ts,
(int)cancellation->Reason);

if (cancellation->Reason == Speech::CancellationReason::Error) {
AZURE_TTS_LOGW("task canceled on error for text: [%s], text_recv_ts: %" PRId64
", errorcode: %d, details: %s, did you update the subscription info?",
text.c_str(),
text_recv_ts,
(int)cancellation->ErrorCode,
cancellation->ErrorDetails.c_str());
}
return;
}
return;
}

auto audioDataStream = Speech::AudioDataStream::FromResult(result);
Expand Down
8 changes: 4 additions & 4 deletions src/tts.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#include <speechapi_cxx.h>

#include <atomic>
Expand Down Expand Up @@ -40,10 +39,10 @@ class AzureTTS {

void Flush() noexcept;

void Push(const std::string &text) noexcept;
void Push(const std::string &text, bool ssml) noexcept;

private:
void SpeechText(const std::string &text, int64_t text_recv_ts);
void SpeechText(const std::string &text, int64_t text_recv_ts, bool ssml);

int64_t time_since_epoch_in_us() const;

Expand All @@ -59,10 +58,11 @@ class AzureTTS {
std::atomic_int64_t outdate_ts_{0}; // for flushing

struct Task {
Task(const std::string &t, int64_t ts) : ts(ts), text(t) {}
Task(const std::string &t, int64_t ts, bool ssml) : ts(ts), text(t), ssml(ssml) {}

int64_t ts{0};
std::string text;
bool ssml;
};

std::queue<std::unique_ptr<Task>> tasks_;
Expand Down