Skip to content

Commit

Permalink
update val subsetting; change mse to wrap for pi/2
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 3, 2024
1 parent a4f68bb commit ad2319d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
8 changes: 3 additions & 5 deletions spf/model_training_and_inference/models/beamsegnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,9 @@ def likelihood(self, x, y, sigma_eps=0.01, smoothing_prob=0.0001):
def mse(self, x, y):
# not sure why we cant wrap around for torch.pi/2....
# assert np.isclose(self.max_angle, torch.pi, atol=0.05)
if self.max_angle == torch.pi:
return (
torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2
).mean()
return ((x[:, 0] - y[:, 0]) ** 2).mean()
# if self.max_angle == torch.pi:
return (torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2).mean()
# return ((x[:, 0] - y[:, 0]) ** 2).mean()

def loglikelihood(self, x, y, log_eps=0.000000001):
return torch.log(self.likelihood(x, y) + log_eps)
Expand Down
23 changes: 20 additions & 3 deletions spf/notebooks/simple_train_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm

from random import shuffle
import wandb
from spf.dataset.spf_dataset import (
v5_collate_beamsegnet,
Expand Down Expand Up @@ -172,6 +172,7 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):
beamnet_mse = self.beam_m.mse(
output["pred_theta"], y_rad_reduced.reshape(-1, 1)
)
breakpoint()
loss = transformer_loss + beamnet_loss
return {
"loss": loss,
Expand Down Expand Up @@ -216,7 +217,10 @@ def simple_train(args):
else:
n = len(complete_ds)
train_idxs = range(int((1.0 - args.val_holdout_fraction) * n))
val_idxs = range(train_idxs[-1] + 1, n)
val_idxs = list(range(train_idxs[-1] + 1, n))

shuffle(val_idxs)
val_idxs = val_idxs[: max(1, int(len(val_idxs) * args.val_subsample_fraction))]

train_ds = torch.utils.data.Subset(complete_ds, train_idxs)
val_ds = torch.utils.data.Subset(complete_ds, val_idxs)
Expand Down Expand Up @@ -304,6 +308,9 @@ def new_log():
for step, batch_data in tqdm(
enumerate(train_dataloader), total=len(train_dataloader)
):
# if step > 0:
# return
# continue
if step % args.val_every == 0:
m.eval()
save_everything(
Expand Down Expand Up @@ -499,6 +506,12 @@ def get_parser():
required=False,
default=0.2,
)
parser.add_argument(
"--val-subsample-fraction",
type=float,
required=False,
default=0.05,
)
parser.add_argument(
"--hidden",
type=int,
Expand Down Expand Up @@ -598,9 +611,13 @@ def get_parser():
return parser


from pyinstrument import Profiler

if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
# with Profile() as profile:
# with Profiler(interval=0.1) as profiler:
simple_train(args)
# (Stats(profile).strip_dirs().sort_stats(SortKey.TIME).print_stats(200))
# # (Stats(profile).strip_dirs().sort_stats(SortKey.TIME).print_stats(200))
# profiler.print()

0 comments on commit ad2319d

Please sign in to comment.