Skip to content

Commit

Permalink
Merge pull request #605 from ShiromiyaG/disable-filters
Browse files Browse the repository at this point in the history
Disable filters in preprocess
  • Loading branch information
blaisewf authored Aug 17, 2024
2 parents ef0c8e2 + da1bbd4 commit d85e980
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
2 changes: 2 additions & 0 deletions assets/i18n/languages/en_US.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"Preprocess": "Preprocess",
"Audio cutting": "Audio cutting",
"It's recommended to deactivate this option if your dataset has already been processed.": "It's recommended to deactivate this option if your dataset has already been processed.",
"No Filters": "No Filters",
"Disables all preprocessing filters.": "Disables all preprocessing filters.",
"Model Name": "Model Name",
"Name of the new model.": "Name of the new model.",
"Enter model name": "Enter model name",
Expand Down
11 changes: 11 additions & 0 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def run_preprocess_script(
sample_rate: int,
cpu_cores: int,
cut_preprocess: bool,
no_filters: bool,
):
config = get_config()
per = 3.0 if config.is_half else 3.7
Expand All @@ -276,6 +277,7 @@ def run_preprocess_script(
per,
cpu_cores,
cut_preprocess,
no_filters,
],
),
]
Expand Down Expand Up @@ -1048,6 +1050,14 @@ def parse_arguments():
default=True,
required=False,
)
preprocess_parser.add_argument(
"--no_filters",
type=lambda x: bool(strtobool(x)),
choices=[True, False],
help="Disable all filters during preprocessing.",
default=False,
required=False,
)

# Parser for 'extract' mode
extract_parser = subparsers.add_parser(
Expand Down Expand Up @@ -1515,6 +1525,7 @@ def main():
sample_rate=args.sample_rate,
cpu_cores=args.cpu_cores,
cut_preprocess=args.cut_preprocess,
no_filters=args.no_filters,
)
elif args.mode == "extract":
run_extract_script(
Expand Down
53 changes: 40 additions & 13 deletions rvc/train/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,17 @@ def _write_audio(self, audio: torch.Tensor, filename: str, sr: int):
audio = audio.cpu().numpy()
wavfile.write(filename, sr, audio.astype(np.float32))

def process_audio_segment(self, audio_segment: torch.Tensor, idx0: int, idx1: int):
normalized_audio = self._normalize_audio(audio_segment)
def process_audio_segment(
self,
audio_segment: torch.Tensor,
idx0: int,
idx1: int,
no_filters: bool,
):
if no_filters:
normalized_audio = audio_segment
else:
normalized_audio = self._normalize_audio(audio_segment)
if normalized_audio is None:
print(f"{idx0}-{idx1}-filtered")
return
Expand All @@ -82,13 +91,17 @@ def process_audio_segment(self, audio_segment: torch.Tensor, idx0: int, idx1: in
wav_16k_path = os.path.join(self.wavs16k_dir, f"{idx0}_{idx1}.wav")
self._write_audio(audio_16k, wav_16k_path, SAMPLE_RATE_16K)

def process_audio(self, path: str, idx0: int, cut_preprocess: bool):
def process_audio(
self, path: str, idx0: int, cut_preprocess: bool, no_filters: bool
):
try:
audio = load_audio(path, self.sr)
audio = torch.tensor(
signal.lfilter(self.b_high, self.a_high, audio), device=self.device
).float()

if no_filters:
audio = torch.tensor(audio, device=self.device).float()
else:
audio = torch.tensor(
signal.lfilter(self.b_high, self.a_high, audio), device=self.device
).float()
idx1 = 0
if cut_preprocess:
for audio_segment in self.slicer.slice(audio.cpu().numpy()):
Expand All @@ -103,26 +116,35 @@ def process_audio(self, path: str, idx0: int, cut_preprocess: bool):
tmp_audio = audio_segment[
start : start + int(self.per * self.sr)
]
self.process_audio_segment(tmp_audio, idx0, idx1)
self.process_audio_segment(
tmp_audio, idx0, idx1, no_filters
)
idx1 += 1
else:
tmp_audio = audio_segment[start:]
self.process_audio_segment(tmp_audio, idx0, idx1)
self.process_audio_segment(
tmp_audio, idx0, idx1, no_filters
)
idx1 += 1
break
else:
self.process_audio_segment(audio, idx0, idx1)
self.process_audio_segment(audio, idx0, idx1, no_filters)
except Exception as error:
print(f"An error occurred on {path} path: {error}")

def process_audio_file(self, file_path_idx, cut_preprocess):
def process_audio_file(self, file_path_idx, cut_preprocess, no_filters):
file_path, idx0 = file_path_idx
ext = os.path.splitext(file_path)[1].lower()
if ext not in [".wav"]:
audio = AudioSegment.from_file(file_path)
file_path = os.path.join("/tmp", f"{idx0}.wav")
audio.export(file_path, format="wav")
self.process_audio(file_path, idx0, cut_preprocess)
self.process_audio(file_path, idx0, cut_preprocess, no_filters)


def process_file(args):
pp, file, cut_preprocess, no_filters = args
pp.process_audio_file(file, cut_preprocess, no_filters)


def preprocess_training_set(
Expand All @@ -132,6 +154,7 @@ def preprocess_training_set(
exp_dir: str,
per: float,
cut_preprocess: bool,
no_filters: bool,
):
start_time = time.time()

Expand All @@ -146,7 +169,9 @@ def preprocess_training_set(

ctx = multiprocessing.get_context("spawn")
with ctx.Pool(processes=num_processes) as pool:
pool.starmap(pp.process_audio_file, [(file, cut_preprocess) for file in files])
pool.map(
process_file, [(pp, file, cut_preprocess, no_filters) for file in files]
)

elapsed_time = time.time() - start_time
print(f"Preprocess completed in {elapsed_time:.2f} seconds.")
Expand All @@ -161,6 +186,7 @@ def preprocess_training_set(
int(sys.argv[5]) if len(sys.argv) > 5 else multiprocessing.cpu_count()
)
cut_preprocess = strtobool(sys.argv[6])
no_filters = strtobool(sys.argv[7])

preprocess_training_set(
input_root,
Expand All @@ -169,4 +195,5 @@ def preprocess_training_set(
experiment_directory,
percentage,
cut_preprocess,
no_filters,
)
8 changes: 8 additions & 0 deletions tabs/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,13 @@ def train_tab():
value=True,
interactive=True,
visible=True,
)
no_filters = gr.Checkbox(
label=i18n("No Filters"),
info=i18n("Disables all preprocessing filters."),
value=False,
interactive=True,
visible=True,
)
preprocess_output_info = gr.Textbox(
label=i18n("Output Information"),
Expand All @@ -389,6 +396,7 @@ def train_tab():
sampling_rate,
cpu_cores_preprocess,
cut_preprocess,
no_filters,
],
outputs=[preprocess_output_info],
api_name="preprocess_dataset",
Expand Down

0 comments on commit d85e980

Please sign in to comment.