From 25cf2c2dbb299fe0f840bb6037daf005a791eacf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 20 Nov 2023 14:08:51 -0800 Subject: [PATCH] Update bulk_runner with improved filtering options for benchmarking / val runs --- bulk_runner.py | 43 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/bulk_runner.py b/bulk_runner.py index b71d0bb686..5f5db6407e 100755 --- a/bulk_runner.py +++ b/bulk_runner.py @@ -21,7 +21,7 @@ from typing import Callable, List, Tuple, Union -from timm.models import is_model, list_models +from timm.models import is_model, list_models, get_pretrained_cfg parser = argparse.ArgumentParser(description='Per-model process launcher') @@ -98,16 +98,32 @@ def main(): cmd, cmd_args = cmd_from_args(args) model_cfgs = [] - model_names = [] if args.model_list == 'all': - # NOTE should make this config, for validation / benchmark runs the focus is 1k models, - # so we filter out 21/22k and some other unusable heads. This will change in the future... - exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k'] model_names = list_models( pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set - exclude_filters=exclude_model_filters ) model_cfgs = [(n, None) for n in model_names] + elif args.model_list == 'all_in1k': + model_names = list_models(pretrained=True) + model_cfgs = [] + for n in model_names: + pt_cfg = get_pretrained_cfg(n) + if getattr(pt_cfg, 'num_classes', 0) == 1000: + print(n, pt_cfg.num_classes) + model_cfgs.append((n, None)) + elif args.model_list == 'all_res': + model_names = list_models() + model_names += [n.split('.')[0] for n in list_models(pretrained=True)] + model_cfgs = set() + for n in model_names: + pt_cfg = get_pretrained_cfg(n) + if pt_cfg is None: + print(f'Model {n} is missing pretrained cfg, skipping.') + continue + model_cfgs.add((n, pt_cfg.input_size[-1])) + if pt_cfg.test_input_size is not None: + model_cfgs.add((n, pt_cfg.test_input_size[-1])) + model_cfgs = [(n, {'img-size': r}) for n, r in sorted(model_cfgs)] elif not is_model(args.model_list): # model name doesn't exist, try as wildcard filter model_names = list_models(args.model_list) @@ -122,7 +138,8 @@ def main(): results_file = args.results_file or './results.csv' results = [] errors = [] - print('Running script on these models: {}'.format(', '.join(model_names))) + model_strings = '\n'.join([f'{x[0]}, {x[1]}' for x in model_cfgs]) + print(f"Running script on these models:\n {model_strings}") if not args.sort_key: if 'benchmark' in args.script: if any(['train' in a for a in args.script_args]): @@ -136,10 +153,14 @@ def main(): print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}') try: - for m, _ in model_cfgs: + for m, ax in model_cfgs: if not m: continue args_str = (cmd, *[str(e) for e in cmd_args], '--model', m) + if ax is not None: + extra_args = [(f'--{k}', str(v)) for k, v in ax.items()] + extra_args = [i for t in extra_args for i in t] + args_str += tuple(extra_args) try: o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1] r = json.loads(o) @@ -157,7 +178,11 @@ def main(): if errors: print(f'{len(errors)} models had errors during run.') for e in errors: - print(f"\t {e['model']} ({e.get('error', 'Unknown')})") + if 'model' in e: + print(f"\t {e['model']} ({e.get('error', 'Unknown')})") + else: + print(e) + results = list(filter(lambda x: 'error' not in x, results)) no_sortkey = list(filter(lambda x: sort_key not in x, results))