diff --git a/ml-agents/mlagents/trainers/bc/trainer.py b/ml-agents/mlagents/trainers/bc/trainer.py index a0251b3b64..757cdaef2f 100644 --- a/ml-agents/mlagents/trainers/bc/trainer.py +++ b/ml-agents/mlagents/trainers/bc/trainer.py @@ -122,13 +122,14 @@ def update_policy(self): """ self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length) batch_losses = [] + batch_size = self.n_sequences * self.policy.sequence_length + # We either divide the entire buffer into num_batches batches, or limit the number + # of batches to batches_per_epoch. num_batches = min( - len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences, + len(self.demonstration_buffer.update_buffer["actions"]) // batch_size, self.batches_per_epoch, ) - batch_size = self.n_sequences * self.policy.sequence_length - for i in range(0, num_batches * batch_size, batch_size): update_buffer = self.demonstration_buffer.update_buffer mini_batch = update_buffer.make_mini_batch(i, i + batch_size) diff --git a/ml-agents/mlagents/trainers/tests/test_bc.py b/ml-agents/mlagents/trainers/tests/test_bc.py index a5ed13e6e0..19c2a73187 100644 --- a/ml-agents/mlagents/trainers/tests/test_bc.py +++ b/ml-agents/mlagents/trainers/tests/test_bc.py @@ -25,7 +25,7 @@ def dummy_config(): use_recurrent: false sequence_length: 32 memory_size: 32 - batches_per_epoch: 1 + batches_per_epoch: 100 # Force code to use all possible batches batch_size: 32 summary_freq: 2000 max_steps: 4000 @@ -33,7 +33,7 @@ def dummy_config(): ) -def create_bc_trainer(dummy_config, is_discrete=False): +def create_bc_trainer(dummy_config, is_discrete=False, use_recurrent=False): mock_env = mock.Mock() if is_discrete: mock_brain = mb.create_mock_pushblock_brain() @@ -54,6 +54,7 @@ def create_bc_trainer(dummy_config, is_discrete=False): trainer_parameters["demo_path"] = ( os.path.dirname(os.path.abspath(__file__)) + "/test.demo" ) + trainer_parameters["use_recurrent"] = use_recurrent trainer = BCTrainer( mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0 ) @@ -61,8 +62,9 @@ def create_bc_trainer(dummy_config, is_discrete=False): return trainer, env -def test_bc_trainer_step(dummy_config): - trainer, env = create_bc_trainer(dummy_config) +@pytest.mark.parametrize("use_recurrent", [True, False]) +def test_bc_trainer_step(dummy_config, use_recurrent): + trainer, env = create_bc_trainer(dummy_config, use_recurrent=use_recurrent) # Test get_step assert trainer.get_step == 0 # Test update policy