From a4f68bbec90ac16b838102b43ca9f1a6c4c2261b Mon Sep 17 00:00:00 2001 From: misko Date: Tue, 2 Jul 2024 14:54:03 +0000 Subject: [PATCH] fix segmask bool to float --- spf/model_training_and_inference/models/beamsegnet.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/spf/model_training_and_inference/models/beamsegnet.py b/spf/model_training_and_inference/models/beamsegnet.py index 23eed0c1..83ecd7ed 100644 --- a/spf/model_training_and_inference/models/beamsegnet.py +++ b/spf/model_training_and_inference/models/beamsegnet.py @@ -619,7 +619,11 @@ def likelihood(self, x, y, sigma_eps=0.01, smoothing_prob=0.0001): def mse(self, x, y): # not sure why we cant wrap around for torch.pi/2.... # assert np.isclose(self.max_angle, torch.pi, atol=0.05) - return (torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2).mean() + if self.max_angle == torch.pi: + return ( + torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2 + ).mean() + return ((x[:, 0] - y[:, 0]) ** 2).mean() def loglikelihood(self, x, y, log_eps=0.000000001): return torch.log(self.likelihood(x, y) + log_eps) @@ -809,7 +813,9 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask): # segmentation loss segmentation_loss = 0 if not self.skip_segmentation: - segmentation_loss = ((output["segmentation"] - seg_mask) ** 2).mean() + segmentation_loss = ( + (output["segmentation"] - seg_mask.to(float)) ** 2 + ).mean() mse_loss = self.beamnet.mse(output["pred_theta"], y_rad_reduced)