Skip to content

Commit

Permalink
Fix batch size issue with BC (#2965)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ervin T authored and Chris Elion committed Nov 25, 2019
1 parent 2287c06 commit 5d5fe57
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
7 changes: 4 additions & 3 deletions ml-agents/mlagents/trainers/bc/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,14 @@ def update_policy(self):
"""
self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length)
batch_losses = []
batch_size = self.n_sequences * self.policy.sequence_length
# We either divide the entire buffer into num_batches batches, or limit the number
# of batches to batches_per_epoch.
num_batches = min(
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,
len(self.demonstration_buffer.update_buffer["actions"]) // batch_size,
self.batches_per_epoch,
)

batch_size = self.n_sequences * self.policy.sequence_length

for i in range(0, num_batches * batch_size, batch_size):
update_buffer = self.demonstration_buffer.update_buffer
mini_batch = update_buffer.make_mini_batch(i, i + batch_size)
Expand Down
10 changes: 6 additions & 4 deletions ml-agents/mlagents/trainers/tests/test_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def dummy_config():
use_recurrent: false
sequence_length: 32
memory_size: 32
batches_per_epoch: 1
batches_per_epoch: 100 # Force code to use all possible batches
batch_size: 32
summary_freq: 2000
max_steps: 4000
"""
)


def create_bc_trainer(dummy_config, is_discrete=False):
def create_bc_trainer(dummy_config, is_discrete=False, use_recurrent=False):
mock_env = mock.Mock()
if is_discrete:
mock_brain = mb.create_mock_pushblock_brain()
Expand All @@ -54,15 +54,17 @@ def create_bc_trainer(dummy_config, is_discrete=False):
trainer_parameters["demo_path"] = (
os.path.dirname(os.path.abspath(__file__)) + "/test.demo"
)
trainer_parameters["use_recurrent"] = use_recurrent
trainer = BCTrainer(
mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0
)
trainer.demonstration_buffer = mb.simulate_rollout(env, trainer.policy, 100)
return trainer, env


def test_bc_trainer_step(dummy_config):
trainer, env = create_bc_trainer(dummy_config)
@pytest.mark.parametrize("use_recurrent", [True, False])
def test_bc_trainer_step(dummy_config, use_recurrent):
trainer, env = create_bc_trainer(dummy_config, use_recurrent=use_recurrent)
# Test get_step
assert trainer.get_step == 0
# Test update policy
Expand Down

0 comments on commit 5d5fe57

Please sign in to comment.