Skip to content

Commit

Permalink
some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Dec 8, 2023
1 parent c1018fa commit 90edbd3
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ovos_plugin_manager/templates/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def get_from_cache(self, sentence, audio_ext="wav", cache_config=None):
sentence_hash = hash_sentence(sentence)
phonemes = None
cache = self.get_cache(audio_ext, cache_config)
if sentence_hash not in cache:
raise FileNotFoundError(f"sentence is not cached, {sentence_hash}.{audio_ext}")
audio_file, pho_file = cache.cached_sentences[sentence_hash]
LOG.info(f"Found {audio_file.name} in TTS cache")
if pho_file:
Expand Down
73 changes: 71 additions & 2 deletions test/unittests/test_tts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import unittest
from unittest.mock import patch, Mock
from unittest.mock import MagicMock, patch
from unittest.mock import Mock

from ovos_plugin_manager.templates.tts import TTS, TTSContext
from ovos_plugin_manager.utils import PluginTypes, PluginConfigTypes
from ovos_plugin_manager.templates.tts import TTS


class TestTTSTemplate(unittest.TestCase):
Expand Down Expand Up @@ -262,3 +264,70 @@ def test_create(self, get_class):
get_class.assert_called_with(expected_config)
plugin_class.assert_called_with(lang=None, config=expected_config)
self.assertEqual(plugin, plugin_class())


class TestTTSContext(unittest.TestCase):
def test_tts_context_init(self):
session_mock = MagicMock()
tts_context = TTSContext(session=session_mock)
self.assertEqual(tts_context.session, session_mock)
self.assertEqual(tts_context.lang, session_mock.lang)

@patch("ovos_plugin_manager.templates.tts.TextToSpeechCache", autospec=True)
def test_tts_context_get_cache(self, cache_mock):
session_mock = MagicMock()
tts_context = TTSContext(session=session_mock)

cache_config = {
"min_free_percent": 75,
"persist_cache": False,
"persist_thresh": 1,
"preloaded_cache": "/fake/cache/path/{}/{}/{}".format(
session_mock.tts_preferences['plugin_id'],
session_mock.tts_preferences['config']['voice'],
session_mock.lang
)
}

result = tts_context.get_cache(cache_config=cache_config)

self.assertEqual(result, cache_mock.return_value)
self.assertEqual(result, tts_context._caches[tts_context.tts_id])


class TestTTSCache(unittest.TestCase):
def setUp(self):
self.tts_mock = TTS(lang="en-us", config={"some_config_key": "some_config_value"})
self.tts_mock.stopwatch = MagicMock()
self.tts_mock.queue = MagicMock()
self.tts_mock.playback = MagicMock()

@patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash")
@patch("ovos_plugin_manager.templates.tts.TTSContext", autospec=True)
def test_tts_synth(self, tts_context_mock, hash_sentence_mock):
tts_context_mock.get_cache.return_value = MagicMock()
tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path"

sentence = "Hello world!"
result = self.tts_mock.synth(sentence, tts_context_mock)

tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config)
tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash")
self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None))

@patch("ovos_plugin_manager.templates.tts.hash_sentence", return_value="fake_hash")
def test_tts_synth_cache_enabled(self, hash_sentence_mock):
tts_context_mock = MagicMock()
tts_context_mock.tts_id = "fake_tts_id"
tts_context_mock.get_cache.return_value = MagicMock()
tts_context_mock.get_cache.return_value.cached_sentences = {}
tts_context_mock.get_cache.return_value.define_audio_file.return_value.path = "fake_audio_path"
tts_context_mock._caches = {tts_context_mock.tts_id: tts_context_mock.get_cache.return_value}

sentence = "Hello world!"
result = self.tts_mock.synth(sentence, tts_context_mock)

tts_context_mock.get_cache.assert_called_once_with("wav", self.tts_mock.config)
tts_context_mock.get_cache.return_value.define_audio_file.assert_called_once_with("fake_hash")
self.assertEqual(result, (tts_context_mock.get_cache.return_value.define_audio_file.return_value, None))
self.assertIn("fake_hash", tts_context_mock.get_cache.return_value.cached_sentences)

0 comments on commit 90edbd3

Please sign in to comment.