Skip to content

Commit

Permalink
[benchmarks] Fix verifier error handling and OOM. (#7777)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Aug 5, 2024
1 parent 7135a7e commit 5bbc4b3
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 64 deletions.
102 changes: 50 additions & 52 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch._dynamo.utils as dynamo_utils
import tiers
import traceback
import typing
from typing import Any, Dict, List, Optional, Sequence
import torch_xla.debug.metrics as met
Expand All @@ -22,7 +23,7 @@
import copy
from torch.autograd import DeviceType
from benchmark_model import ModelLoader
from verifier import VerificationCode, verify
from verifier import VerificationCode, VerificationException, verify
from enum import Enum
from torchbench_model import TorchBenchModelLoader
from benchmark_model import BenchmarkModel
Expand Down Expand Up @@ -183,7 +184,7 @@ def generate_and_run_all_configs(self):
benchmark_experiment.to_dict(),
benchmark_model.to_dict(),
metrics={"error": str(e)},
verification_code=VerificationCode.VERIFIER_SKIPPED_UNEXPECTEDLY,
verification_code=VerificationCode.VERIFIER_DIDNT_RUN,
)
except subprocess.CalledProcessError as e:
self._fwd_captured_stdout_stderr(e.stdout, e.stderr)
Expand All @@ -192,23 +193,23 @@ def generate_and_run_all_configs(self):
benchmark_experiment.to_dict(),
benchmark_model.to_dict(),
metrics={"error": e.stderr},
verification_code=VerificationCode.VERIFIER_SKIPPED_UNEXPECTEDLY,
verification_code=VerificationCode.VERIFIER_DIDNT_RUN,
)
except subprocess.SubprocessError as e:
logger.error("ERROR when launching child process")
self._save_results(
benchmark_experiment.to_dict(),
benchmark_model.to_dict(),
metrics={"error": str(e)},
verification_code=VerificationCode.VERIFIER_SKIPPED_UNEXPECTEDLY,
verification_code=VerificationCode.VERIFIER_DIDNT_RUN,
)
except ValueError as e:
logger.error(f"ERROR {e}")
self._save_results(
benchmark_experiment.to_dict(),
benchmark_model.to_dict(),
metrics={"error": str(e)},
verification_code=VerificationCode.VERIFIER_SKIPPED_UNEXPECTEDLY,
verification_code=VerificationCode.VERIFIER_DIDNT_RUN,
)

# TODO: Use `_unique_basename` instead.
Expand Down Expand Up @@ -262,59 +263,56 @@ def run_single_config(self):

# Load experiment and model.
experiment_config = json.loads(self._args.experiment_config)
experiment = self.experiment_loader.load_experiment(experiment_config)
experiment_dict = experiment.to_dict()

model_config = json.loads(self._args.model_config)
benchmark_experiment = self.experiment_loader.load_experiment(
experiment_config)
reset_rng_state(benchmark_experiment)
benchmark_model = self.model_loader.load_model(model_config,
benchmark_experiment)
model = self.model_loader.load_model(model_config, experiment, dummy=True)
model_dict = model.to_dict()

# Initialize output variables
accumulated_metrics = OrderedDict()
verification_code = VerificationCode.VERIFIER_SKIPPED

# Turn on CUDAGraphs if we are running inductor
if benchmark_experiment.is_inductor():
if experiment.is_inductor():
from torch._inductor import config as inductor_config
inductor_config.triton.cudagraphs = True

# Repeat the experiment and accumulate metrics.
with benchmark_model.pick_grad():
accumulated_metrics = OrderedDict()
for repeat_iteration in range(self._args.repeat):
metrics, _ = self.run_once_and_gather_metrics(benchmark_experiment,
benchmark_model,
experiment_config,
model_config,
repeat_iteration)
for k, v in metrics.items():
if k not in accumulated_metrics:
accumulated_metrics[k] = []
accumulated_metrics[k].append(v)

# Save the dict representation before deleting them.
# This will be used later for saving the results.
experiment_dict = benchmark_experiment.to_dict()
model_dict = benchmark_model.to_dict()

# Save other model-specific configuration: tolerance and whether
# to use cosine similarity on accuracy checks.
tolerance = benchmark_model.tolerance()
use_cosine_similarity = benchmark_model.use_cosine_similarity()
skip_verifier = benchmark_model.skip_verifier()

# Delete the instantiated BenchmarkModel, so we can save memory
# for verifying the result.
del benchmark_model
cleanup(benchmark_experiment.is_cuda())

# Run the verifier iff:
#
# 1. We are running this script with --verify flag
# 2. It should not be skipped
if self._args.verify and not skip_verifier:
res = verify(self, experiment_config, model_config, tolerance,
use_cosine_similarity)
else:
res = VerificationCode.VERIFIER_SKIPPED

