Skip to content

Commit

Permalink
Merge pull request #69 from jambonz/feat/aws_polly_rolearn
Browse files Browse the repository at this point in the history
support AWS Polly RoleArn credential
  • Loading branch information
davehorton authored May 2, 2024
2 parents ba61f20 + 79289a7 commit f0a1ab1
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 39 deletions.
4 changes: 2 additions & 2 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ module.exports = (opts, logger) => {
getTtsSize: require('./lib/get-tts-size').bind(null, client, logger),
purgeTtsCache: require('./lib/purge-tts-cache').bind(null, client, logger),
addFileToCache: require('./lib/add-file-to-cache').bind(null, client, logger),
synthAudio: require('./lib/synth-audio').bind(null, client, logger),
synthAudio: require('./lib/synth-audio').bind(null, client, createHash, retrieveHash, logger),
getNuanceAccessToken: require('./lib/get-nuance-access-token').bind(null, client, logger),
getIbmAccessToken: require('./lib/get-ibm-access-token').bind(null, client, logger),
getAwsAuthToken: require('./lib/get-aws-sts-token').bind(null, logger, createHash, retrieveHash),
getTtsVoices: require('./lib/get-tts-voices').bind(null, client, logger),
getTtsVoices: require('./lib/get-tts-voices').bind(null, client, createHash, retrieveHash, logger),
};
};
39 changes: 24 additions & 15 deletions lib/get-aws-sts-token.js
Original file line number Diff line number Diff line change
@@ -1,32 +1,41 @@
const { STSClient, GetSessionTokenCommand } = require('@aws-sdk/client-sts');
const { STSClient, GetSessionTokenCommand, AssumeRoleCommand } = require('@aws-sdk/client-sts');
const {makeAwsKey, noopLogger} = require('./utils');
const debug = require('debug')('jambonz:speech-utils');
const EXPIRY = 3600;

