Skip to content

Commit

Permalink
fix segmask bool to float
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 2, 2024
1 parent 615c695 commit a4f68bb
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions spf/model_training_and_inference/models/beamsegnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,11 @@ def likelihood(self, x, y, sigma_eps=0.01, smoothing_prob=0.0001):
def mse(self, x, y):
# not sure why we cant wrap around for torch.pi/2....
# assert np.isclose(self.max_angle, torch.pi, atol=0.05)
return (torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2).mean()
if self.max_angle == torch.pi:
return (
torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2
).mean()
return ((x[:, 0] - y[:, 0]) ** 2).mean()

def loglikelihood(self, x, y, log_eps=0.000000001):
return torch.log(self.likelihood(x, y) + log_eps)
Expand Down Expand Up @@ -809,7 +813,9 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):
# segmentation loss
segmentation_loss = 0
if not self.skip_segmentation:
segmentation_loss = ((output["segmentation"] - seg_mask) ** 2).mean()
segmentation_loss = (
(output["segmentation"] - seg_mask.to(float)) ** 2
).mean()

mse_loss = self.beamnet.mse(output["pred_theta"], y_rad_reduced)

Expand Down

0 comments on commit a4f68bb

Please sign in to comment.