-
Notifications
You must be signed in to change notification settings - Fork 8
/
gen_action_sets.py
118 lines (94 loc) · 3.69 KB
/
gen_action_sets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from arguments import get_args
import random
import os.path as osp
import os
import numpy as np
from envs.gym_minigrid.action_sets import create_action_bank
from envs.create_game.tool_gen import ToolGenerator
from envs.block_stack.poly_gen import *
from envs.recogym.action_set import get_rnd_action_sets
# Ex usage python scripts/gen_action_sets.py --action-seg-loc envs/action_segs_new --env-name MiniGrid
args = get_args()
if args.env_name.startswith('MiniGrid'):
create_action_bank(args)
n_skills = len(args.action_bank)
if args.exp_type == 'rnd':
train, test = get_rnd_action_sets(n_skills)
elif args.exp_type == 'all':
train = np.arange(n_skills)
test = np.arange(n_skills)
random.shuffle(train)
random.shuffle(test)
else:
raise ValueError('Invalid exp type')
new_dir = osp.join(args.action_seg_loc, 'grid_%s' % args.exp_type)
if not osp.exists(new_dir):
os.makedirs(new_dir)
with open(osp.join(new_dir, 'set_train.npy'), 'wb') as f:
np.save(f, train)
with open(osp.join(new_dir, 'set_test.npy'), 'wb') as f:
np.save(f, test)
print('Training set: ', len(train))
print('Test set: ', len(test))
elif args.env_name.startswith('Create'):
tool_gen = ToolGenerator(args.gran_factor)
train_tools, test_tools = tool_gen.get_train_test_split(args)
# Randomize here
np.random.shuffle(train_tools)
np.random.shuffle(test_tools)
add_str = ('_' + args.split_type) if (args.split_type is not None and 'New' in args.exp_type) else ''
new_dir = osp.join(args.action_seg_loc, 'create_' + args.exp_type + add_str)
if not osp.exists(new_dir):
os.makedirs(new_dir)
train_filename = osp.join(new_dir, 'set_train.npy')
with open(train_filename, 'wb') as f:
np.save(f, train_tools)
test_filename = osp.join(new_dir, 'set_test.npy')
with open(test_filename, 'wb') as f:
np.save(f, test_tools)
elif args.env_name.startswith('Stack'):
all_polys, _, polygon_types = gen_polys('envs/block_stack/assets/stl/')
if args.exp_type == 'rnd':
train, test = rnd_train_test_split(all_polys, polygon_types)
elif args.exp_type == 'full':
train, test = full_train_test_split(all_polys)
elif args.exp_type == 'all':
train = np.arange(len(all_polys))
test = np.arange(len(all_polys))
random.shuffle(train)
random.shuffle(test)
else:
raise ValueError('Invalid exp type')
train_types = set([all_polys[i].type for i in train])
test_types = set([all_polys[i].type for i in test])
print('')
print('In train (%i)' % len(train))
for t in train_types:
print(' - %s' % (t))
print('')
print('In test (%i)' % len(test))
for t in test_types:
print(' - %s' % (t))
new_dir = osp.join(args.action_seg_loc, 'stack_%s' % args.exp_type)
if not osp.exists(new_dir):
os.makedirs(new_dir)
with open(osp.join(new_dir, 'set_train.npy'), 'wb') as f:
np.save(f, train)
with open(osp.join(new_dir, 'set_test.npy'), 'wb') as f:
np.save(f, test)
elif args.env_name.startswith('Reco'):
if args.exp_type == 'rnd':
train, test = get_rnd_action_sets(args.reco_n_prods)
else:
raise ValueError('Invalid exp type')
new_dir = osp.join(args.action_seg_loc, 'reco_%s' % args.exp_type)
if not osp.exists(new_dir):
os.makedirs(new_dir)
with open(osp.join(new_dir, 'set_train.npy'), 'wb') as f:
np.save(f, train)
with open(osp.join(new_dir, 'set_test.npy'), 'wb') as f:
np.save(f, test)
print('Training set: ', len(train))
print('Test set: ', len(test))
else:
print('Unspecified Environment!')