Skip to content

Commit

Permalink
[rllib] Better error message for unsupported non-atari image observat…
Browse files Browse the repository at this point in the history
…ion sizes (#3444)
  • Loading branch information
ericl authored Dec 3, 2018
1 parent 4abafd7 commit 7abfbfd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 10 additions & 7 deletions python/ray/rllib/models/visionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options):
inputs = input_dict["obs"]
filters = options.get("conv_filters")
if not filters:
filters = get_filter_config(options)
filters = get_filter_config(inputs)

activation = get_activation_fn(options.get("conv_activation"))

Expand Down Expand Up @@ -47,7 +47,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options):
return flatten(fc2), flatten(fc1)


def get_filter_config(options):
def get_filter_config(inputs):
filters_84x84 = [
[16, [8, 8], 4],
[32, [4, 4], 2],
Expand All @@ -58,12 +58,15 @@ def get_filter_config(options):
[32, [4, 4], 2],
[256, [11, 11], 1],
]
dim = options.get("dim")
if dim == 84:
shape = inputs.shape.as_list()[1:]
if len(shape) == 3 and shape[:2] == [84, 84]:
return filters_84x84
elif dim == 42:
elif len(shape) == 3 and shape[:2] == [42, 42]:
return filters_42x42
else:
raise ValueError(
"No default configuration for image size={}".format(dim) +
", you must specify `conv_filters` manually as a model option.")
"No default configuration for obs input {}".format(inputs) +
", you must specify `conv_filters` manually as a model option. "
"Default configurations are only available for inputs of size "
"[?, 42, 42, K] and [?, 84, 84, K]. You may alternatively want "
"to use a custom model or preprocessor.")
4 changes: 2 additions & 2 deletions python/ray/rllib/test/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ def testDefaultModels(self):

with tf.variable_scope("test1"):
p1 = ModelCatalog.get_model({
"obs": np.zeros((10, 3), dtype=np.float32)
"obs": tf.zeros((10, 3), dtype=tf.float32)
}, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {})
self.assertEqual(type(p1), FullyConnectedNetwork)

with tf.variable_scope("test2"):
p2 = ModelCatalog.get_model({
"obs": np.zeros((10, 84, 84, 3), dtype=np.float32)
"obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32)
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {})
self.assertEqual(type(p2), VisionNetwork)

Expand Down

0 comments on commit 7abfbfd

Please sign in to comment.