Skip to content

Commit

Permalink
fix: end of training metrics computation
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 7, 2023
1 parent a0209a4 commit e1f213c
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,18 +358,27 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal):

###Let's compute final FID
if rank_0 and opt.train_compute_metrics_test:
cur_fid = model.compute_metrics_test()
with torch.no_grad():
if use_temporal:
dataloaders_test = zip(dataloader_test, dataloader_test_temporal)
else:
dataloaders_test = zip(dataloader_test)
model.compute_metrics_test(
dataloaders_test, opt.train_epoch_count - 1, total_iters
)
cur_metrics = model.get_current_metrics()
path_json = os.path.join(opt.checkpoints_dir, opt.name, "eval_results.json")

if os.path.exists(path_json):
with open(path_json, "r") as loadfile:
data = json.load(loadfile)

with open(path_json, "w+") as outfile:
data = {}
data["fid_%s_img_%s_epochs" % (opt.data_max_dataset_size, epoch)] = float(
cur_fid.item()
)
for key, value in cur_metrics.items():
data[
"%s_%s_img_%s"
% (key, opt.data_max_dataset_size, opt.train_epoch_count)
] = float(value)
json.dump(data, outfile)

if rank_0:
Expand Down

0 comments on commit e1f213c

Please sign in to comment.