Skip to content

Commit

Permalink
fix test and dtype for cpu tests
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 17, 2024
1 parent bcf96cf commit 6ac36b5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
35 changes: 24 additions & 11 deletions spf/notebooks/simple_train_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
token_dropout=0.0,
):
super(DebugFunkyNet, self).__init__()
self.l = torch.nn.Linear(3, 1).to(torch.float16)
self.l = torch.nn.Linear(3, 1).to(torch.float32)

def forward(self, x, seg_mask, rx_spacing, y_rad, windowed_beam_former, rx_pos):
return {
Expand Down Expand Up @@ -358,6 +358,10 @@ def simple_train_filter(args):

scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

dtype = torch.float16
if args.dtype == "float32":
dtype = torch.float32

if args.act == "relu":
act = torch.nn.ReLU
elif args.act == "selu":
Expand All @@ -370,7 +374,7 @@ def simple_train_filter(args):
# init model here
#######
if args.debug_model:
m = DebugFunkyNet().to(torch_device)
m = DebugFunkyNet().to(torch_device, dtype=dtype)
else:
m = FunkyNet(
d_hid=args.tformer_dhid,
Expand All @@ -379,7 +383,7 @@ def simple_train_filter(args):
token_dropout=args.tformer_snapshot_dropout,
n_layers=args.tformer_layers,
latent=args.beamnet_latent,
).to(torch_device)
).to(torch_device, dtype=dtype)
########

if args.wandb_project:
Expand Down Expand Up @@ -407,14 +411,18 @@ def simple_train_filter(args):

def batch_data_to_x_y_seg(batch_data):
# x ~ # trimmed_cm, trimmed_stddev, abs_signal_median
x = batch_data["all_windows_stats"].to(torch_device)
rx_pos = batch_data["rx_pos_xy"].to(torch_device)
seg_mask = batch_data["downsampled_segmentation_mask"].to(torch_device)
rx_spacing = batch_data["rx_spacing"].to(torch_device)
windowed_beamformer = batch_data["windowed_beamformer"].to(torch_device)
y_rad = batch_data["y_rad"].to(torch_device)
craft_y_rad = batch_data["craft_y_rad"].to(torch_device)
y_phi = batch_data["y_phi"].to(torch_device)
x = batch_data["all_windows_stats"].to(torch_device, dtype=dtype)
rx_pos = batch_data["rx_pos_xy"].to(torch_device, dtype=dtype)
seg_mask = batch_data["downsampled_segmentation_mask"].to(
torch_device, dtype=dtype
)
rx_spacing = batch_data["rx_spacing"].to(torch_device, dtype=dtype)
windowed_beamformer = batch_data["windowed_beamformer"].to(
torch_device, dtype=dtype
)
y_rad = batch_data["y_rad"].to(torch_device, dtype=dtype)
craft_y_rad = batch_data["craft_y_rad"].to(torch_device, dtype=dtype)
y_phi = batch_data["y_phi"].to(torch_device, dtype=dtype)
assert seg_mask.ndim == 4 and seg_mask.shape[2] == 1
return (
x,
Expand Down Expand Up @@ -873,6 +881,11 @@ def get_parser_filter():
type=int,
default=512,
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
)
parser.add_argument(
"--tformer-dropout",
type=int,
Expand Down
5 changes: 5 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def test_simple_filter_save_load(perfect_circle_dataset_n33):
"1",
"--save-prefix",
save_prefix,
"--device",
"cpu",
"--no-amp",
"--dtype",
"float32",
]
chkpnt_fn = save_prefix + "_step0.chkpnt"
save_args = base_args + [
Expand Down

0 comments on commit 6ac36b5

Please sign in to comment.