Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create good defaults in accelerate launch #553

Merged
merged 7 commits into from
Jul 22, 2022
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import argparse
import importlib
import logging
import os
import subprocess
import sys
Expand All @@ -24,6 +25,8 @@
from pathlib import Path
from typing import Dict, List

import torch

from accelerate.commands.config import default_config_file, load_config_from_file
from accelerate.commands.config.config_args import SageMakerConfig
from accelerate.utils import (
Expand All @@ -40,6 +43,9 @@
from accelerate.utils.dataclasses import SageMakerDistributedType


logger = logging.getLogger(__name__)


def launch_command_parser(subparsers=None):
if subparsers is not None:
parser = subparsers.add_parser("launch")
Expand Down Expand Up @@ -684,15 +690,33 @@ def launch_command(args):
and getattr(args, name, None) is None
):
setattr(args, name, attr)

if not args.mixed_precision:
if args.fp16:
args.mixed_precision = "fp16"
else:
args.mixed_precision = defaults.mixed_precision
else:
warned = False
if args.num_processes is None:
logger.warn("`--num_processes` was not set, using a value of `1`.")
warned = True
args.num_processes = 1
if args.num_machines is None:
warned = True
logger.warn("`--num_machines` was not set, using a value of `1`.")
args.num_machines = 1
if args.mixed_precision is None:
warned = True
logger.warn("`--mixed_precision` was not set, using a value of `'no'`.")
args.mixed_precision = "no"
if not hasattr(args, "use_cpu"):
args.use_cpu = args.cpu
if warned:
logger.warn("To avoid these warnings pass in values for each of the problematic parameters")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved

if args.multi_gpu and args.num_processes == 1:
logger.warn("`--multi_gpu` was passed but `num_processes` was not set. Automatically using all available GPUs")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
args.num_processes = torch.cuda.device_count()

# Use the proper launcher
if args.use_deepspeed and not args.cpu:
Expand Down