Skip to content

Commit

Permalink
Merge pull request #712 from AznamirWoW/refactor1
Browse files Browse the repository at this point in the history
refactoring for pipeline and training
  • Loading branch information
blaisewf authored Sep 19, 2024
2 parents 1541038 + 105f626 commit 17180b2
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 857 deletions.
10 changes: 0 additions & 10 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ def run_train_script(
index_algorithm: str = "Auto",
cache_data_in_gpu: bool = False,
custom_pretrained: bool = False,
use_cpu: bool = False,
g_pretrained_path: str = None,
d_pretrained_path: str = None,
):
Expand Down Expand Up @@ -561,7 +560,6 @@ def run_train_script(
overtraining_detector,
overtraining_threshold,
sync_graph,
use_cpu,
],
),
]
Expand Down Expand Up @@ -1473,13 +1471,6 @@ def parse_arguments():
default="Auto",
required=False,
)
train_parser.add_argument(
"--use_cpu",
type=lambda x: bool(strtobool(x)),
choices=[True, False],
help="Force the use of CPU for training.",
default=False,
)

# Parser for 'index' mode
index_parser = subparsers.add_parser(
Expand Down Expand Up @@ -1784,7 +1775,6 @@ def main():
sync_graph=args.sync_graph,
index_algorithm=args.index_algorithm,
cache_data_in_gpu=args.cache_data_in_gpu,
use_cpu=args.use_cpu,
g_pretrained_path=args.g_pretrained_path,
d_pretrained_path=args.d_pretrained_path,
)
Expand Down
123 changes: 43 additions & 80 deletions rvc/infer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def get_f0_hybrid(
for method in methods:
f0 = None
if method == "crepe":
f0 = self.get_f0_crepe(x, f0_min, f0_max, p_len, int(hop_length))
f0 = self.get_f0_crepe_computation(
x, f0_min, f0_max, p_len, int(hop_length)
)
elif method == "rmvpe":
self.model_rmvpe = RMVPE0Predictor(
os.path.join("rvc", "models", "predictors", "rmvpe.pt"),
Expand Down Expand Up @@ -412,82 +414,44 @@ def voice_conversion(
version: Model version ("v1" or "v2").
protect: Protection level for preserving the original pitch.
"""
feats = torch.from_numpy(audio0)
if self.is_half:
feats = feats.half()
else:
feats = feats.float()
if feats.dim() == 2:
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)

with torch.no_grad():
feats = model(feats.to(self.device))["last_hidden_state"]
feats = (
model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats
)
if protect < 0.5 and pitch != None and pitchf != None:
feats0 = feats.clone()
if (
isinstance(index, type(None)) == False
and isinstance(big_npy, type(None)) == False
and index_rate != 0
):
npy = feats[0].cpu().numpy()
if self.is_half:
npy = npy.astype("float32")

score, ix = index.search(npy, k=8)
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)

if self.is_half:
npy = npy.astype("float16")
feats = (
torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
+ (1 - index_rate) * feats
)

feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
if protect < 0.5 and pitch != None and pitchf != None:
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
0, 2, 1
)
p_len = audio0.shape[0] // self.window
if feats.shape[1] < p_len:
p_len = feats.shape[1]
if pitch != None and pitchf != None:
pitch = pitch[:, :p_len]
pitchf = pitchf[:, :p_len]

if protect < 0.5 and pitch != None and pitchf != None:
pitchff = pitchf.clone()
pitchff[pitchf > 0] = 1
pitchff[pitchf < 1] = protect
pitchff = pitchff.unsqueeze(-1)
feats = feats * pitchff + feats0 * (1 - pitchff)
feats = feats.to(feats0.dtype)
p_len = torch.tensor([p_len], device=self.device).long()
with torch.no_grad():
if pitch != None and pitchf != None:
audio1 = (
(net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0])
.data.cpu()
.float()
.numpy()
)
pitch_guidance = pitch != None and pitchf != None
# prepare source audio
feats = torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float()
feats = feats.mean(-1) if feats.dim() == 2 else feats
assert feats.dim() == 1, feats.dim()
feats = feats.view(1, -1).to(self.device)
# extract features
feats = model(feats)["last_hidden_state"]
feats = model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats
# make a copy for pitch guidance and protection
feats0 = feats.clone() if pitch_guidance else None
if index: # set by parent function, only true if index is available, loaded, and index rate > 0
feats = self._retrieve_speaker_embeddings(feats, index, big_npy, index_rate)
# feature upsampling
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
# adjust the length if the audio is short
p_len = min(audio0.shape[0] // self.window, feats.shape[1])
if pitch_guidance:
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
pitch, pitchf = pitch[:, :p_len], pitchf[:, :p_len]
# Pitch protection blending
if protect < 0.5:
pitchff = pitchf.clone()
pitchff[pitchf > 0] = 1
pitchff[pitchf < 1] = protect
feats = feats * pitchff.unsqueeze(-1) + feats0 * (1 - pitchff.unsqueeze(-1))
feats = feats.to(feats0.dtype)
else:
audio1 = (
(net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
)
del feats, p_len, padding_mask
if torch.cuda.is_available():
torch.cuda.empty_cache()
pitch, pitchf = None, None
p_len = torch.tensor([p_len], device=self.device).long()
audio1 = ((net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]).data.cpu().float().numpy())
# clean up
del feats, feats0, p_len
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio1

def _retrieve_speaker_embeddings(self, feats, index, big_npy, index_rate):
npy = feats[0].cpu().numpy()
npy = npy.astype("float32") if self.is_half else npy
Expand All @@ -496,12 +460,9 @@ def _retrieve_speaker_embeddings(self, feats, index, big_npy, index_rate):
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
npy = npy.astype("float16") if self.is_half else npy
feats = (
torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate
+ (1 - index_rate) * feats
)
feats = torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats
return feats

def pipeline(
self,
model,
Expand Down Expand Up @@ -689,7 +650,9 @@ def pipeline(
audio_max = np.abs(audio_opt).max() / 0.99
if audio_max > 1:
audio_opt /= audio_max
del pitch, pitchf, sid
if pitch_guidance:
del pitch, pitchf
del sid
if torch.cuda.is_available():
torch.cuda.empty_cache()
return audio_opt
Loading

0 comments on commit 17180b2

Please sign in to comment.