Skip to content

Commit

Permalink
Merge pull request #679 from ShiromiyaG/embedder-info
Browse files Browse the repository at this point in the history
Add embedder info
  • Loading branch information
blaisewf authored Sep 7, 2024
2 parents e9354aa + 4ed64e4 commit 1e25cb1
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
13 changes: 12 additions & 1 deletion rvc/train/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import concurrent.futures
import multiprocessing as mp
import json

# Zluda
if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
Expand Down Expand Up @@ -258,7 +259,17 @@ def run_embedding_extraction(
os.makedirs(os.path.join(exp_dir, "f0"), exist_ok=True)
os.makedirs(os.path.join(exp_dir, "f0_voiced"), exist_ok=True)
os.makedirs(os.path.join(exp_dir, version + "_extracted"), exist_ok=True)

# write to model_info.json
chosen_embedder_model = (
embedder_model_custom if embedder_model_custom else embedder_model
)
with open(os.path.join(exp_dir, "model_info.json"), "w") as f:
json.dump(
{
"embedder_model": chosen_embedder_model,
},
f,
)
files = []
for file in glob.glob(os.path.join(wav_path, "*.wav")):
file_name = os.path.basename(file)
Expand Down
2 changes: 2 additions & 0 deletions rvc/train/process/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def extract_model(
with open(os.path.join(model_dir_path, "model_info.json"), "r") as f:
data = json.load(f)
dataset_lenght = data.get("total_dataset_duration", None)
embedder_model = data.get("embedder_model", None)
else:
dataset_lenght = None

Expand Down Expand Up @@ -101,6 +102,7 @@ def extract_model(
opt["dataset_lenght"] = dataset_lenght
opt["model_name"] = name
opt["author"] = model_author
opt["embedder_model"] = embedder_model

torch.save(opt, os.path.join(model_dir_path, pth_file))

Expand Down
2 changes: 2 additions & 0 deletions rvc/train/process/model_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def model_information(path):
model_hash = model_data.get("model_hash", None)
overtrain_info = model_data.get("overtrain_info", "None")
model_author = model_data.get("author", "None")
embedder_model = model_data.get("embedder_model", "None")

pitch_guidance = "True" if f0 == 1 else "False"

Expand All @@ -45,4 +46,5 @@ def model_information(path):
f"Creation Date: {creation_date_str}\n"
f"Hash (ID): {model_hash}\n"
f"Overtrain Info: {overtrain_info}"
f"Embedder Model: {embedder_model}"
)
2 changes: 1 addition & 1 deletion tabs/extra/model_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def model_information_tab():
label=i18n("Output Information"),
info=i18n("The output information will be displayed here."),
value="",
max_lines=11,
max_lines=12,
interactive=False,
)
model_information_button = gr.Button(i18n("See Model Information"))
Expand Down

0 comments on commit 1e25cb1

Please sign in to comment.