Skip to content

Commit

Permalink
optmizations and partial loading support
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 28, 2024
1 parent 6da922b commit 650d5e6
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 50 deletions.
2 changes: 1 addition & 1 deletion spf/dataset/fake_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False):
"rx_pos_y_mm": 0,
"tx_pos_x_mm": np.sin(thetas[record_idx]) * radius,
"tx_pos_y_mm": np.cos(thetas[record_idx]) * radius,
"system_timestamp": record_idx * 5.0,
"system_timestamp": 1.0 + record_idx * 5.0,
"rx_theta_in_pis": yaml_config["receivers"][receiver_idx][
"theta-in-pis"
],
Expand Down
47 changes: 37 additions & 10 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,16 @@ def segment_single_session(


def mp_segment_zarr(zarr_fn, results_fn, steering_vectors_for_all_receivers, gpu=False):
print("Segmenting file", zarr_fn)

z = zarr_open_from_lmdb_store(zarr_fn)
valid_entries = min(
[
(z[f"receivers/r{ridx}/system_timestamp"][:] > 0).sum()
for ridx in range(len(z["receivers"]))
]
)

print("Segmenting file", valid_entries, zarr_fn)
assert len(z["receivers"]) == 2

n_sessions, _, _ = z.receivers["r0"].signal_matrix.shape
Expand All @@ -142,7 +150,7 @@ def mp_segment_zarr(zarr_fn, results_fn, steering_vectors_for_all_receivers, gpu
"gpu": gpu,
**default_segment_args,
}
for idx in range(n_sessions)
for idx in range(valid_entries)
]

