diff --git a/parlai/agents/local_human/agents.py b/parlai/agents/local_human/agents.py new file mode 100644 index 00000000000..969adc8045f --- /dev/null +++ b/parlai/agents/local_human/agents.py @@ -0,0 +1,27 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. An additional grant +# of patent rights can be found in the PATENTS file in the same directory. +"""Agent does gets the local keyboard input in the act() function. + Example: python examples/eval_model.py -m local_human -t babi:Task1k:1 -dt valid +""" + +from parlai.core.agents import Agent +from parlai.core.worlds import display_messages + +class LocalHumanAgent(Agent): + + def __init__(self, opt, shared=None): + super().__init__(opt) + self.id = 'localHuman' + + def observe(self, msg): + print(display_messages([msg])) + + def act(self): + obs = self.observation + reply = {} + reply['id'] = self.getID() + reply['text'] = input("Enter Your Reply: ") + return reply diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index 6523b2d0b58..15bae0c0cc4 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -61,6 +61,46 @@ def validate(observation): else: raise RuntimeError('Must return dictionary from act().') +def display_messages(msgs): + """Returns a string describing the set of messages provided""" + lines = [] + episode_done = False + for index, msg in enumerate(msgs): + if msg is None: + continue + if msg.get('episode_done', False): + episode_done = True + # Possibly indent the text (for the second speaker, if two). + space = '' + if len(msgs) == 2 and index == 1: + space = ' ' + if msg.get('reward', None) is not None: + lines.append(space + '[reward: {r}]'.format(r=msg['reward'])) + if msg.get('text', ''): + ID = '[' + msg['id'] + ']: ' if 'id' in msg else '' + lines.append(space + ID + msg['text']) + if msg.get('labels', False): + lines.append(space + ('[labels: {}]'.format( + '|'.join(msg['labels'])))) + if msg.get('label_candidates', False): + cand_len = len(msg['label_candidates']) + if cand_len <= 10: + lines.append(space + ('[cands: {}]'.format( + '|'.join(msg['label_candidates'])))) + else: + # select five label_candidates from the candidate set, + # can't slice in because it's a set + cand_iter = iter(msg['label_candidates']) + display_cands = (next(cand_iter) for _ in range(5)) + # print those cands plus how many cands remain + lines.append(space + ('[cands: {}{}]'.format( + '|'.join(display_cands), + '| ...and {} more'.format(cand_len - 5) + ))) + if episode_done: + lines.append('- - - - - - - - - - - - - - - - - - - - -') + return '\n'.join(lines) + class World(object): """Empty parent providing null definitions of API functions for Worlds. @@ -91,40 +131,7 @@ def display(self): By default, display the messages between the agents.""" if not hasattr(self, 'acts'): return '' - lines = [] - for index, msg in enumerate(self.acts): - if msg is None: - continue - # Possibly indent the text (for the second speaker, if two). - space = '' - if len(self.acts) == 2 and index == 1: - space = ' ' - if msg.get('reward', None) is not None: - lines.append(space + '[reward: {r}]'.format(r=msg['reward'])) - if msg.get('text', ''): - ID = '[' + msg['id'] + ']: ' if 'id' in msg else '' - lines.append(space + ID + msg['text']) - if msg.get('labels', False): - lines.append(space + ('[labels: {}]'.format( - '|'.join(msg['labels'])))) - if msg.get('label_candidates', False): - cand_len = len(msg['label_candidates']) - if cand_len <= 10: - lines.append(space + ('[cands: {}]'.format( - '|'.join(msg['label_candidates'])))) - else: - # select five label_candidates from the candidate set, - # can't slice in because it's a set - cand_iter = iter(msg['label_candidates']) - display_cands = (next(cand_iter) for _ in range(5)) - # print those cands plus how many cands remain - lines.append(space + ('[cands: {}{}]'.format( - '|'.join(display_cands), - '| ...and {} more'.format(cand_len - 5) - ))) - if self.episode_done(): - lines.append('- - - - - - - - - - - - - - - - - - - - -') - return '\n'.join(lines) + return display_messages(self.acts) def episode_done(self): """Whether the episode is done or not. """