Skip to content

Commit

Permalink
Allow setting any name that an RL agent was trained with.
Browse files Browse the repository at this point in the history
  • Loading branch information
vladfi1 committed Nov 11, 2024
1 parent b6d2776 commit 8857a3c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
26 changes: 19 additions & 7 deletions slippi_ai/eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,29 @@ def build_delayed_agent(
policy = saving.load_policy_from_state(state)

rl_name = get_name_from_rl_state(state)

if rl_name is not None:
name = rl_name
if isinstance(name, list):
if len(name) > 1:
logging.warning('Agent trained with multiple names, using first.')
# TODO: cycle through names?
name = name[0]
if isinstance(rl_name, str):
rl_name = [rl_name]

override = False

if name is None:
override = True
elif isinstance(name, str) and name not in rl_name:
logging.warning(f'Agent trained with name(s) "{rl_name}", got "{name}"')
override = True
elif isinstance(name, list):
for n in name:
if n not in rl_name:
raise ValueError(f'Agent trained with name(s) {rl_name}, got "{n}"')

logging.info('Setting agent name to "%s" from RL', name)
if override:
logging.info('Setting agent name to "%s" from RL', name)
name = rl_name[0]

if name is None:
# TODO: just pick from the name_map?
raise ValueError('Must specify an agent name.')

if isinstance(name, str):
Expand Down
4 changes: 2 additions & 2 deletions slippi_ai/flag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def dataclass_from_dict(cls: tp.Type[T], nest: dict) -> T:

def override_dict(
base: dict,
overrides: flags.FlagHolder,
overrides: flags.FlagHolder[dict],
prefix: tp.Sequence[str],
) -> str:
) -> dict:
"""Override a base config value from another dictionary."""

def maybe_update(path: tp.Sequence[str], base_value):
Expand Down

0 comments on commit 8857a3c

Please sign in to comment.