Skip to content

Commit

Permalink
Apply precision config env vars in the root process. (#6152)
Browse files Browse the repository at this point in the history
After some changes to the main branch, os.environ was not sufficient to pick up new env vars in the subprocess.
In this PR we apply a necessary workaround in the root process which launches subprocess per each experiment. New flags are passed via process_env var.
  • Loading branch information
golechwierowicz authored Dec 15, 2023
1 parent c3e341c commit dfcf306
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 21 deletions.
7 changes: 7 additions & 0 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,10 @@ def to_dict(self):
d["suite_name"] = self.suite_name
d["model_name"] = self.model_name
return d

@property
def default_precision_flag(self):
return None

def extend_process_env(self, process_env):
return process_env
1 change: 1 addition & 0 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def generate_and_run_all_configs(self):
experiment_config)
dummy_benchmark_model = self.model_loader.load_model(
model_config, dummy_benchmark_experiment, dummy=True)
process_env = dummy_benchmark_model.extend_process_env(process_env)
experiment_config["process_env"] = process_env
command = ([sys.executable] + sys.argv +
[f"--experiment-config={experiment_config_str}"] +
Expand Down
64 changes: 43 additions & 21 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,24 @@ def set_up(self):
"""
self.optimizer_class = torch.optim.Adam

benchmark = self.load_benchmark()

self.module, self.example_inputs = benchmark.get_module()

self.benchmark_experiment.batch_size = benchmark.batch_size

# Torchbench has quite different setup for yolov3, so directly passing
# the right example_inputs
if self.model_name == "yolov3":
self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, 3,
384, 512),)
if self.benchmark_experiment.test == "train" and self.model_name in DETECTRON2_MODELS:
self.optimizer = benchmark.optimizer

del benchmark
gc.collect()

def load_benchmark(self):
try:
module = importlib.import_module(
f"torchbenchmark.models.{self.model_name}")
Expand All @@ -182,30 +200,16 @@ def set_up(self):
else:
device = str(self.benchmark_experiment.get_device())

benchmark = benchmark_cls(
return benchmark_cls(
test=self.benchmark_experiment.test,
device=device,
batch_size=self.benchmark_experiment.batch_size,
)

self.module, self.example_inputs = benchmark.get_module()

self.benchmark_experiment.batch_size = benchmark.batch_size

# Torchbench has quite different setup for yolov3, so directly passing
# the right example_inputs
if self.model_name == "yolov3":
self.example_inputs = (torch.rand(self.benchmark_experiment.batch_size, 3,
384, 512),)
if self.benchmark_experiment.test == "train" and self.model_name in DETECTRON2_MODELS:
self.optimizer = benchmark.optimizer

del benchmark
gc.collect()

def apply_default_precision_config(self, test, benchmark):
@property
def default_precision_flag(self):
"""
Apply default precision config to XLA, if present.
Get the default precision config to XLA, if present.
Whenever a model has a default precision for cuda set
we need to set proper environment flags so XLA catches
Expand All @@ -215,25 +219,43 @@ def apply_default_precision_config(self, test, benchmark):
changes to the PT/XLA bridge so that the input shape
is properly inferred after issuing converts to `torch.nn.Module`.
"""
test = self.benchmark_experiment.test
try:
benchmark = self.load_benchmark()
except Exception:
logger.exception("Cannot load benchmark model")
return None

if test == "eval" and hasattr(benchmark, 'DEFAULT_EVAL_CUDA_PRECISION'):
precision = benchmark.DEFAULT_EVAL_CUDA_PRECISION
elif test == "train" and hasattr(benchmark, 'DEFAULT_TRAIN_CUDA_PRECISION'):
precision = benchmark.DEFAULT_TRAIN_CUDA_PRECISION
else:
logger.warning("No default precision set. No patching needed.")
return
return None

del benchmark
gc.collect()

precision_flag = None
if precision == "fp16":
os.environ['XLA_USE_FP16'] = '1'
precision_flag = 'XLA_USE_FP16'
elif precision == "amp":
raise ValueError(
f"AMP for PT/XLA:GPU is not implemented yet for torchbench models")
elif precision == "bf16":
os.environ['XLA_USE_BF16'] = '1'
precision_flag = 'XLA_USE_BF16'
elif precision == "fp32":
logger.warning("Sticking with the default fp32 precision.")
else:
raise ValueError(f"Unknown precision: {precision}")
return precision_flag

def extend_process_env(self, process_env):
precision_flag = self.default_precision_flag
if precision_flag is not None:
process_env[precision_flag] = '1'
return process_env

def pick_grad(self):
# special case
Expand Down

0 comments on commit dfcf306

Please sign in to comment.