forked from alfworld/alfworld
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathalfworld-play-tw
executable file
·98 lines (75 loc) · 3.51 KB
/
alfworld-play-tw
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python
import os
import json
import glob
import random
import argparse
from os.path import join as pjoin
import textworld
from textworld.agents import HumanAgent
import textworld.gym
from alfworld.info import ALFWORLD_DATA
from alfworld.agents.utils.misc import add_task_to_grammar
from alfworld.agents.environment.alfred_tw_env import AlfredExpert, AlfredDemangler, AlfredExpertType
def main(args):
print(f"Playing '{args.problem}'.")
GAME_LOGIC = {
"pddl_domain": open(args.domain).read(),
"grammar": open(args.grammar).read(),
}
# load state and trajectory files
pddl_file = os.path.join(args.problem, 'initial_state.pddl')
json_file = os.path.join(args.problem, 'traj_data.json')
with open(json_file, 'r') as f:
traj_data = json.load(f)
GAME_LOGIC['grammar'] = add_task_to_grammar(GAME_LOGIC['grammar'], traj_data)
# dump game file
gamedata = dict(**GAME_LOGIC, pddl_problem=open(pddl_file).read())
gamefile = os.path.join(os.path.dirname(pddl_file), 'game.tw-pddl')
json.dump(gamedata, open(gamefile, "w"))
expert = AlfredExpert(expert_type=AlfredExpertType.PLANNER)
# expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)
# register a new Gym environment.
request_infos = textworld.EnvInfos(won=True, admissible_commands=True, score=True, max_score=True, intermediate_reward=True, extras=["expert_plan"])
env_id = textworld.gym.register_game(gamefile, request_infos,
max_episode_steps=1000000,
wrappers=[AlfredDemangler(), expert])
# reset env
env = textworld.gym.make(env_id)
obs, infos = env.reset()
# human agent
agent = HumanAgent(autocompletion=True, oracle=(expert.expert_type==AlfredExpertType.PLANNER))
agent.reset(env)
while True:
print(obs)
cmd = agent.act(infos, 0, False)
if cmd == "ipdb":
from ipdb import set_trace; set_trace()
continue
obs, score, done, infos = env.step(cmd)
if done:
print("You won!")
break
if __name__ == "__main__":
description = "Play the abstract text version of an ALFRED environment."
parser = argparse.ArgumentParser(description=description)
parser.add_argument("problem", nargs="?", default=None,
help="Path to a folder containing PDDL and traj_data files."
f"Default: pick one at random found in {ALFWORLD_DATA}")
parser.add_argument("--domain",
default=pjoin(ALFWORLD_DATA, "logic", "alfred.pddl"),
help="Path to a PDDL file describing the domain."
" Default: `%(default)s`.")
parser.add_argument("--grammar",
default=pjoin(ALFWORLD_DATA, "logic", "alfred.twl2"),
help="Path to a TWL2 file defining the grammar used to generated text feedbacks."
" Default: `%(default)s`.")
args = parser.parse_args()
if args.problem is None:
problems = glob.glob(pjoin(ALFWORLD_DATA, "**", "initial_state.pddl"), recursive=True)
# Remove problem which contains movable receptacles.
problems = [p for p in problems if "movable_recep" not in p]
args.problem = os.path.dirname(random.choice(problems))
if "movable_recep" in args.problem:
raise ValueError("This problem contains movable receptacles, which is not supported by ALFWorld.")
main(args)