Skip to content

Commit

Permalink
bump torch; optimize dataloading
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 28, 2024
1 parent e6e3a3a commit fba0992
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 34 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ test_data.txt
**/testdata*
**/temp
**/sessions*
wandb
**/*.chkpnt
**/*.mdb
wandb
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ six==1.16.0
sympy==1.12
tensordict==0.1.2
tomli==2.0.1
torch==2.3.1
torchvision==0.18.1
torch==2.4.0
torchvision==0.19.0
tqdm==4.66.1
typing_extensions==4.9.0
urllib3==2.1.0
Expand Down
16 changes: 13 additions & 3 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,19 @@ def v5_thetas_to_targets(
def v5_collate_keys_fast(keys: List[str], batch: Dict[str, torch.Tensor]):
d = {}
for key in keys:
d[key] = torch.vstack(
[x[key] for paired_sample in batch for x in paired_sample]
)
if key == "windowed_beamformer" or key == "all_windows_stats":
d[key] = torch.vstack(
[
x[key].to(torch.float32)
for paired_sample in batch
for x in paired_sample
]
)
else:
d[key] = torch.vstack(
[x[key] for paired_sample in batch for x in paired_sample]
)

return TensorDict(d, batch_size=d["y_rad"].shape)


Expand Down
54 changes: 26 additions & 28 deletions spf/notebooks/simple_train_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,30 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):
}


def batch_data_to_x_y_seg(batch_data, torch_device, dtype):
# x ~ # trimmed_cm, trimmed_stddev, abs_signal_median
batch_data = batch_data.to(torch_device)
x = batch_data["all_windows_stats"].to(dtype=dtype)
rx_pos = batch_data["rx_pos_xy"].to(dtype=dtype)
seg_mask = batch_data["downsampled_segmentation_mask"].to(dtype=dtype)
rx_spacing = batch_data["rx_spacing"].to(dtype=dtype)
windowed_beamformer = batch_data["windowed_beamformer"].to(dtype=dtype)
y_rad = batch_data["y_rad"].to(dtype=dtype)
craft_y_rad = batch_data["craft_y_rad"].to(dtype=dtype)
y_phi = batch_data["y_phi"].to(dtype=dtype)
# assert seg_mask.ndim == 4 and seg_mask.shape[2] == 1
return (
x,
y_rad,
craft_y_rad,
y_phi,
seg_mask,
rx_spacing,
windowed_beamformer,
rx_pos,
)


def simple_train_filter(args):
# torch.autograd.detect_anomaly()
assert args.n_radios == 2
Expand Down Expand Up @@ -457,32 +481,6 @@ def simple_train_filter(args):
m.parameters(), lr=args.lr, weight_decay=args.weight_decay
)

def batch_data_to_x_y_seg(batch_data):
# x ~ # trimmed_cm, trimmed_stddev, abs_signal_median
x = batch_data["all_windows_stats"].to(torch_device, dtype=dtype)
rx_pos = batch_data["rx_pos_xy"].to(torch_device, dtype=dtype)
seg_mask = batch_data["downsampled_segmentation_mask"].to(
torch_device, dtype=dtype
)
rx_spacing = batch_data["rx_spacing"].to(torch_device, dtype=dtype)
windowed_beamformer = batch_data["windowed_beamformer"].to(
torch_device, dtype=dtype
)
y_rad = batch_data["y_rad"].to(torch_device, dtype=dtype)
craft_y_rad = batch_data["craft_y_rad"].to(torch_device, dtype=dtype)
y_phi = batch_data["y_phi"].to(torch_device, dtype=dtype)
assert seg_mask.ndim == 4 and seg_mask.shape[2] == 1
return (
x,
y_rad,
craft_y_rad,
y_phi,
seg_mask,
rx_spacing,
windowed_beamformer,
rx_pos,
)

step = 0
losses = []

Expand Down Expand Up @@ -618,7 +616,7 @@ def params_for_ds(ds):
rx_spacing,
windowed_beamformer,
rx_pos,
) = batch_data_to_x_y_seg(val_batch_data)
) = batch_data_to_x_y_seg(val_batch_data, torch_device, dtype)

# run beamformer and segmentation
output = m(
Expand Down Expand Up @@ -666,7 +664,7 @@ def params_for_ds(ds):
rx_spacing,
windowed_beamformer,
rx_pos,
) = batch_data_to_x_y_seg(batch_data)
) = batch_data_to_x_y_seg(batch_data, torch_device, dtype)

with torch.autocast(
device_type=args.device, dtype=torch.float16, enabled=args.amp
Expand Down

0 comments on commit fba0992

Please sign in to comment.