Skip to content

Commit

Permalink
Merge pull request #279 from bolna-ai/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
marmikcfc authored Jun 22, 2024
2 parents b16da0f + bdddb44 commit 53cd006
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 22 deletions.
1 change: 1 addition & 0 deletions bolna/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class FourieConfig(BaseModel):

class DeepgramConfig(BaseModel):
voice: str
model: str


class MeloConfig(BaseModel):
Expand Down
82 changes: 63 additions & 19 deletions bolna/synthesizer/deepgram_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import os
from dotenv import load_dotenv
from bolna.helpers.logger_config import configure_logger
from bolna.helpers.utils import create_ws_data_packet
from bolna.helpers.utils import convert_audio_to_wav, create_ws_data_packet
from bolna.memory.cache.inmemory_scalar_cache import InmemoryScalarCache
from .base_synthesizer import BaseSynthesizer

logger = configure_logger(__name__)
Expand All @@ -12,24 +13,39 @@


class DeepgramSynthesizer(BaseSynthesizer):
def __init__(self, voice, audio_format="pcm", sampling_rate="8000", stream=False, buffer_size=400,
**kwargs):
def __init__(self, voice, audio_format="pcm", sampling_rate="8000", stream=False, buffer_size=400, caching = True,
model = "aura-zeus-en", **kwargs):
super().__init__(stream, buffer_size)
self.format = "linear16" if audio_format == "pcm" else audio_format
self.format = "linear16" if audio_format in ["pcm", 'wav'] else audio_format
self.voice = voice
self.sample_rate = str(sampling_rate)
self.model = model
self.first_chunk_generated = False
self.api_key = kwargs.get("transcriber_key", os.getenv('DEEPGRAM_AUTH_TOKEN'))

self.synthesized_characters = 0
self.caching = caching
if caching:
self.cache = InmemoryScalarCache()


def get_synthesized_characters(self):
return self.synthesized_characters

def get_engine(self):
return self.model

async def __generate_http(self, text):
headers = {
"Authorization": "Token {}".format(self.api_key),
"Content-Type": "application/json"
}
url = DEEPGRAM_TTS_URL + "?encoding={}&container=none&sample_rate={}&model={}".format(
self.format, self.sample_rate, self.voice
self.format, self.sample_rate, self.model
)

logger.info(f"Sending deepgram request {url}")

payload = {
"text": text
}
Expand All @@ -39,32 +55,60 @@ async def __generate_http(self, text):
async with session.post(url, headers=headers, json=payload) as response:
if response.status == 200:
chunk = await response.read()
yield chunk
return chunk
else:
logger.info("Payload was null")

def supports_websocket(self):
return False

async def open_connection(self):
pass
pass

async def synthesize(self, text):
# This is used for one off synthesis mainly for use cases like voice lab and IVR
try:
audio = await self.__generate_http(text)
if self.format == "mp3":
audio = convert_audio_to_wav(audio, source_format="mp3")
return audio
except Exception as e:
logger.error(f"Could not synthesize {e}")

async def generate(self):
while True:
logger.info("Generating TTS response")
message = await self.internal_queue.get()
logger.info(f"Generating TTS response for message: {message}")

meta_info, text = message.get("meta_info"), message.get("data")
async for message in self.__generate_http(text):
if not self.first_chunk_generated:
meta_info["is_first_chunk"] = True
self.first_chunk_generated = True
if self.caching:
logger.info(f"Caching is on")
if self.cache.get(text):
logger.info(f"Cache hit and hence returning quickly {text}")
message = self.cache.get(text)
else:
meta_info["is_first_chunk"] = False
if "end_of_llm_stream" in meta_info and meta_info["end_of_llm_stream"]:
meta_info["end_of_synthesizer_stream"] = True
self.first_chunk_generated = False
logger.info(f"Not a cache hit {list(self.cache.data_dict)}")
self.synthesized_characters += len(text)
message = await self.__generate_http(text)
self.cache.set(text, message)
else:
logger.info(f"No caching present")
self.synthesized_characters += len(text)
message = await self.__generate_http(text)

meta_info['text'] = text
meta_info['format'] = self.format
yield create_ws_data_packet(message, meta_info)
if self.format == "mp3":
message = convert_audio_to_wav(message, source_format="mp3")
if not self.first_chunk_generated:
meta_info["is_first_chunk"] = True
self.first_chunk_generated = True
else:
meta_info["is_first_chunk"] = False
if "end_of_llm_stream" in meta_info and meta_info["end_of_llm_stream"]:
meta_info["end_of_synthesizer_stream"] = True
self.first_chunk_generated = False
meta_info['text'] = text
meta_info['format'] = 'wav'
yield create_ws_data_packet(message, meta_info)

async def push(self, message):
logger.info("Pushed message to internal queue")
Expand Down
6 changes: 4 additions & 2 deletions bolna/synthesizer/melo_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ async def synthesize(self, text):

async def open_connection(self):
pass


def supports_websocket(self):
return False
async def generate(self):
while True:
message = await self.internal_queue.get()
Expand All @@ -100,7 +102,7 @@ async def generate(self):
logger.info(f"Not a cache hit {list(self.cache.data_dict)}")
self.synthesized_characters += len(text)
audio = await self.__generate_http(text)
self.cache.set(text, message)
self.cache.set(text, audio)
else:
logger.info(f"No caching present")
self.synthesized_characters += len(text)
Expand Down
10 changes: 9 additions & 1 deletion bolna/synthesizer/styletts_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ async def __generate_http(self, text):
chunk = base64.b64decode(res_json["audio"])
return chunk


def supports_websocket(self):
return False

def get_synthesized_characters(self):
return self.synthesized_characters


async def open_connection(self):
pass

Expand Down Expand Up @@ -90,7 +98,7 @@ async def generate(self):
logger.info(f"Not a cache hit {list(self.cache.data_dict)}")
self.synthesized_characters += len(text)
audio = await self.__generate_http(text)
self.cache.set(text, message)
self.cache.set(text, audio)
else:
logger.info(f"No caching present")
self.synthesized_characters += len(text)
Expand Down

0 comments on commit 53cd006

Please sign in to comment.