Skip to content

Commit

Permalink
Move warning (#598)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Aug 2, 2022
1 parent b52b793 commit 15a8c6c
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,17 +339,6 @@ def simple_launcher(args):
mixed_precision = "fp16"

current_env["MIXED_PRECISION"] = str(mixed_precision)
if args.num_cpu_threads_per_process is None:
local_size = get_int_from_env(
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
)
args.num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
if args.num_cpu_threads_per_process == 0:
args.num_cpu_threads_per_process = 1
logger.info(
f"num_cpu_threads_per_process unset, we set it at {args.num_cpu_threads_per_process} to improve oob performance."
)

current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)

process = subprocess.Popen(cmd, env=current_env)
Expand Down Expand Up @@ -447,9 +436,6 @@ def multi_gpu_launcher(args):
current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch_policy)
if args.fsdp_state_dict_type is not None:
current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type)
if args.num_cpu_threads_per_process is None:
args.num_cpu_threads_per_process = 1
logger.info(f"num_cpu_threads_per_process unset, we set it at {args.num_cpu_threads_per_process}.")
current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
process = subprocess.Popen(cmd, env=current_env)
process.wait()
Expand Down Expand Up @@ -803,6 +789,17 @@ def launch_command(args):
if "--num_processes" in warn:
warned[i] = warn.replace("`1`", f"`{args.num_processes}`")

if args.num_cpu_threads_per_process is None:
local_size = get_int_from_env(
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1
)
args.num_cpu_threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
if args.num_cpu_threads_per_process == 0:
args.num_cpu_threads_per_process = 1
warned.append(
f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance"
)

if any(warned):
message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n"
message += "\n".join(warned)
Expand Down

0 comments on commit 15a8c6c

Please sign in to comment.