self._save_results(experiment_dict, model_dict, accumulated_metrics, res)
# Only run the actual experiment first if the --verify flag is not
# specified. This is so we avoid using too much memory before running
# eager.
if not self._args.verify:
reset_rng_state(experiment)
model = self.model_loader.load_model(model_config, experiment)

# Repeat the experiment and accumulate metrics.
with model.pick_grad():
for repeat_iteration in range(self._args.repeat):
metrics, _ = self.run_once_and_gather_metrics(experiment, model,
experiment_config,
model_config,
repeat_iteration)
for k, v in metrics.items():
if k not in accumulated_metrics:
accumulated_metrics[k] = []
accumulated_metrics[k].append(v)
elif not model.skip_verifier():
try:
verification_code = verify(
self,
experiment_config,
model_config,
tolerance=model.tolerance(),
use_cosine_similarity=model.use_cosine_similarity())
except VerificationException as e:
verification_code = e.code
# Record the error in the metrics dictionary.
# Similar to what's done when the whole experiment fails.
accumulated_metrics["error"] = traceback.format_exc()

self._save_results(experiment_dict, model_dict, accumulated_metrics,
verification_code)

def run_once_and_gather_metrics(
self, benchmark_experiment: BenchmarkExperiment,
Expand Down
40 changes: 28 additions & 12 deletions benchmarks/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from benchmark_model import ModelLoader
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
from util import cleanup, move_to_device, reset_rng_state, StrOrBool

logger = logging.getLogger(__name__)
Expand All @@ -20,15 +20,22 @@ class VerificationCode(str, Enum):
FAIL = 'FAIL',
# Eager execution failed.
EAGER_FAILED = 'EAGER_FAILED'
# Verifier failed, raising an exception.
VERIFIER_FAILED = 'VERIFIER_FAILED'
# An exception was raised when running the verifier.
EXCEPTION_RAISED = 'EXCEPTION_RAISED'
# Eager runs do not agree.
NONDETERMINISTIC_EAGER_RUN = 'NONDETERMINISTIC_EAGER_RUN'
# Verifier skipped.
VERIFIER_SKIPPED = 'VERIFIER_SKIPPED'
# Verifier did not run. It was skipped, but due to an unexpected reason.
# Either an exception was raised or the process timeout.
VERIFIER_SKIPPED_UNEXPECTEDLY = 'VERIFIER_SKIPPED_UNEXPECTEDLY'
VERIFIER_DIDNT_RUN = 'VERIFIER_DIDNT_RUN'


class VerificationException(Exception):

def __init__(self, code: VerificationCode) -> None:
super().__init__(f"verifier failed with code: {code}")
self.code = code


def verify(
Expand All @@ -43,15 +50,15 @@ def verify(
Both `tolerance` and `use_cosine_similarity` will be used when checking whether the
accuracy of the actual experiment is close to that of eager.
"""

try:
# 1. Run eager twice, so as to make sure the model actually outputs deterministic results.
try:
eager_output = _run(runner, experiment_config, model_config, eager=True)
additional_eager_output = _run(
runner, experiment_config, model_config, eager=True)
except:
traceback.print_exc()
return VerificationCode.EAGER_FAILED
except Exception as e:
raise VerificationException(VerificationCode.EAGER_FAILED) from e

# If the results are not close, it might mean that this model is not deterministic.
# Therefore, we give up the verification process, entirely.
Expand All @@ -60,8 +67,14 @@ def verify(

# 2. Compute the output using float64 precision for increased precision. This should
# help deciding whether the outputs of the actual experiment have acceptable accuracy.
eager_fp64_output = _run(
runner, experiment_config, model_config, force_fp64=True, eager=True)
try:
eager_fp64_output = _run(
runner, experiment_config, model_config, force_fp64=True, eager=True)
except:
logger.warning(
"failed running fp64 golden ref. Setting accuracy to cosine.")
eager_fp64_output = None
use_cosine_similarity = True

# 3. Compute the output of the actual experiment.
output = _run(runner, experiment_config, model_config)
Expand All @@ -76,9 +89,12 @@ def verify(
tol=tolerance,
):
return VerificationCode.FAIL
except:
traceback.print_exc()
return VerificationCode.VERIFIER_FAILED
except VerificationException:
raise
except Exception as e:
# If anything went wrong (other than an explicit VerificationException), raise
# a VerificationException with EXCEPTION_RAISED code, while chaining the cause.
raise VerificationException(VerificationCode.EXCEPTION_RAISED) from e

return VerificationCode.PASS

Expand Down

0 comments on commit 5bbc4b3

Please sign in to comment.