with Pool(min(cpu_count(), 20)) as pool: # cpu_count()) # cpu_count() // 4)
Expand Down Expand Up @@ -170,32 +178,35 @@ def mp_segment_zarr(zarr_fn, results_fn, steering_vectors_for_all_receivers, gpu

for r_idx in [0, 1]:
# collect all windows stats
z[f"r{r_idx}/all_windows_stats"][:] = np.vstack(
z[f"r{r_idx}/all_windows_stats"][:valid_entries] = np.vstack(
[x["all_windows_stats"][None] for x in results_by_receiver[f"r{r_idx}"]]
)
# collect windowed beamformer
z[f"r{r_idx}/windowed_beamformer"][:] = np.vstack(
z[f"r{r_idx}/windowed_beamformer"][:valid_entries] = np.vstack(
[x["windowed_beamformer"][None] for x in results_by_receiver[f"r{r_idx}"]]
)
# collect downsampled segmentation mask
z[f"r{r_idx}/downsampled_segmentation_mask"][:] = np.vstack(
z[f"r{r_idx}/downsampled_segmentation_mask"][:valid_entries] = np.vstack(
[
x["downsampled_segmentation_mask"][None]
for x in results_by_receiver[f"r{r_idx}"]
]
)
# remove from dictionary to prevent it from going into pkl file later
for x in results_by_receiver[f"r{r_idx}"]:
x.pop("all_windows_stats")
x.pop("windowed_beamformer")

simple_segmentation = {}
for r_idx in [0, 1]:
simple_segmentation[f"r{r_idx}"] = [
{"simple_segmentation": x["simple_segmentation"]}
for x in results_by_receiver[f"r{r_idx}"]
] + [{"simple_segmentation": []} for _ in range(n_sessions - valid_entries)]

z.store.close()
z = None
zarr_shrink(segmentation_zarr_fn)
pickle.dump(
{
"version": SEGMENTATION_VERSION,
"segmentation_by_receiver": results_by_receiver,
"segmentation_by_receiver": simple_segmentation,
},
open(results_fn, "wb"),
)
Expand Down Expand Up @@ -505,6 +516,11 @@ def __init__(
.to(torch.float32)
.mean()
)
else:
# if we are a temp verison
self.mean_phase = {}
for ridx in range(self.n_receivers):
self.mean_phase[f"r{ridx}"] = torch.ones(len(self)) * torch.inf

if not ignore_qc:
assert not temp_file
Expand All @@ -523,6 +539,10 @@ def __init__(

# self.close()

def get_valid_entries(self):
self.refresh()
return self.valid_entries

def refresh(self):
# get how many entries are in the underlying storage
# recompute if we need to
Expand All @@ -547,6 +567,13 @@ def refresh(self):
return True
return False

def get_mean_phase(self, ridx, idx):
v = self.mean_phase[f"r{ridx}"][idx]
if torch.isfinite(v):
return v
else:
print("NOT VALID")

def close(self):
# let workers open their own
self.z.store.close()
Expand Down
4 changes: 3 additions & 1 deletion spf/dataset/zarr_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def compare_and_copy(prefix, src, dst):
dst[:] = src[:]
else:
for x in range(src.shape[0]):
dst[x] = src[x]
dst[x] = src[
x
] # TODO why cant we just copy the whole thing at once? # too big?


if __name__ == "__main__":
Expand Down
29 changes: 20 additions & 9 deletions spf/model_training_and_inference/models/beamsegnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def normal_correction_for_bounded_range(


@torch.jit.script
def normal_dist(x, y, sigma):
def normal_dist(x: torch.Tensor, y: torch.Tensor, sigma: torch.Tensor):
assert x.ndim == 1
assert y.ndim == 1
assert sigma.ndim == 1
Expand Down Expand Up @@ -534,7 +534,8 @@ def __init__(

self.beam_net = beam_net

def fixify(self, _y, sign):
# @torch.compile
def fixify(self, _y: torch.Tensor, sign: float):
_y_sig = self.sigmoid(_y) # in [0,1]
if self.no_sigmoid:
mean_values = sign * _y[:, [0]]
Expand All @@ -556,10 +557,17 @@ def fixify(self, _y, sign):
]
)

def forward(self, x):
def forward(self, x: torch.Tensor):
return self.fixify(self.beam_net(x), sign=1)

def likelihood(self, x, y, sigma_eps=0.01, smoothing_prob=0.0001):
# @torch.compile
def likelihood(
self,
x: torch.Tensor,
y: torch.Tensor,
):
sigma_eps: float = 0.01
smoothing_prob: float = 0.0001
assert y.ndim == 2 and y.shape[1] == 1
assert x.ndim == 2 and x.shape[1] >= 5
### EXTREMELY IMPORTANT!!! x[:,[0]] NOT x[:,0]
Expand Down Expand Up @@ -618,18 +626,20 @@ def likelihood(self, x, y, sigma_eps=0.01, smoothing_prob=0.0001):
assert likelihood.shape == (x.shape[0], 1)
return likelihood + smoothing_prob

def mse(self, x, y):
# @torch.compile
def mse(self, x: torch.Tensor, y: torch.Tensor):
# not sure why we cant wrap around for torch.pi/2....
# assert np.isclose(self.max_angle, torch.pi, atol=0.05)
# if self.max_angle == torch.pi:
return (torch_pi_norm(x[:, 0] - y[:, 0], max_angle=self.max_angle) ** 2).mean()
# return ((x[:, 0] - y[:, 0]) ** 2).mean()

def loglikelihood(self, x, y, log_eps=0.000000001):
return torch.log(self.likelihood(x, y) + log_eps)
# @torch.compile
def loglikelihood(self, x: torch.Tensor, y: torch.Tensor): # , log_eps: float =
return torch.log(self.likelihood(x, y) + 0.000000001)

# this is discrete its already rendered
def render_discrete_x(self, x):
def render_discrete_x(self, x: torch.Tensor):

thetas = torch.linspace(
-self.max_angle,
Expand All @@ -650,7 +660,7 @@ def render_discrete_x(self, x):
return likelihoods

# this is discrete its already rendered
def render_discrete_y(self, y):
def render_discrete_y(self, y: torch.Tensor):
assert y.abs().max() <= self.max_angle
return v5_thetas_to_targets(y, self.nthetas, sigma=0.5, range_in_rad=1)

Expand Down Expand Up @@ -721,6 +731,7 @@ def __init__(
self.rx_spacing_track = rx_spacing_track
self.symmetry = symmetry

# @torch.compile
def forward(self, x):
# split into pd>=0 and pd<0

Expand Down
63 changes: 43 additions & 20 deletions spf/notebooks/simple_train_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import cache, partial

import numpy as np
import tensordict
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
Expand All @@ -22,6 +23,7 @@
from spf.rf import (
reduce_theta_to_positive_y,
torch_pi_norm,
torch_pi_norm_pi,
torch_reduce_theta_to_positive_y,
)

Expand All @@ -35,6 +37,8 @@

from torch.utils.data import DistributedSampler, Sampler, BatchSampler

torch.set_float32_matmul_precision("high")


# from fair-chem repo
class StatefulDistributedSampler(DistributedSampler):
Expand Down Expand Up @@ -175,6 +179,25 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):
}


# @torch.no_grad
# @torch.compile
def random_loss(target: torch.Tensor, y_rad_reduced: torch.Tensor):
random_target = (torch.rand(target.shape, device=target.device) - 0.5) * 2 * np.pi
beamnet_mse_random = (
torch_pi_norm(
y_rad_reduced
- (torch.rand(y_rad_reduced.shape, device=target.device) - 0.5)
* 2
* np.pi
/ 2,
max_angle=torch.pi / 2,
)
** 2
).mean()
transformer_random_loss = (torch_pi_norm_pi(target - random_target) ** 2).mean()
return beamnet_mse_random, transformer_random_loss


class FunkyNet(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -268,6 +291,7 @@ def __init__(
self.paired_drop_in_gt = 0.00
self.token_dropout = token_dropout

# @torch.compile
def forward(self, x, seg_mask, rx_spacing, y_rad, windowed_beam_former, rx_pos):
rx_pos = rx_pos.detach().clone() / 4000

Expand Down Expand Up @@ -354,36 +378,30 @@ def forward(self, x, seg_mask, rx_spacing, y_rad, windowed_beam_former, rx_pos):
"pred_theta": pred_theta,
}

def loss(self, output, y_rad, craft_y_rad, seg_mask):
# @torch.compile
def loss(
self,
output: torch.Tensor,
y_rad: torch.Tensor,
craft_y_rad: torch.Tensor,
seg_mask: torch.Tensor,
):
target = craft_y_rad[::2, [-1]]
transformer_loss = (
torch_pi_norm(target - output["transformer_output"]) ** 2
torch_pi_norm_pi(target - output["transformer_output"]) ** 2
).mean()

random_target = (
(torch.rand(target.shape, device=target.device) - 0.5) * 2 * np.pi
)
transformer_random_loss = (torch_pi_norm(target - random_target) ** 2).mean()

y_rad_reduced = torch_reduce_theta_to_positive_y(y_rad).reshape(-1, 1)
# x to beamformer loss (indirectly including segmentation)
beamnet_loss = -self.beam_m.loglikelihood(
output["pred_theta"], y_rad_reduced
).mean()

beamnet_mse = self.beam_m.mse(output["pred_theta"], y_rad_reduced)
beamnet_mse_random = (
torch_pi_norm(
y_rad_reduced
- (torch.rand(y_rad_reduced.shape, device=target.device) - 0.5)
* 2
* np.pi
/ 2,
max_angle=torch.pi / 2,
)
** 2
).mean()

loss = transformer_loss + beamnet_loss

beamnet_mse_random, transformer_random_loss = random_loss(target, y_rad_reduced)
return {
"loss": loss,
"transformer_mse_loss": transformer_loss,
Expand All @@ -394,7 +412,12 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):
}


def batch_data_to_x_y_seg(batch_data, torch_device, dtype):
# @torch.compile
def batch_data_to_x_y_seg(
batch_data: tensordict.tensordict.TensorDict,
torch_device: torch.device,
dtype: torch.dtype,
):
# x ~ # trimmed_cm, trimmed_stddev, abs_signal_median
batch_data = batch_data.to(torch_device)
x = batch_data["all_windows_stats"].to(dtype=dtype)
Expand Down Expand Up @@ -521,7 +544,7 @@ def new_log():
ignore_qc=args.skip_qc,
gpu=args.device == "cuda",
snapshots_per_session=args.snapshots_per_session,
readahead=False,
readahead=True,
skip_simple_segmentations=True,
)
for prefix in args.datasets
Expand Down
Loading

0 comments on commit 650d5e6

Please sign in to comment.