Skip to content

Commit

Permalink
Type annotation for benchmarks/ (#7289)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore authored Jun 20, 2024
1 parent cb6549a commit 98dd99e
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 129 deletions.
23 changes: 15 additions & 8 deletions benchmarks/benchmark_experiment.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import argparse
from collections import OrderedDict
import logging
import os
from typing import Any, List, Dict, Optional
import torch
import torch._dynamo as dynamo
import torch_xla.core.xla_model as xm
from util import parse_none_str, is_xla_device_available, get_accelerator_model
from util import parse_none_str, is_xla_device_available, get_accelerator_model, StrOrBool

logger = logging.getLogger(__name__)


class ExperimentLoader:

def __init__(self, args):
def __init__(self, args: argparse.Namespace):
self._args = args

def list_experiment_configs(self):
Expand Down Expand Up @@ -58,7 +60,7 @@ def list_experiment_configs(self):
experiment_configs.append(cfg)
return experiment_configs

def _expand_config_choices(self, config_choices):
def _expand_config_choices(self, config_choices: Dict[str, List[Any]]):
configs = [{}]
for k, choices in config_choices.items():
new_configs = []
Expand All @@ -70,7 +72,9 @@ def _expand_config_choices(self, config_choices):
configs = new_configs
return configs

def _is_available(self, experiment_config):
def _is_available(self,
experiment_config: List[Dict[str,
List[Optional[StrOrBool]]]]):
cfg_dynamo = experiment_config["dynamo"]
cfg_accelerator = experiment_config["accelerator"]
cfg_xla = experiment_config["xla"]
Expand Down Expand Up @@ -123,7 +127,9 @@ def _is_available(self, experiment_config):

return True

def load_experiment(self, experiment_config):
def load_experiment(self,
experiment_config: List[Dict[str,
List[Optional[StrOrBool]]]]):
accelerator = experiment_config["accelerator"].lower()
xla = experiment_config["xla"]
xla_flags = experiment_config["xla_flags"]
Expand All @@ -145,8 +151,9 @@ def load_experiment(self, experiment_config):

class BenchmarkExperiment:

def __init__(self, accelerator, xla, xla_flags, dynamo, torch_xla2,
keep_model_data_on_cuda: bool, test, batch_size):
def __init__(self, accelerator: str, xla: Optional[str],
xla_flags: Optional[str], dynamo: str, torch_xla2: bool,
keep_model_data_on_cuda: bool, test: str, batch_size: str):
self.accelerator = accelerator
self.xla = xla
self.xla_flags = xla_flags
Expand All @@ -157,7 +164,7 @@ def __init__(self, accelerator, xla, xla_flags, dynamo, torch_xla2,
self.batch_size = batch_size
self.accelerator_model = get_accelerator_model(self.accelerator)

def update_process_env(self, process_env):
def update_process_env(self, process_env: Dict[str, str]):

# Remove env vars that would interfere with the subprocess.
if self.xla is not None:
Expand Down
120 changes: 64 additions & 56 deletions benchmarks/benchmark_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
from collections import OrderedDict
import contextlib
import logging
Expand All @@ -7,64 +8,16 @@
from torch._dynamo.testing import collect_results
from torch.utils import _pytree as pytree
from util import cast_to_dtype, move_to_device
from benchmark_experiment import BenchmarkExperiment
from typing import Dict, Any, Sequence

logger = logging.getLogger(__name__)


class ModelLoader:

def __init__(self, args):
self._args = args
self.suite_name = self._args.suite_name
self.benchmark_model_class = BenchmarkModel
self._dynamo_compile_opts = dict()
if self._args.filter_by_single_graph:
self._dynamo_compile_opts['fullgraph'] = True

def list_model_configs(self):
model_configs = [
{
"model_name": "dummy"
},
]

return model_configs

def is_compatible(self, dummy_benchmark_model, benchmark_experiment):
return True

def get_benchmark_indices(self, length):
start = self._args.partition_id * (length // self._args.total_partitions)
end = ((self._args.partition_id + 1) *
(length // self._args.total_partitions)
if self._args.partition_id < self._args.total_partitions - 1 else
length)
return start, end

def skip_model(self, model_name):
return (not re.search("|".join(self._args.filter), model_name, re.I) or
re.search("|".join(self._args.exclude), model_name, re.I))

def load_model(self, model_config, benchmark_experiment, dummy=False):
suite_name = self.suite_name
model_name = model_config["model_name"]
benchmark_model = self.benchmark_model_class(
suite_name=suite_name,
model_name=model_name,
benchmark_experiment=benchmark_experiment,
)

if not dummy:
benchmark_model.set_up()
benchmark_model.prepare_for_experiment(
dynamo_compilation_opts=self._dynamo_compile_opts)

return benchmark_model


class BenchmarkModel:

def __init__(self, suite_name, model_name, benchmark_experiment):
def __init__(self, suite_name: str, model_name: str,
benchmark_experiment: BenchmarkExperiment):
self.suite_name = suite_name
self.model_name = model_name
self.benchmark_experiment = benchmark_experiment
Expand Down Expand Up @@ -108,7 +61,7 @@ def _prepare_for_train(self):
def conversion_dtype(self):
return None

def prepare_for_experiment(self, dynamo_compilation_opts):
def prepare_for_experiment(self, dynamo_compilation_opts: Dict[str, str]):
self.device = self.benchmark_experiment.get_device()
self.dtype = self.conversion_dtype()
if self.dtype is not None:
Expand Down Expand Up @@ -184,7 +137,7 @@ def _optimizer_step(self):
def compute_loss(self, pred):
raise NotImplementedError

def train(self, inputs, collect_full_output=False):
def train(self, inputs: Sequence[Any], collect_full_output: bool = False):
self._optimizer_zero_grad()
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
Expand All @@ -197,7 +150,7 @@ def train(self, inputs, collect_full_output=False):
# TODO: dynamo inductor would fail if .detach() is used
return None

def eval(self, inputs, collect_full_output=False):
def eval(self, inputs: Sequence[Any], collect_full_output: bool = False):
with self.autocast(**self.autocast_kwargs):
pred = self.module(*inputs)
return pred
Expand All @@ -216,5 +169,60 @@ def to_dict(self):
def default_precision_flag(self):
return None

def update_process_env(self, process_env):
def update_process_env(self, process_env: Dict[str, str]):
pass


class ModelLoader:

def __init__(self, args: argparse.Namespace):
self._args = args
self.suite_name = self._args.suite_name
self.benchmark_model_class = BenchmarkModel
self._dynamo_compile_opts = dict()
if self._args.filter_by_single_graph:
self._dynamo_compile_opts['fullgraph'] = True

def list_model_configs(self):
model_configs = [
{
"model_name": "dummy"
},
]

return model_configs

def is_compatible(self, dummy_benchmark_model: BenchmarkModel,
benchmark_experiment: BenchmarkExperiment):
return True

def get_benchmark_indices(self, length: int):
start = self._args.partition_id * (length // self._args.total_partitions)
end = ((self._args.partition_id + 1) *
(length // self._args.total_partitions)
if self._args.partition_id < self._args.total_partitions - 1 else
length)
return start, end

def skip_model(self, model_name: str):
return (not re.search("|".join(self._args.filter), model_name, re.I) or
re.search("|".join(self._args.exclude), model_name, re.I))

def load_model(self,
model_config: Dict[str, Any],
benchmark_experiment: BenchmarkExperiment,
dummy: bool = False) -> BenchmarkModel:
suite_name = self.suite_name
model_name = model_config["model_name"]
benchmark_model = self.benchmark_model_class(
suite_name=suite_name,
model_name=model_name,
benchmark_experiment=benchmark_experiment,
)

if not dummy:
benchmark_model.set_up()
benchmark_model.prepare_for_experiment(
dynamo_compilation_opts=self._dynamo_compile_opts)

return benchmark_model
Loading

0 comments on commit 98dd99e

Please sign in to comment.