Skip to content

Commit

Permalink
Add a test for trainer metrics (#399)
Browse files Browse the repository at this point in the history
This patch adds a test for the metrics in the trainer class to ensure that they are actually set. This patch just directly inspects the state rather than trying to ensure that we log to tensorboard too, but this should be good enough, and is definitely better than what we had before (nothing).
  • Loading branch information
boomanaiden154 authored Dec 19, 2024
1 parent e314df5 commit fffde33
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion compiler_opt/rl/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@


def _create_test_data(batch_size, sequence_length):
# Use the value zero, which signals the beginning of a sequence, which
# allows us to test the num_trajectories metric.
test_trajectory = trajectory.Trajectory(
step_type=tf.fill([batch_size, sequence_length], 1),
step_type=tf.fill([batch_size, sequence_length], 0),
observation={
'callee_users':
tf.fill([batch_size, sequence_length],
Expand Down Expand Up @@ -131,6 +133,27 @@ def test_training_with_multiple_times(self):
test_trainer.train(dataset_iter, monitor_dict, num_iterations=10)
self.assertEqual(20, test_trainer._global_step.numpy())

def test_training_metrics(self):
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
self._time_step_spec,
self._action_spec,
self._network,
tf.compat.v1.train.AdadeltaOptimizer(),
num_outer_dims=2)
test_trainer = trainer.Trainer(
root_dir=self.get_temp_dir(), agent=test_agent, summary_log_interval=1)
self.assertEqual(0, test_trainer._data_action_mean.result().numpy())
self.assertEqual(0, test_trainer._data_reward_mean.result().numpy())
self.assertEqual(0, test_trainer._num_trajectories.result().numpy())

dataset_iter = _create_test_data(batch_size=3, sequence_length=3)
monitor_dict = {'default': {'test': 1}}
test_trainer.train(dataset_iter, monitor_dict, num_iterations=10)

self.assertEqual(1, test_trainer._data_action_mean.result().numpy())
self.assertEqual(2, test_trainer._data_reward_mean.result().numpy())
self.assertEqual(90, test_trainer._num_trajectories.result().numpy())

def test_inference(self):
test_agent = behavioral_cloning_agent.BehavioralCloningAgent(
self._time_step_spec,
Expand Down

0 comments on commit fffde33

Please sign in to comment.