Skip to content

Commit

Permalink
ADD modelWrapper.py for facilitating mlflow logging; Credit: Blirona
Browse files Browse the repository at this point in the history
  • Loading branch information
FeU-aKlos committed Sep 5, 2023
1 parent cdc971f commit c053b38
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions TTS/tts/models/modelWrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from TTS.utils.synthesizer import Synthesizer
import mlflow

try:
experiment = mlflow.get_experiment_by_name("tts: tts test")
if experiment is None:
experiment = mlflow.create_experiment("tts: tts test", artifact_location="s3://mlflow/TTS")
experiment = mlflow.get_experiment_by_name("tts: tts test")
except mlflow.exceptions.RestException as e:
experiment = mlflow.create_experiment("tts: tts test", artifact_location="s3://mlflow/TTS")
experiment = mlflow.get_experiment_by_name("tts: tts test")


class MyModel(mlflow.pyfunc.PythonModel):
'''
def __init__(self, tts_path, tts_checkpoint):
import os
self.checkpoint = os.path.join(tts_path, tts_checkpoint)
self.config_path = os.path.join(tts_path, "config.json")
self.synthesizer = None
def predict(self, context, model_input):
self.synthesizer = Synthesizer(self.checkpoint, self.config_path)
wav = self.synthesizer.tts(model_input)
self.synthesizer.save_wav(wav, 'output.wav')
return wav
'''
def __init__(self, synthesizer):
self.synthesizer = synthesizer


def predict(self, context, model_input):
synthesizer = self.synthesizer
wav = synthesizer.tts(model_input)
synthesizer.save_wav(wav, 'output.wav')

return wav

0 comments on commit c053b38

Please sign in to comment.