Skip to content

Commit

Permalink
Correct fps/mps estimate for imitation/q-learning.
Browse files Browse the repository at this point in the history
  • Loading branch information
vladfi1 committed Oct 23, 2024
1 parent 053032f commit 2f43684
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
9 changes: 2 additions & 7 deletions slippi_ai/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
self.data_source = data_source
self.hidden_state = learner.initial_state(data_source.batch_size)
self.step_kwargs = step_kwargs
self.total_frames = 0
self.data_profiler = utils.Profiler()
self.step_profiler = utils.Profiler()

Expand All @@ -63,12 +62,8 @@ def step(self, compiled: bool = True) -> tuple[dict, data_lib.Batch]:
with self.step_profiler:
stats, self.hidden_state = self.learner.step(
batch, self.hidden_state, compile=compiled, **self.step_kwargs)
num_frames = batch.frames.state_action.state.stage.size
self.total_frames += num_frames
stats.update(
epoch=epoch,
num_frames=num_frames,
total_frames=self.total_frames,
)
return stats, batch

Expand Down Expand Up @@ -352,6 +347,7 @@ def save():
logging.info('loss post-restore: %f', train_loss)

FRAMES_PER_MINUTE = 60 * 60
FRAMES_PER_STEP = config.data.batch_size * config.data.unroll_length

step_tracker = utils.Tracker(step.numpy())
epoch_tracker = utils.Tracker(train_stats['epoch'])
Expand All @@ -365,8 +361,7 @@ def maybe_log(train_stats: dict):
elapsed_time = log_tracker.update(time.time())
total_steps = step.numpy()
steps = step_tracker.update(total_steps)
# assume num_frames is constant per step
num_frames = steps * train_stats['num_frames']
num_frames = steps * FRAMES_PER_STEP

epoch = train_stats['epoch']
delta_epoch = epoch_tracker.update(epoch)
Expand Down
4 changes: 2 additions & 2 deletions slippi_ai/train_q_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def save():
logging.info('loss post-restore: %f', _get_loss(test_manager.step()[0]))

FRAMES_PER_MINUTE = 60 * 60
FRAMES_PER_STEP = config.data.batch_size * config.data.unroll_length

step_tracker = utils.Tracker(step.numpy())
epoch_tracker = utils.Tracker(0)
Expand All @@ -331,8 +332,7 @@ def maybe_log(train_stats: dict):
elapsed_time = log_tracker.update(time.time())
total_steps = step.numpy()
steps = step_tracker.update(total_steps)
# assume num_frames is constant per step
num_frames = steps * train_stats['num_frames']
num_frames = steps * FRAMES_PER_STEP

epoch = train_stats['epoch']
delta_epoch = epoch_tracker.update(epoch)
Expand Down

0 comments on commit 2f43684

Please sign in to comment.