From c053b38419d3d74591a1b42e8377947e0c14d113 Mon Sep 17 00:00:00 2001 From: Andreas Klos Date: Tue, 5 Sep 2023 14:02:36 +0200 Subject: [PATCH] ADD modelWrapper.py for facilitating mlflow logging; Credit: Blirona --- TTS/tts/models/modelWrapper.py | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 TTS/tts/models/modelWrapper.py diff --git a/TTS/tts/models/modelWrapper.py b/TTS/tts/models/modelWrapper.py new file mode 100644 index 0000000000..46e74270ac --- /dev/null +++ b/TTS/tts/models/modelWrapper.py @@ -0,0 +1,38 @@ +from TTS.utils.synthesizer import Synthesizer +import mlflow + +try: + experiment = mlflow.get_experiment_by_name("tts: tts test") + if experiment is None: + experiment = mlflow.create_experiment("tts: tts test", artifact_location="s3://mlflow/TTS") + experiment = mlflow.get_experiment_by_name("tts: tts test") +except mlflow.exceptions.RestException as e: + experiment = mlflow.create_experiment("tts: tts test", artifact_location="s3://mlflow/TTS") + experiment = mlflow.get_experiment_by_name("tts: tts test") + + +class MyModel(mlflow.pyfunc.PythonModel): + ''' + def __init__(self, tts_path, tts_checkpoint): + import os + self.checkpoint = os.path.join(tts_path, tts_checkpoint) + self.config_path = os.path.join(tts_path, "config.json") + self.synthesizer = None + + def predict(self, context, model_input): + self.synthesizer = Synthesizer(self.checkpoint, self.config_path) + wav = self.synthesizer.tts(model_input) + self.synthesizer.save_wav(wav, 'output.wav') + + return wav + ''' + def __init__(self, synthesizer): + self.synthesizer = synthesizer + + + def predict(self, context, model_input): + synthesizer = self.synthesizer + wav = synthesizer.tts(model_input) + synthesizer.save_wav(wav, 'output.wav') + + return wav \ No newline at end of file