Skip to content

Commit

Permalink
Merge pull request #641 from ShiromiyaG/more-model-info
Browse files Browse the repository at this point in the history
Add more information to the model
  • Loading branch information
blaisewf authored Aug 27, 2024
2 parents 9d39f54 + d890d66 commit d13bd26
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 7 deletions.
5 changes: 4 additions & 1 deletion assets/i18n/languages/en_US.json
Original file line number Diff line number Diff line change
Expand Up @@ -301,5 +301,8 @@
"Folder Name": "Folder Name",
"Upload .bin": "Upload .bin",
"Upload .json": "Upload .json",
"Move files to custom embedder folder": "Move files to custom embedder folder"
"Move files to custom embedder folder": "Move files to custom embedder folder",
"model information": "model information",
"Model Creator": "Model Creator",
"Name of the model creator. (Default: Unknown)": "Name of the model creator. (Default: Unknown)"
}
26 changes: 26 additions & 0 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,21 @@ def run_tts_script(
f0_file=f0_file,
embedder_model=embedder_model,
embedder_model_custom=embedder_model_custom,
formant_shifting=None,
formant_qfrency=None,
formant_timbre=None,
post_process=None,
reverb=None,
pitch_shift=None,
limiter=None,
gain=None,
distortion=None,
chorus=None,
bitcrush=None,
clipping=None,
compressor=None,
delay=None,
sliders=None,
)

return f"Text {tts_text} synthesized successfully.", output_rvc_path.replace(
Expand Down Expand Up @@ -443,6 +458,7 @@ def run_train_script(
custom_pretrained: bool = False,
g_pretrained_path: str = None,
d_pretrained_path: str = None,
model_creator: str = None,
):

if pretrained == True:
Expand Down Expand Up @@ -484,6 +500,7 @@ def run_train_script(
overtraining_detector,
overtraining_threshold,
sync_graph,
model_creator,
],
),
]
Expand Down Expand Up @@ -526,6 +543,7 @@ def run_model_extract_script(
# Model information
def run_model_information_script(pth_path: str):
print(model_information(pth_path))
return model_information(pth_path)


# Model blender
Expand Down Expand Up @@ -1351,6 +1369,13 @@ def parse_arguments():
help="Enable graph synchronization for distributed training.",
default=False,
)
train_parser.add_argument(
"--model_creator",
type=str,
help="Model creator name.",
default=None,
required=False,
)
train_parser.add_argument(
"--cache_data_in_gpu",
type=lambda x: bool(strtobool(x)),
Expand Down Expand Up @@ -1655,6 +1680,7 @@ def main():
pretrained=args.pretrained,
custom_pretrained=args.custom_pretrained,
sync_graph=args.sync_graph,
model_creator=args.model_creator,
index_algorithm=args.index_algorithm,
cache_data_in_gpu=args.cache_data_in_gpu,
g_pretrained_path=args.g_pretrained_path,
Expand Down
21 changes: 20 additions & 1 deletion rvc/train/process/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,20 @@ def replace_keys_in_dict(d, old_key_part, new_key_part):
return updated_dict


def extract_model(ckpt, sr, pitch_guidance, name, model_dir, epoch, step, version, hps):
def extract_model(
ckpt,
sr,
pitch_guidance,
name,
model_dir,
epoch,
step,
version,
hps,
model_creator,
overtrain_info,
dataset_lenght,
):
try:
print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})")

Expand Down Expand Up @@ -70,6 +83,12 @@ def extract_model(ckpt, sr, pitch_guidance, name, model_dir, epoch, step, versio
hash_input = f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}"
model_hash = hashlib.sha256(hash_input.encode()).hexdigest()
opt["model_hash"] = model_hash
opt["model_name"] = name
if model_creator is None:
model_creator = "Unknown"
opt["author"] = model_creator
opt["overtrain_info"] = overtrain_info
opt["dataset_lenght"] = dataset_lenght

torch.save(opt, os.path.join(model_dir_path, pth_file))

Expand Down
8 changes: 8 additions & 0 deletions rvc/train/process/model_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,32 @@ def model_information(path):

print(f"Loaded model from {path}")

model_name = model_data.get("model_name", "None")
model_creator = model_data.get("author", "None")
epochs = model_data.get("epoch", "None")
steps = model_data.get("step", "None")
sr = model_data.get("sr", "None")
f0 = model_data.get("f0", "None")
dataset_lenght = model_data.get("dataset_lenght", "None")
version = model_data.get("version", "None")
creation_date = model_data.get("creation_date", "None")
model_hash = model_data.get("model_hash", None)
overtrain_info = model_data.get("overtrain_info", "None")

pitch_guidance = "True" if f0 == 1 else "False"

creation_date_str = prettify_date(creation_date) if creation_date else "None"

return (
f"Model Name: {model_name}\n"
f"Model Creator: {model_creator}\n"
f"Epochs: {epochs}\n"
f"Steps: {steps}\n"
f"RVC Version: {version}\n"
f"Sampling Rate: {sr}\n"
f"Pitch Guidance: {pitch_guidance}\n"
f"Dataset Length: {dataset_lenght}\n"
f"Creation Date: {creation_date_str}\n"
f"Hash (ID): {model_hash}"
f"Overtrain Info: {overtrain_info}"
)
36 changes: 36 additions & 0 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
import torch.multiprocessing as mp
from pydub import AudioSegment

