Skip to content

Commit

Permalink
Added variation flag to dataset_generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
stepjam committed Jun 21, 2020
1 parent 0f48889 commit 5a82cef
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tools/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
'The number of parallel processes during collection.')
flags.DEFINE_integer('episodes_per_task', 10,
'The number of episodes to collect per task.')
flags.DEFINE_integer('variations', -1,
'Number of variations to collect per task. -1 for all.')


def check_and_make(dir):
Expand Down Expand Up @@ -187,7 +189,10 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
my_variation_count = variation_count.value
t = tasks[task_index.value]
task_env = rlbench_env.get_task(t)
if my_variation_count >= task_env.variation_count():
var_target = task_env.variation_count()
if FLAGS.variations >= 0:
var_target = np.minimum(FLAGS.variations, var_target)
if my_variation_count >= var_target:
# If we have reached the required number of variations for this
# task, then move on to the next task.
variation_count.value = my_variation_count = 0
Expand Down Expand Up @@ -250,7 +255,6 @@ def run(i, lock, task_index, variation_count, results, file_lock, tasks):
break

results[i] = tasks_with_problems

rlbench_env.shutdown()


Expand Down

0 comments on commit 5a82cef

Please sign in to comment.