Skip to content

Commit

Permalink
Update bulk_runner with improved filtering options for benchmarking /…
Browse files Browse the repository at this point in the history
… val runs
  • Loading branch information
rwightman committed Nov 20, 2023
1 parent dfb8658 commit 25cf2c2
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions bulk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Callable, List, Tuple, Union


from timm.models import is_model, list_models
from timm.models import is_model, list_models, get_pretrained_cfg


parser = argparse.ArgumentParser(description='Per-model process launcher')
Expand Down Expand Up @@ -98,16 +98,32 @@ def main():
cmd, cmd_args = cmd_from_args(args)

model_cfgs = []
model_names = []
if args.model_list == 'all':
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
# so we filter out 21/22k and some other unusable heads. This will change in the future...
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
model_names = list_models(
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
exclude_filters=exclude_model_filters
)
model_cfgs = [(n, None) for n in model_names]
elif args.model_list == 'all_in1k':
model_names = list_models(pretrained=True)
model_cfgs = []
for n in model_names:
pt_cfg = get_pretrained_cfg(n)
if getattr(pt_cfg, 'num_classes', 0) == 1000:
print(n, pt_cfg.num_classes)
model_cfgs.append((n, None))
elif args.model_list == 'all_res':
model_names = list_models()
model_names += [n.split('.')[0] for n in list_models(pretrained=True)]
model_cfgs = set()
for n in model_names:
pt_cfg = get_pretrained_cfg(n)
if pt_cfg is None:
print(f'Model {n} is missing pretrained cfg, skipping.')
continue
model_cfgs.add((n, pt_cfg.input_size[-1]))
if pt_cfg.test_input_size is not None:
model_cfgs.add((n, pt_cfg.test_input_size[-1]))
model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)]
elif not is_model(args.model_list):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model_list)
Expand All @@ -122,7 +138,8 @@ def main():
results_file = args.results_file or './results.csv'
results = []
errors = []
print('Running script on these models: {}'.format(', '.join(model_names)))
model_strings = '\n'.join([f'{x[0]}, {x[1]}' for x in model_cfgs])
print(f"Running script on these models:\n {model_strings}")
if not args.sort_key:
if 'benchmark' in args.script:
if any(['train' in a for a in args.script_args]):
Expand All @@ -136,10 +153,14 @@ def main():
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')

try:
for m, _ in model_cfgs:
for m, ax in model_cfgs:
if not m:
continue
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
if ax is not None:
extra_args = [(f'--{k}', str(v)) for k, v in ax.items()]
extra_args = [i for t in extra_args for i in t]
args_str += tuple(extra_args)
try:
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
r = json.loads(o)
Expand All @@ -157,7 +178,11 @@ def main():
if errors:
print(f'{len(errors)} models had errors during run.')
for e in errors:
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
if 'model' in e:
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
else:
print(e)

results = list(filter(lambda x: 'error' not in x, results))

no_sortkey = list(filter(lambda x: sort_key not in x, results))
Expand Down

0 comments on commit 25cf2c2

Please sign in to comment.