async function getAwsAuthToken(
logger,
createHash, retrieveHash,
awsAccessKeyId, awsSecretAccessKey, awsRegion) {
logger, createHash, retrieveHash,
awsAccessKeyId, awsSecretAccessKey, awsRegion, roleArn = null) {
logger = logger || noopLogger;
try {
const key = makeAwsKey(awsAccessKeyId);
const key = makeAwsKey(roleArn || awsAccessKeyId);
const obj = await retrieveHash(key);
if (obj) return {...obj, servedFromCache: true};

/* access token not found in cache, so generate it using STS */
const stsClient = new STSClient({
region: awsRegion,
credentials: {
accessKeyId: awsAccessKeyId,
secretAccessKey: awsSecretAccessKey,
}
});
const command = new GetSessionTokenCommand({DurationSeconds: EXPIRY});
const data = await stsClient.send(command);
let data;
if (roleArn) {
const stsClient = new STSClient({ region: awsRegion});
const roleToAssume = { RoleArn: roleArn, RoleSessionName: 'Jambonz_Speech', DurationSeconds: EXPIRY};
const command = new AssumeRoleCommand(roleToAssume);

data = await stsClient.send(command);
} else {
/* access token not found in cache, so generate it using STS */
const stsClient = new STSClient({
region: awsRegion,
credentials: {
accessKeyId: awsAccessKeyId,
secretAccessKey: awsSecretAccessKey,
}
});
const command = new GetSessionTokenCommand({DurationSeconds: EXPIRY});
data = await stsClient.send(command);
}

const credentials = {
accessKeyId: data.Credentials.AccessKeyId,
secretAccessKey: data.Credentials.SecretAccessKey,
sessionToken: data.Credentials.SessionToken,
securityToken: data.Credentials.SessionToken
};

Expand Down
33 changes: 22 additions & 11 deletions lib/get-tts-voices.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const TextToSpeechV1 = require('ibm-watson/text-to-speech/v1');
const { IamAuthenticator } = require('ibm-watson/auth');
const ttsGoogle = require('@google-cloud/text-to-speech');
const { PollyClient, DescribeVoicesCommand } = require('@aws-sdk/client-polly');
const getAwsAuthToken = require('./get-aws-sts-token');

const getIbmVoices = async(client, logger, credentials) => {
const {tts_region, tts_api_key} = credentials;
Expand Down Expand Up @@ -87,16 +88,26 @@ const getGoogleVoices = async(_client, logger, credentials) => {
return await client.listVoices();
};

const getAwsVoices = async(_client, logger, credentials) => {
const getAwsVoices = async(_client, createHash, retrieveHash, logger, credentials) => {
try {
const {region, accessKeyId, secretAccessKey} = credentials;
const client = new PollyClient({
region,
credentials: {
accessKeyId,
secretAccessKey
}
});
const {region, accessKeyId, secretAccessKey, roleArn} = credentials;
let client = null;
if (accessKeyId && secretAccessKey) {
client = new PollyClient({
region,
credentials: {
accessKeyId,
secretAccessKey
}
});
} else if (roleArn) {
client = new PollyClient({
region,
credentials: await getAwsAuthToken(logger, createHash, retrieveHash, null, null, region, roleArn),
});
} else {
client = new PollyClient({region});
}
const command = new DescribeVoicesCommand({});
const response = await client.send(command);
return response;
Expand All @@ -122,7 +133,7 @@ const getAwsVoices = async(_client, logger, credentials) => {
* @returns object containing filepath to an mp3 file in the /tmp folder containing
* the synthesized audio, and a variable indicating whether it was served from cache
*/
async function getTtsVoices(client, logger, {vendor, credentials}) {
async function getTtsVoices(client, createHash, retrieveHash, logger, {vendor, credentials}) {
logger = logger || noopLogger;

assert.ok(['nuance', 'ibm', 'google', 'aws', 'polly'].includes(vendor),
Expand All @@ -137,7 +148,7 @@ async function getTtsVoices(client, logger, {vendor, credentials}) {
return getGoogleVoices(client, logger, credentials);
case 'aws':
case 'polly':
return getAwsVoices(client, logger, credentials);
return getAwsVoices(client, createHash, retrieveHash, logger, credentials);
default:
break;
}
Expand Down
36 changes: 25 additions & 11 deletions lib/synth-audio.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const debug = require('debug')('jambonz:realtimedb-helpers');
const EXPIRES = (process.env.JAMBONES_TTS_CACHE_DURATION_MINS || 4 * 60) * 60; // cache tts for 4 hours
const TMP_FOLDER = '/tmp';
const OpenAI = require('openai');
const getAwsAuthToken = require('./get-aws-sts-token');


const trimTrailingSilence = (buffer) => {
Expand Down Expand Up @@ -75,7 +76,7 @@ const trimTrailingSilence = (buffer) => {
* @returns object containing filepath to an mp3 file in the /tmp folder containing
* the synthesized audio, and a variable indicating whether it was served from cache
*/
async function synthAudio(client, logger, stats, { account_sid,
async function synthAudio(client, createHash, retrieveHash, logger, stats, { account_sid,
vendor, language, voice, gender, text, engine, salt, model, credentials, deploymentId,
disableTtsCache, renderForCaching, disableTtsStreaming, options
}) {
Expand Down Expand Up @@ -187,7 +188,8 @@ async function synthAudio(client, logger, stats, { account_sid,
case 'aws':
case 'polly':
vendorLabel = 'aws';
audioBuffer = await synthPolly(logger, {credentials, stats, language, voice, text, engine});
audioBuffer = await synthPolly(createHash, retrieveHash, logger,
{credentials, stats, language, voice, text, engine});
break;
case 'azure':
case 'microsoft':
Expand Down Expand Up @@ -263,16 +265,28 @@ async function synthAudio(client, logger, stats, { account_sid,
});
}

const synthPolly = async(logger, {credentials, stats, language, voice, engine, text}) => {
const synthPolly = async(createHash, retrieveHash, logger,
{credentials, stats, language, voice, engine, text}) => {
try {
const {region, accessKeyId, secretAccessKey} = credentials;
const polly = new PollyClient({
region,
credentials: {
accessKeyId,
secretAccessKey
}
});
const {region, accessKeyId, secretAccessKey, roleArn} = credentials;
let polly;
if (accessKeyId && secretAccessKey) {
polly = new PollyClient({
region,
credentials: {
accessKeyId,
secretAccessKey
}
});
} else if (roleArn) {
polly = new PollyClient({
region,
credentials: await getAwsAuthToken(logger, createHash, retrieveHash, null, null, region, roleArn),
});
} else {
// AWS RoleArn assigned to Instance profile
polly = new PollyClient({region});
}
const opts = {
Engine: engine,
OutputFormat: 'mp3',
Expand Down
27 changes: 27 additions & 0 deletions test/synth.js
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,33 @@ test('AWS speech synth tests', async(t) => {
client.quit();
});

test('AWS speech synth tests by RoleArn', async(t) => {
const fn = require('..');
const {synthAudio, client} = fn(opts, logger);

if (!process.env.AWS_ROLE_ARN || !process.env.AWS_REGION) {
t.pass('skipping AWS speech synth tests by RoleArn since AWS_ROLE_ARN or AWS_REGION not provided');
return t.end();
}
try {
let opts = await synthAudio(stats, {
vendor: 'aws',
credentials: {
roleArn: process.env.AWS_ROLE_ARN,
region: process.env.AWS_REGION,
},
language: 'en-US',
voice: 'Joey',
text: 'This is a test. This is only a test',
});
t.ok(!opts.servedFromCache, `successfully synthesized aws by roleArn audio to ${opts.filePath}`);
} catch (err) {
console.error(err);
t.end(err);
}
client.quit();
});

test('Azure speech synth tests', async(t) => {
const fn = require('..');
const {synthAudio, client} = fn(opts, logger);
Expand Down

0 comments on commit f0a1ab1

Please sign in to comment.