Skip to content

Commit

Permalink
fix tests by moving to float32
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jun 30, 2024
1 parent 07c2857 commit c15745f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,9 @@ def render_session(self, receiver_idx, session_idx):
pd = torch_get_phase_diff(data["signal_matrix"]).to(torch.float32)
data["x"] = torch.vstack([abs_signal[0], abs_signal[1], pd])[None]

data["y_rad"] = data["ground_truth_theta"][None]
data["y_phi"] = data["ground_truth_phi"][None]
data["craft_y_rad"] = data["craft_ground_truth_theta"][None]
data["y_rad"] = data["ground_truth_theta"][None].to(torch.float32)
data["y_phi"] = data["ground_truth_phi"][None].to(torch.float32)
data["craft_y_rad"] = data["craft_ground_truth_theta"][None].to(torch.float32)

# data["y_discrete"] = v5_thetas_to_targets(data["y_rad"], self.nthetas)

Expand Down

0 comments on commit c15745f

Please sign in to comment.