diff --git a/slippi_ai/eval_lib.py b/slippi_ai/eval_lib.py index bb19695..c874ca0 100644 --- a/slippi_ai/eval_lib.py +++ b/slippi_ai/eval_lib.py @@ -457,17 +457,29 @@ def build_delayed_agent( policy = saving.load_policy_from_state(state) rl_name = get_name_from_rl_state(state) + if rl_name is not None: - name = rl_name - if isinstance(name, list): - if len(name) > 1: - logging.warning('Agent trained with multiple names, using first.') - # TODO: cycle through names? - name = name[0] + if isinstance(rl_name, str): + rl_name = [rl_name] + + override = False + + if name is None: + override = True + elif isinstance(name, str) and name not in rl_name: + logging.warning(f'Agent trained with name(s) "{rl_name}", got "{name}"') + override = True + elif isinstance(name, list): + for n in name: + if n not in rl_name: + raise ValueError(f'Agent trained with name(s) {rl_name}, got "{n}"') - logging.info('Setting agent name to "%s" from RL', name) + if override: + logging.info('Setting agent name to "%s" from RL', name) + name = rl_name[0] if name is None: + # TODO: just pick from the name_map? raise ValueError('Must specify an agent name.') if isinstance(name, str): diff --git a/slippi_ai/flag_utils.py b/slippi_ai/flag_utils.py index d53addb..42bc822 100644 --- a/slippi_ai/flag_utils.py +++ b/slippi_ai/flag_utils.py @@ -129,9 +129,9 @@ def dataclass_from_dict(cls: tp.Type[T], nest: dict) -> T: def override_dict( base: dict, - overrides: flags.FlagHolder, + overrides: flags.FlagHolder[dict], prefix: tp.Sequence[str], -) -> str: +) -> dict: """Override a base config value from another dictionary.""" def maybe_update(path: tp.Sequence[str], base_value):