-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathembed_test.py
144 lines (116 loc) · 4.15 KB
/
embed_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from unittest import TestCase
import dispatch_embed
import experiment
import observer
import unittest
from feeder import LstmFeeder, MlpFeeder
from loop import (
OpenLoop,
ClosedLoop,
ClosedLoopMultiSample,
ClosedLoopOverwriteLatents,
uniform_noise,
invert_noise,
clamp_noise,
)
from system import System
def set_up_experiment(params):
print(params["ref_steps"])
system = System(
ref_path=params["ref_path"],
model_dir=params["model_dir"],
dataset=params["dataset"],
stac_params=params["stac_params"],
offset_path=params["offset_path"],
start_step=0,
ref_steps=tuple(params["ref_steps"]),
# start_step=357500, # To test the end loop handling
torque_actuators=params["torque_actuators"],
latent_noise=params["latent_noise"],
noise_gain=params["noise_gain"],
)
if params["lstm"]:
obs = observer.LstmObserver(system.environment, params["save_dir"])
feeder = LstmFeeder()
else:
obs = observer.MlpObserver(system.environment, params["save_dir"])
feeder = MlpFeeder()
loop = ClosedLoop(
system.environment, feeder, start_step=0, video_length=2500, action_noise=False
)
# loop = ClosedLoopOverwriteLatents(system.environment, feeder, start_step=0, video_length=2500)
return experiment.Experiment(system, obs, loop)
def change_exp_model(exp):
is_mlp = isinstance(exp.observer, observer.MlpObserver)
if is_mlp:
exp.system.model_dir = params[1]["model_dir"]
exp.observer.setup_model_ovservables(observer.LSTM_NETWORK_FEATURES)
exp.looper.feeder = LstmFeeder()
else:
exp.system.model_dir = params[0]["model_dir"]
exp.observer.setup_model_ovservables(observer.MLP_NETWORK_FEATURES)
exp.looper.feeder = MlpFeeder()
return exp
# TODO: figure out how to cleanly setup dm_control environment
# without camera namspace conflicts.
# Hack to avoid problems with overlapping camera namespace.
params = dispatch_embed.build_params("test_params.yaml")
# Test MLP
EXP = set_up_experiment(params[3])
# Test LSTM
# EXP = set_up_experiment(params[1])
# class ExperimentTest(absltest.TestCase):
# def test_setup(self):
# self.assertTrue(isinstance(EXP, experiment.Experiment))
# def test_run_mlp(self):
# EXP.run()
# class ObserverTest(absltest.TestCase):
# def clear_observations(self):
# EXP.observer.cam_list = []
# def setUp(self):
# self.clear_observations()
# def tearDown(self):
# self.clear_observations()
# def test_grab_frame_no_segmentation_mlp(self):
# self.grab_frame(EXP, False)
# def test_grab_frame_segmentation_mlp(self):
# self.grab_frame(EXP, True)
# def grab_frame(self, exp, seg_frames):
# exp.observer.seg_frames = seg_frames
# exp.observer.grab_frame()
# self.assertEqual(exp.observer.cam_list[0].shape, tuple(observer.IMAGE_SIZE))
class LoopTest(TestCase):
def loop(self, loop_fn, exp):
exp.looper = loop_fn(
exp.system.environment,
exp.looper.feeder,
exp.looper.start_step,
exp.looper.video_length,
)
exp.run()
# def test_open(self):
# self.loop(OpenLoop, EXP)
def test_closed(self):
self.loop(ClosedLoop, EXP)
# def test_closed_loop_overwrite_latents(self):
# EXP.looper = ClosedLoopOverwriteLatents(
# EXP.system.environment,
# EXP.looper.feeder,
# EXP.looper.start_step,
# EXP.looper.video_length,
# lambda sess, feed_dict: clamp_noise(sess, feed_dict, "standard"),
# action_noise=True,
# )
# EXP.run()
# def test_closed_multi_sample(self):
# self.loop(ClosedLoopMultiSample, EXP)
# def test_end_loop(self):
# def test_open_lstm(self):
# self.loop(experiment.OpenLoop, lstm_exp)
# def test_closed_lstm(self):
# self.loop(experiment.ClosedLoop, lstm_exp)
if __name__ == "__main__":
unittest.main()