now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))
Expand Down Expand Up @@ -72,10 +73,12 @@
overtraining_detector = strtobool(sys.argv[14])
overtraining_threshold = int(sys.argv[15])
sync_graph = strtobool(sys.argv[16])
model_creator = sys.argv[17]

current_dir = os.getcwd()
experiment_dir = os.path.join(current_dir, "logs", model_name)
config_save_path = os.path.join(experiment_dir, "config.json")
dataset_path = os.path.join(experiment_dir, "sliced_audios")

with open(config_save_path, "r") as f:
config = json.load(f)
Expand All @@ -97,6 +100,8 @@
smoothed_loss_disc_history = []
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
training_file_path = os.path.join(experiment_dir, "training_data.json")
dataset_duration = 0
overtrain_info = None

import logging

Expand Down Expand Up @@ -124,6 +129,24 @@ def record(self):
return f"time={current_time} | training_speed={elapsed_time_str}"


def ms_to_min_sec(ms):
seconds = ms // 1000
minutes = seconds // 60
seconds = seconds % 60
return f"{minutes}:{seconds:02}"


def get_audio_durations(dataset_path):
durations = []
for filename in os.listdir(dataset_path):
if filename.endswith(".wav"): # Assumindo que os arquivos de áudio são .wav
audio_path = os.path.join(dataset_path, filename)
audio = AudioSegment.from_wav(audio_path)
duration_ms = len(audio)
durations.append(ms_to_min_sec(duration_ms))
return durations


def main():
"""
Main function to start the training process.
Expand Down Expand Up @@ -203,6 +226,8 @@ def continue_overtrain_detector(training_file_path):
print("GPU not detected, reverting to CPU (not recommended)")
n_gpus = 1

dataset_duration = get_audio_durations(dataset_path)

if sync_graph == True:
print(
"Sync graph is now activated! With sync graph enabled, the model undergoes a single epoch of training. Once the graphs are synchronized, training proceeds for the previously specified number of epochs."
Expand Down Expand Up @@ -821,6 +846,9 @@ def train_and_evaluate(
step=global_step,
version=version,
hps=hps,
model_creator=model_creator,
overtrain_info=overtrain_info,
dataset_lenght=dataset_duration,
)

def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
Expand Down Expand Up @@ -917,6 +945,8 @@ def save_to_json(
consecutive_increases_gen += 1
else:
consecutive_increases_gen = 0

overtrain_info = f"Smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}"
# Save the data in the JSON file if the epoch is divisible by save_every_epoch
if epoch % save_every_epoch == 0:
save_to_json(
Expand Down Expand Up @@ -965,6 +995,9 @@ def save_to_json(
step=global_step,
version=version,
hps=hps,
model_creator=model_creator,
overtrain_info=overtrain_info,
dataset_lenght=dataset_duration,
)

# Print training progress
Expand Down Expand Up @@ -1025,6 +1058,9 @@ def save_to_json(
step=global_step,
version=version,
hps=hps,
model_creator=model_creator,
overtrain_info=overtrain_info,
dataset_lenght=dataset_duration,
)
sleep(1)
os._exit(2333333)
Expand Down
5 changes: 2 additions & 3 deletions tabs/extra/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ def extra_tab():
"This section contains some extra utilities that often may be in experimental phases."
)
)
with gr.TabItem(i18n("Model information")):
processing_tab()

with gr.TabItem(i18n("F0 Curve")):
f0_extractor_tab()

with gr.TabItem(i18n("Processing")):
processing_tab()

with gr.TabItem(i18n("Audio Analyzer")):
analyzer_tab()
2 changes: 1 addition & 1 deletion tabs/extra/model_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def model_information_tab():
label=i18n("Output Information"),
info=i18n("The output information will be displayed here."),
value="",
max_lines=8,
max_lines=11,
interactive=False,
)
model_information_button = gr.Button(i18n("See Model Information"))
Expand Down
2 changes: 1 addition & 1 deletion tabs/extra/processing/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def processing_tab():
label=i18n("Output Information"),
info=i18n("The output information will be displayed here."),
value="",
max_lines=8,
max_lines=11,
)
model_view_button = gr.Button(i18n("View"), variant="primary")
model_view_button.click(
Expand Down
9 changes: 9 additions & 0 deletions tabs/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ def train_tab():
interactive=True,
allow_custom_value=True,
)
model_creator = gr.Dropdown(
label=i18n("Model Creator"),
info=i18n("Name of the model creator. (Default: Unknown)"),
value=None,
interactive=True,
visible=True,
allow_custom_value=True,
)
sampling_rate = gr.Radio(
label=i18n("Sampling Rate"),
info=i18n("The sampling rate of the audio files."),
Expand Down Expand Up @@ -752,6 +760,7 @@ def train_tab():
custom_pretrained,
g_pretrained_path,
d_pretrained_path,
model_creator,
],
outputs=[train_output_info],
api_name="start_training",
Expand Down

0 comments on commit d13bd26

Please sign in to comment.