Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

local human agent #110

Merged
merged 2 commits into from
May 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions parlai/agents/local_human/agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
"""Agent does gets the local keyboard input in the act() function.
Example: python examples/eval_model.py -m local_human -t babi:Task1k:1 -dt valid
"""

from parlai.core.agents import Agent
from parlai.core.worlds import display_messages

class LocalHumanAgent(Agent):

def __init__(self, opt, shared=None):
super().__init__(opt)
self.id = 'localHuman'

def observe(self, msg):
print(display_messages([msg]))

def act(self):
obs = self.observation
reply = {}
reply['id'] = self.getID()
reply['text'] = input("Enter Your Reply: ")
return reply
75 changes: 41 additions & 34 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,46 @@ def validate(observation):
else:
raise RuntimeError('Must return dictionary from act().')

def display_messages(msgs):
"""Returns a string describing the set of messages provided"""
lines = []
episode_done = False
for index, msg in enumerate(msgs):
if msg is None:
continue
if msg.get('episode_done', False):
episode_done = True
# Possibly indent the text (for the second speaker, if two).
space = ''
if len(msgs) == 2 and index == 1:
space = ' '
if msg.get('reward', None) is not None:
lines.append(space + '[reward: {r}]'.format(r=msg['reward']))
if msg.get('text', ''):
ID = '[' + msg['id'] + ']: ' if 'id' in msg else ''
lines.append(space + ID + msg['text'])
if msg.get('labels', False):
lines.append(space + ('[labels: {}]'.format(
'|'.join(msg['labels']))))
if msg.get('label_candidates', False):
cand_len = len(msg['label_candidates'])
if cand_len <= 10:
lines.append(space + ('[cands: {}]'.format(
'|'.join(msg['label_candidates']))))
else:
# select five label_candidates from the candidate set,
# can't slice in because it's a set
cand_iter = iter(msg['label_candidates'])
display_cands = (next(cand_iter) for _ in range(5))
# print those cands plus how many cands remain
lines.append(space + ('[cands: {}{}]'.format(
'|'.join(display_cands),
'| ...and {} more'.format(cand_len - 5)
)))
if episode_done:
lines.append('- - - - - - - - - - - - - - - - - - - - -')
return '\n'.join(lines)


class World(object):
"""Empty parent providing null definitions of API functions for Worlds.
Expand Down Expand Up @@ -91,40 +131,7 @@ def display(self):
By default, display the messages between the agents."""
if not hasattr(self, 'acts'):
return ''
lines = []
for index, msg in enumerate(self.acts):
if msg is None:
continue
# Possibly indent the text (for the second speaker, if two).
space = ''
if len(self.acts) == 2 and index == 1:
space = ' '
if msg.get('reward', None) is not None:
lines.append(space + '[reward: {r}]'.format(r=msg['reward']))
if msg.get('text', ''):
ID = '[' + msg['id'] + ']: ' if 'id' in msg else ''
lines.append(space + ID + msg['text'])
if msg.get('labels', False):
lines.append(space + ('[labels: {}]'.format(
'|'.join(msg['labels']))))
if msg.get('label_candidates', False):
cand_len = len(msg['label_candidates'])
if cand_len <= 10:
lines.append(space + ('[cands: {}]'.format(
'|'.join(msg['label_candidates']))))
else:
# select five label_candidates from the candidate set,
# can't slice in because it's a set
cand_iter = iter(msg['label_candidates'])
display_cands = (next(cand_iter) for _ in range(5))
# print those cands plus how many cands remain
lines.append(space + ('[cands: {}{}]'.format(
'|'.join(display_cands),
'| ...and {} more'.format(cand_len - 5)
)))
if self.episode_done():
lines.append('- - - - - - - - - - - - - - - - - - - - -')
return '\n'.join(lines)
return display_messages(self.acts)

def episode_done(self):
"""Whether the episode is done or not. """
Expand Down