Skip to content

Commit

Permalink
Fix scaling: prune RCPs by mean epochs
Browse files Browse the repository at this point in the history
  • Loading branch information
pgmpablo157321 committed May 22, 2024
1 parent 4054425 commit f3057c7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mlperf_logging/rcp_checker/rcp_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _prune_rcps(self):
# Step 1
# Find point with fastest convergence and prune all point with smaller batch size
# In that way the min batch size point will have the fastest convergenece
fastest_conv = min(min_epochs, key=lambda rc: rc['Min Epochs'])
fastest_conv = min(min_epochs, key=lambda rc: rc['RCP Mean'])
min_epochs = list(filter(lambda rc: rc['BS'] >= fastest_conv['BS'], min_epochs))

# Step 2
Expand All @@ -249,7 +249,7 @@ def _prune_rcps(self):
rcp_max = min_epochs[i+1]
bs = min_epochs[i]['BS']
name, rcp = self._create_interp_rcp(bs, rcp_min, rcp_max)
if min_epochs[i]['Min Epochs'] > rcp['Min Epochs']:
if min_epochs[i]['RCP Mean'] > rcp['RCP Mean']:
del min_epochs[i]
i = i-1
list_len = list_len - 1
Expand Down

0 comments on commit f3057c7

Please sign in to comment.