Skip to content

Commit

Permalink
optimize dataloading
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 28, 2024
1 parent e9424ea commit 469cc87
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 24 deletions.
30 changes: 16 additions & 14 deletions spf/model_training_and_inference/models/beamsegnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
torch_circular_mean,
torch_pi_norm,
)
from torch.nn.functional import sigmoid

from math import sqrt

import torch
import math
Expand Down Expand Up @@ -340,7 +342,7 @@ def __init__(
other=False,
no_sigmoid=False,
positional_encoding=False,
max_angle=np.pi / 2,
max_angle=torch.pi / 2,
):
super(BeamNetDiscrete, self).__init__()
self.nthetas = nthetas
Expand Down Expand Up @@ -452,32 +454,32 @@ def render_discrete_x(self, x):

# this is discrete its already rendered
def render_discrete_y(self, y):
assert y.abs().max() <= np.pi / 2
assert y.abs().max() <= torch.pi / 2
return v5_thetas_to_targets(y, self.nthetas, range_in_rad=1, sigma=0.1)


def cdf(mean, sigma, value):
return 0.5 * (1 + torch.erf((value - mean) * sigma.reciprocal() / math.sqrt(2)))
@torch.jit.script
def cdf(mean: torch.Tensor, sigma: torch.Tensor, value: float):
return 0.5 * (1 + torch.erf((value - mean) * sigma.reciprocal() / sqrt(2)))


def normal_correction_for_bounded_range(mean, sigma, max_y):
@torch.jit.script
def normal_correction_for_bounded_range(
mean: torch.Tensor, sigma: torch.Tensor, max_y: float
):
assert max_y > 0
left_p = cdf(mean, sigma, -max_y)
right_p = cdf(mean, sigma, max_y)
return (right_p - left_p).reciprocal()


def normal_dist_d(sigma, d):
assert sigma.ndim == 1
d = d / sigma
return (1 / (sigma * np.sqrt(2 * np.pi))) * torch.exp(-0.5 * d**2)


def normal_dist(x, y, sigma, d=None):
@torch.jit.script
def normal_dist(x, y, sigma):
assert x.ndim == 1
assert y.ndim == 1
assert sigma.ndim == 1
return normal_dist_d(sigma, (x - y))
d = (x - y) / sigma
return (1 / (sigma * sqrt(2 * torch.pi))) * torch.exp(-0.5 * d**2)


def FFN_to_Normal(
Expand Down Expand Up @@ -667,7 +669,7 @@ def __init__(
norm="batch",
# angle specific
nthetas=65,
max_angle=np.pi / 2,
max_angle=torch.pi / 2,
# normal net params
other=True,
no_sigmoid=False,
Expand Down
28 changes: 18 additions & 10 deletions spf/notebooks/simple_train_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from cProfile import Profile
from pstats import SortKey, Stats

from spf.rf import reduce_theta_to_positive_y, torch_pi_norm
from spf.rf import (
reduce_theta_to_positive_y,
torch_pi_norm,
torch_reduce_theta_to_positive_y,
)

import random

Expand Down Expand Up @@ -301,7 +305,7 @@ def forward(self, x, seg_mask, rx_spacing, y_rad, windowed_beam_former, rx_pos):
# if self.train() randomly inject
# if self.eval() never inject!
if y_rad is not None and self.training and self.paired_drop_in_gt > 0.0:
y_rad_reduced = reduce_theta_to_positive_y(y_rad).reshape(-1, 1)
y_rad_reduced = torch_reduce_theta_to_positive_y(y_rad).reshape(-1, 1)
mask = torch.rand(detached_pred_theta.shape[0]) < self.paired_drop_in_gt
detached_pred_theta[mask, 0] = y_rad_reduced[mask, 0]
detached_pred_theta[mask, 1:3] = 0
Expand Down Expand Up @@ -361,7 +365,7 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):
)
transformer_random_loss = (torch_pi_norm(target - random_target) ** 2).mean()

y_rad_reduced = reduce_theta_to_positive_y(y_rad).reshape(-1, 1)
y_rad_reduced = torch_reduce_theta_to_positive_y(y_rad).reshape(-1, 1)
# x to beamformer loss (indirectly including segmentation)
beamnet_loss = -self.beam_m.loglikelihood(
output["pred_theta"], y_rad_reduced
Expand Down Expand Up @@ -541,9 +545,9 @@ def new_log():
val_ds = torch.utils.data.Subset(complete_ds, val_idxs)
print(f"Train-dataset size {len(train_ds)}, Val dataset size {len(val_ds)}")

def params_for_ds(ds):
def params_for_ds(ds, batch_size):
sampler = StatefulBatchsampler(
ds, shuffle=args.shuffle, seed=args.seed, batch_size=args.batch
ds, shuffle=args.shuffle, seed=args.seed, batch_size=batch_size
)
sampler.set_epoch_and_start_iteration(epoch=epoch, start_iteration=step)
return {
Expand All @@ -569,8 +573,12 @@ def params_for_ds(ds):
"batch_sampler": sampler,
}

train_dataloader = torch.utils.data.DataLoader(train_ds, **params_for_ds(train_ds))
val_dataloader = torch.utils.data.DataLoader(val_ds, **params_for_ds(val_ds))
train_dataloader = torch.utils.data.DataLoader(
train_ds, **params_for_ds(train_ds, batch_size=args.batch)
)
val_dataloader = torch.utils.data.DataLoader(
val_ds, **params_for_ds(val_ds, batch_size=args.batch)
)

for epoch in range(args.epochs):
# breakpoint()
Expand All @@ -584,8 +592,8 @@ def params_for_ds(ds):
# return
if args.steps >= 0 and step >= args.steps:
break
if torch.rand(1).item() < 0.02:
gc.collect()
# if torch.rand(1).item() < 0.002:
# gc.collect()
if step % args.save_every == 0:
m.eval()
save_everything(
Expand Down Expand Up @@ -834,7 +842,7 @@ def get_parser_filter():
parser.add_argument(
"--val-every",
type=int,
default=1000,
default=2500,
)
parser.add_argument(
"--save-every",
Expand Down
26 changes: 26 additions & 0 deletions spf/rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ def circular_stddev(v, u, trim=50.0):
return stddev, trimmed_stddev


@torch.jit.script
def torch_reduce_theta_to_positive_y(ground_truth_thetas):
reduced_thetas = ground_truth_thetas.clone()

# |theta|>np.pi/2 means its on the y<0
reduced_ground_truth_thetas_mask = abs(reduced_thetas) > np.pi / 2
reduced_ground_truth_thetas_at_mask = reduced_thetas[
reduced_ground_truth_thetas_mask
]
# reduced_thetas[reduced_ground_truth_thetas_mask] = (
# np.sign(reduced_ground_truth_thetas_at_mask) * np.pi
# - reduced_ground_truth_thetas_at_mask
# )
if isinstance(ground_truth_thetas, torch.Tensor):
reduced_thetas[reduced_ground_truth_thetas_mask] = (
torch.sign(reduced_ground_truth_thetas_at_mask) * torch.pi
- reduced_ground_truth_thetas_at_mask
)
else:
reduced_thetas[reduced_ground_truth_thetas_mask] = (
np.sign(reduced_ground_truth_thetas_at_mask) * np.pi
- reduced_ground_truth_thetas_at_mask
)
return reduced_thetas


def reduce_theta_to_positive_y(ground_truth_thetas):
if isinstance(ground_truth_thetas, torch.Tensor):
reduced_thetas = ground_truth_thetas.clone()
Expand Down

0 comments on commit 469cc87

Please sign in to comment.