Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change overwrite logic #58

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,17 +783,17 @@ def get_save_path(
file_path: str,
input_dir: str | None,
save_dir: str,
idx: int | str = "",
idx_str: int | str = "",
):
if input_dir is None:
save_path = os.path.join(
save_dir,
os.path.splitext(os.path.basename(file_path))[0] + f"{idx}.mid",
os.path.splitext(os.path.basename(file_path))[0] + f"{idx_str}.mid",
)
else:
input_rel_path = os.path.relpath(file_path, input_dir)
save_path = os.path.join(
save_dir, os.path.splitext(input_rel_path)[0] + f"{idx}.mid"
save_dir, os.path.splitext(input_rel_path)[0] + f"{idx_str}.mid"
)
if not os.path.isdir(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
Expand All @@ -810,7 +810,7 @@ def process_file(
save_dir: str,
input_dir: str,
logger: logging.Logger,
segments: List[Tuple[int, int]] | None = None,
segments: List[Tuple[int, Tuple[int, int]]] | None = None,
):
def _save_seq(_seq: List, _save_path: str):
if os.path.exists(_save_path):
Expand Down Expand Up @@ -852,12 +852,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):

pid = threading.get_ident()
if segments is None:
segments = [None]
# process_file and get_wav_segments will interpret segment=None as
# processing the entire file
segments = [(None, None)]

if len(segments) == 0:
logger.info(f"No segments to transcribe, skipping file: {file_path}")

for idx, segment in enumerate(segments):
for idx, segment in segments:
idx_str = f"_{idx}" if idx is not None else ""
save_path = get_save_path(file_path, input_dir, save_dir, idx_str)

try:
seq = transcribe_file(
file_path,
Expand All @@ -876,15 +881,17 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
logger.info(f"Removed {res_rmv_cnt} from result queue")
continue

logger.info(f"Finished file: {file_path} (segment: {idx})")
logger.info(
f"Finished file: {file_path} (segment: {idx if idx is not None else 'full'})"
)
if len(seq) < 500:
logger.info(f"Skipping seq - too short (segment {idx})")
logger.info(
f"Skipping seq - too short (segment {idx if idx is not None else 'full'})"
)
else:
logger.debug(
f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx})"
f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx if idx is not None else 'full'})"
)
idx = f"_{idx}" if segment is not None else ""
save_path = get_save_path(file_path, input_dir, save_dir, idx)
_save_seq(seq, save_path)

logger.info(f"{file_queue.qsize()} file(s) remaining in queue")
Expand Down Expand Up @@ -997,20 +1004,28 @@ def batch_transcribe(
files_to_process, key=lambda x: os.path.getsize(x["path"]), reverse=True
)
for file_to_process in files_to_process:
# Only add to file_queue if transcription MIDI file doesn't exist
if (
os.path.isfile(
if "segments" in file_to_process:
# Process files with segments
unsaved_segments = []
for idx, segment in enumerate(file_to_process["segments"]):
segment_save_path = get_save_path(
file_to_process["path"],
input_dir,
save_dir,
idx_str=f"_{idx}",
)
if not os.path.isfile(segment_save_path):
unsaved_segments.append((idx, segment))

if unsaved_segments:
file_to_process["segments"] = unsaved_segments
file_queue.put(file_to_process)
else:
# Process files without segments (whole file)
if not os.path.isfile(
get_save_path(file_to_process["path"], input_dir, save_dir)
)
is False
) and os.path.isfile(
get_save_path(
file_to_process["path"], input_dir, save_dir, idx="_0"
)
) is False:
file_queue.put(file_to_process)
elif len(files_to_process) == 1:
file_queue.put(file_to_process)
):
file_queue.put(file_to_process)

logger.info(
f"Files to process: {file_queue.qsize()}/{len(files_to_process)}"
Expand All @@ -1026,7 +1041,7 @@ def batch_transcribe(
file_queue.qsize(),
)
num_processes_per_worker = min(
3 * (batch_size // num_workers), file_queue.qsize() // num_workers
5 * (batch_size // num_workers), file_queue.qsize() // num_workers
)

mp_manager = Manager()
Expand Down
Loading