From ad2319ddaf007f52e72201d102388c7f76291db7 Mon Sep 17 00:00:00 2001 From: misko Date: Wed, 3 Jul 2024 15:59:47 +0000 Subject: [PATCH] update val subsetting; change mse to wrap for pi/2 --- .../models/beamsegnet.py | 8 +++---- spf/notebooks/simple_train_filter.py | 23 ++++++++++++++++--- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/spf/model_training_and_inference/models/beamsegnet.py b/spf/model_training_and_inference/models/beamsegnet.py index 83ecd7ed..eeb2524c 100644 --- a/spf/model_training_and_inference/models/beamsegnet.py +++ b/spf/model_training_and_inference/models/beamsegnet.py @@ -619,11 +619,9 @@ def likelihood(self, x, y, sigma_eps=0.01, smoothing_prob=0.0001): def mse(self, x, y): # not sure why we cant wrap around for torch.pi/2.... # assert np.isclose(self.max_angle, torch.pi, atol=0.05) - if self.max_angle == torch.pi: - return ( - torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2 - ).mean() - return ((x[:, 0] - y[:, 0]) ** 2).mean() + # if self.max_angle == torch.pi: + return (torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2).mean() + # return ((x[:, 0] - y[:, 0]) ** 2).mean() def loglikelihood(self, x, y, log_eps=0.000000001): return torch.log(self.likelihood(x, y) + log_eps) diff --git a/spf/notebooks/simple_train_filter.py b/spf/notebooks/simple_train_filter.py index ad2b51c0..6dd01320 100644 --- a/spf/notebooks/simple_train_filter.py +++ b/spf/notebooks/simple_train_filter.py @@ -4,7 +4,7 @@ import torch from matplotlib import pyplot as plt from tqdm import tqdm - +from random import shuffle import wandb from spf.dataset.spf_dataset import ( v5_collate_beamsegnet, @@ -172,6 +172,7 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask): beamnet_mse = self.beam_m.mse( output["pred_theta"], y_rad_reduced.reshape(-1, 1) ) + breakpoint() loss = transformer_loss + beamnet_loss return { "loss": loss, @@ -216,7 +217,10 @@ def simple_train(args): else: n = len(complete_ds) train_idxs = range(int((1.0 - args.val_holdout_fraction) * n)) - val_idxs = range(train_idxs[-1] + 1, n) + val_idxs = list(range(train_idxs[-1] + 1, n)) + + shuffle(val_idxs) + val_idxs = val_idxs[: max(1, int(len(val_idxs) * args.val_subsample_fraction))] train_ds = torch.utils.data.Subset(complete_ds, train_idxs) val_ds = torch.utils.data.Subset(complete_ds, val_idxs) @@ -304,6 +308,9 @@ def new_log(): for step, batch_data in tqdm( enumerate(train_dataloader), total=len(train_dataloader) ): + # if step > 0: + # return + # continue if step % args.val_every == 0: m.eval() save_everything( @@ -499,6 +506,12 @@ def get_parser(): required=False, default=0.2, ) + parser.add_argument( + "--val-subsample-fraction", + type=float, + required=False, + default=0.05, + ) parser.add_argument( "--hidden", type=int, @@ -598,9 +611,13 @@ def get_parser(): return parser +from pyinstrument import Profiler + if __name__ == "__main__": parser = get_parser() args = parser.parse_args() # with Profile() as profile: + # with Profiler(interval=0.1) as profiler: simple_train(args) - # (Stats(profile).strip_dirs().sort_stats(SortKey.TIME).print_stats(200)) + # # (Stats(profile).strip_dirs().sort_stats(SortKey.TIME).print_stats(200)) + # profiler.print()