diff --git a/slippi_ai/train_lib.py b/slippi_ai/train_lib.py index b51a8c0..1a56c3e 100644 --- a/slippi_ai/train_lib.py +++ b/slippi_ai/train_lib.py @@ -53,7 +53,6 @@ def __init__( self.data_source = data_source self.hidden_state = learner.initial_state(data_source.batch_size) self.step_kwargs = step_kwargs - self.total_frames = 0 self.data_profiler = utils.Profiler() self.step_profiler = utils.Profiler() @@ -63,12 +62,8 @@ def step(self, compiled: bool = True) -> tuple[dict, data_lib.Batch]: with self.step_profiler: stats, self.hidden_state = self.learner.step( batch, self.hidden_state, compile=compiled, **self.step_kwargs) - num_frames = batch.frames.state_action.state.stage.size - self.total_frames += num_frames stats.update( epoch=epoch, - num_frames=num_frames, - total_frames=self.total_frames, ) return stats, batch @@ -352,6 +347,7 @@ def save(): logging.info('loss post-restore: %f', train_loss) FRAMES_PER_MINUTE = 60 * 60 + FRAMES_PER_STEP = config.data.batch_size * config.data.unroll_length step_tracker = utils.Tracker(step.numpy()) epoch_tracker = utils.Tracker(train_stats['epoch']) @@ -365,8 +361,7 @@ def maybe_log(train_stats: dict): elapsed_time = log_tracker.update(time.time()) total_steps = step.numpy() steps = step_tracker.update(total_steps) - # assume num_frames is constant per step - num_frames = steps * train_stats['num_frames'] + num_frames = steps * FRAMES_PER_STEP epoch = train_stats['epoch'] delta_epoch = epoch_tracker.update(epoch) diff --git a/slippi_ai/train_q_lib.py b/slippi_ai/train_q_lib.py index 38f6d6f..ced05b1 100644 --- a/slippi_ai/train_q_lib.py +++ b/slippi_ai/train_q_lib.py @@ -318,6 +318,7 @@ def save(): logging.info('loss post-restore: %f', _get_loss(test_manager.step()[0])) FRAMES_PER_MINUTE = 60 * 60 + FRAMES_PER_STEP = config.data.batch_size * config.data.unroll_length step_tracker = utils.Tracker(step.numpy()) epoch_tracker = utils.Tracker(0) @@ -331,8 +332,7 @@ def maybe_log(train_stats: dict): elapsed_time = log_tracker.update(time.time()) total_steps = step.numpy() steps = step_tracker.update(total_steps) - # assume num_frames is constant per step - num_frames = steps * train_stats['num_frames'] + num_frames = steps * FRAMES_PER_STEP epoch = train_stats['epoch'] delta_epoch = epoch_tracker.update(epoch)