Skip to content

Commit

Permalink
use tensordict; fixes;
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 14, 2024
1 parent 0bdbccc commit 1a525c0
Show file tree
Hide file tree
Showing 8 changed files with 395 additions and 96 deletions.
84 changes: 84 additions & 0 deletions spf/dataset/benchmark_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from functools import partial
import numpy as np
import scipy
from tqdm import tqdm
from filterpy.monte_carlo import systematic_resample

import pickle
import os

import time

from spf.dataset.spf_dataset import (
v5_collate_beamsegnet,
v5_collate_keys_fast,
v5spfdataset,
)
import torch
import random

ds_fn = "/mnt/4tb_ssd/june_fix/wallarrayv3_2024_06_15_11_44_13_nRX2_bounce.zarr"


nthetas = 65
ds = v5spfdataset(
ds_fn,
nthetas=nthetas,
ignore_qc=True,
precompute_cache="/home/mouse9911/precompute_cache_chunk16_fresh",
paired=True,
skip_signal_matrix=True,
snapshots_per_session=500,
skip_simple_segmentations=True,
)

idxs = torch.arange(len(ds))
random.shuffle(idxs)

random.seed(10)
ds = torch.utils.data.Subset(ds, idxs[:3000])
print("Getting")

if False:
start_time = time.time()
count = 1
for idx in range(len(ds)):
if count % 1000 == 0:
print((time.time() - start_time) / count, "seconds per sample")
ds[idx]
count += 1
print((time.time() - start_time) / len(ds), "seconds per sample")

# x = batch_data["all_windows_stats"].to(torch_device).to(torch.float32)
# rx_pos = batch_data["rx_pos_xy"].to(torch_device)
# seg_mask = batch_data["downsampled_segmentation_mask"].to(torch_device)
# rx_spacing = batch_data["rx_spacing"].to(torch_device)
# windowed_beamformer = batch_data["windowed_beamformer"].to(torch_device)
# y_rad = batch_data["y_rad"].to(torch_device)
# craft_y_rad = batch_data["craft_y_rad"].to(torch_device)
# y_phi = batch_data["y_phi"].to(torch_device)
workers = 0
dataloader_params = {
"batch_size": 8,
"shuffle": True,
"num_workers": workers,
"collate_fn": partial(
v5_collate_keys_fast,
[
"all_windows_stats",
"rx_pos_xy",
"downsampled_segmentation_mask",
"rx_spacing",
# "windowed_beamformer",
"y_rad",
"craft_y_rad",
"y_phi",
],
),
"pin_memory": True,
"prefetch_factor": 1 if workers > 0 else None,
}
train_dataloader = torch.utils.data.DataLoader(ds, **dataloader_params)

for step, batch_data in enumerate(tqdm(train_dataloader)):
pass
163 changes: 119 additions & 44 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import time
from multiprocessing import Pool, cpu_count
from typing import List

import gc
import numpy as np
from tensordict import TensorDict
import torch
import tqdm
import yaml
Expand Down Expand Up @@ -204,6 +205,24 @@ def v5_thetas_to_targets(target_thetas, nthetas, range_in_rad, sigma=1):
# return torch.nn.functional.normalize(p, p=1, dim=1)


def v5_collate_keys_fast(keys, batch):
d = {}
for key in keys:
d[key] = torch.vstack(
[x[key] for paired_sample in batch for x in paired_sample]
)
return TensorDict(d, batch_size=d["y_rad"].shape)


def v5_collate_all_fast(batch):
d = {}
for key in batch[0][0].keys():
d[key] = torch.vstack(
[x[key] for paired_sample in batch for x in paired_sample]
)
return d


def v5_collate_beamsegnet(batch):
n_windows = batch[0][0]["all_windows_stats"].shape[1]
y_rad_list = []
Expand All @@ -225,15 +244,15 @@ def v5_collate_beamsegnet(batch):
rx_pos_list.append(x["rx_pos_xy"])
craft_y_rad_list.append(x["craft_y_rad"])
rx_spacing_list.append(x["rx_spacing"].reshape(-1, 1))
simple_segmentation_list += x["simple_segmentations"]
# simple_segmentation_list += x["simple_segmentations"]
all_window_stats_list.append(x["all_windows_stats"]) # .astype(np.float32)
windowed_beamformers_list.append(
x["windowed_beamformer"] # .astype(np.float32)
)
downsampled_segmentation_mask_list.append(
x["downsampled_segmentation_mask"]
)
receiver_idx_list.append(x["receiver_idx"].repeat(x["y_rad"].shape[1]))
receiver_idx_list.append(x["receiver_idx"].expand_as(x["y_rad"]))

if "x" in batch[0][0]:
x_list.append(x["x"])
Expand All @@ -245,7 +264,7 @@ def v5_collate_beamsegnet(batch):
"receiver_idx": torch.vstack(receiver_idx_list),
"craft_y_rad": torch.vstack(craft_y_rad_list),
"rx_spacing": torch.vstack(rx_spacing_list),
"simple_segmentation": simple_segmentation_list,
# "simple_segmentation": simple_segmentation_list,
"all_windows_stats": torch.from_numpy(np.vstack(all_window_stats_list)),
"windowed_beamformer": torch.from_numpy(np.vstack(windowed_beamformers_list)),
"downsampled_segmentation_mask": torch.vstack(
Expand All @@ -255,7 +274,6 @@ def v5_collate_beamsegnet(batch):
if "x" in batch[0][0]:
d["x"] = torch.vstack(x_list)
d["segmentation_mask"] = torch.vstack(segmentation_mask_list)

return d


Expand Down Expand Up @@ -299,16 +317,20 @@ def __init__(
gpu=False,
snapshots_per_session=1,
tiled_sessions=True,
readahead=False,
skip_simple_segmentations=False,
):
# print("Open", prefix)
self.readahead = readahead
self.precompute_cache = precompute_cache
prefix = prefix.replace(".zarr", "")
self.nthetas = nthetas
self.prefix = prefix
self.skip_signal_matrix = skip_signal_matrix
self.skip_simple_segmentations = skip_simple_segmentations
self.zarr_fn = f"{prefix}.zarr"
self.yaml_fn = f"{prefix}.yaml"
self.z = zarr_open_from_lmdb_store(self.zarr_fn)
self.z = zarr_open_from_lmdb_store(self.zarr_fn, readahead=self.readahead)
self.yaml_config = yaml.safe_load(open(self.yaml_fn, "r"))
self.paired = paired
self.n_receivers = len(self.yaml_config["receivers"])
Expand Down Expand Up @@ -430,6 +452,16 @@ def __init__(
raise ValueError(
"It looks like too few windows have a valid segmentation"
)
self.close()

def close(self):
# let workers open their own
self.z.store.close()
self.z = None
self.receiver_data = None
# try and close segmentation
self.segmentation = None
self.precomputed_zarr = None

def estimate_phi(self, data):
x = torch.tensor(data["all_windows_stats"])
Expand All @@ -442,19 +474,38 @@ def __len__(self):
return self.n_sessions
return self.n_sessions * self.n_receivers

def reinit(self):
if self.z is None:
# worker_info = torch.utils.data.get_worker_info()
# print(worker_info)
self.z = zarr_open_from_lmdb_store(self.zarr_fn, readahead=self.readahead)
self.receiver_data = [
self.z.receivers[f"r{ridx}"] for ridx in range(self.n_receivers)
]
if self.precomputed_zarr is None:
self.get_segmentation()
self.precomputed_zarr = zarr_open_from_lmdb_store(
self.results_fn().replace(".pkl", ".yarr"), mode="r"
)

def render_session(self, receiver_idx, session_idx):
self.reinit()
if self.tiled_sessions:
snapshot_start_idx = session_idx
snapshot_end_idx = session_idx + self.snapshots_per_session
else:
snapshot_start_idx = session_idx * self.snapshots_per_session
snapshot_end_idx = (session_idx + 1) * self.snapshots_per_session

r = self.receiver_data[receiver_idx]

data = {
key: r[key][snapshot_start_idx:snapshot_end_idx]
for key in self.keys_per_session
}
data["receiver_idx"] = np.array(receiver_idx)
data["receiver_idx"] = torch.tensor(receiver_idx).expand(
1, self.snapshots_per_session
)
data["ground_truth_theta"] = self.ground_truth_thetas[receiver_idx][
snapshot_start_idx:snapshot_end_idx
]
Expand All @@ -466,66 +517,88 @@ def render_session(self, receiver_idx, session_idx):
]
data = {
k: (
torch.from_numpy(v)
if type(v) not in (np.float64, float)
else torch.Tensor([v])
torch.from_numpy(v.astype(np.float32)).unsqueeze(0)
if type(v) not in (np.float64, float, torch.Tensor)
else v
)
for k, v in data.items()
}
if not self.skip_signal_matrix:
abs_signal = data["signal_matrix"].abs().to(torch.float32)
pd = torch_get_phase_diff(data["signal_matrix"]).to(torch.float32)
data["x"] = torch.vstack([abs_signal[0], abs_signal[1], pd])[None]

data["y_rad"] = data["ground_truth_theta"][None].to(torch.float32)
data["y_phi"] = data["ground_truth_phi"][None].to(torch.float32)
data["craft_y_rad"] = data["craft_ground_truth_theta"][None].to(torch.float32)
data["x"] = torch.concatenate(
[abs_signal[:, [0]], abs_signal[:, [1]], pd[:, None]], dim=1
)
data["y_rad"] = data["ground_truth_theta"].to(torch.float32)
data["y_phi"] = data["ground_truth_phi"].to(torch.float32)
data["craft_y_rad"] = data["craft_ground_truth_theta"].to(torch.float32)

# data["y_discrete"] = v5_thetas_to_targets(data["y_rad"], self.nthetas)

data["windowed_beamformer"] = self.precomputed_zarr[
f"r{receiver_idx}/windowed_beamformer"
][snapshot_start_idx:snapshot_end_idx]

data["simple_segmentations"] = [
d["simple_segmentation"]
for d in self.segmentation["segmentation_by_receiver"][f"r{receiver_idx}"][
data["windowed_beamformer"] = torch.tensor(
self.precomputed_zarr[f"r{receiver_idx}/windowed_beamformer"][
snapshot_start_idx:snapshot_end_idx
]
]
).unsqueeze(0)

if not self.skip_simple_segmentations:
data["simple_segmentations"] = [
d["simple_segmentation"]
for d in self.segmentation["segmentation_by_receiver"][
f"r{receiver_idx}"
][snapshot_start_idx:snapshot_end_idx]
]

# sessions x 3 x n_windows
data["all_windows_stats"] = self.precomputed_zarr[
f"r{receiver_idx}/all_windows_stats"
][snapshot_start_idx:snapshot_end_idx]
data["all_windows_stats"] = torch.tensor(
self.precomputed_zarr[f"r{receiver_idx}/all_windows_stats"][
snapshot_start_idx:snapshot_end_idx
]
).unsqueeze(0)

# n_windows = data["all_windows_stats"].shape[2]
# data["downsampled_segmentation_mask"] = v5_downsampled_segmentation_mask(
# data, n_windows=n_windows
# )
data["downsampled_segmentation_mask"] = torch.tensor(
self.precomputed_zarr[f"r{receiver_idx}"]["downsampled_segmentation_mask"][
snapshot_start_idx:snapshot_end_idx
]
).unsqueeze(1)
data["downsampled_segmentation_mask"] = (
torch.tensor(
self.precomputed_zarr[f"r{receiver_idx}"][
"downsampled_segmentation_mask"
][snapshot_start_idx:snapshot_end_idx]
)
.unsqueeze(1)
.unsqueeze(0)
)

# breakpoint()
data["mean_phase_segmentation"] = self.mean_phase[f"r{receiver_idx}"][
snapshot_start_idx:snapshot_end_idx
]
data["mean_phase_segmentation"] = torch.tensor(
self.mean_phase[f"r{receiver_idx}"][
snapshot_start_idx:snapshot_end_idx
].astype(np.float32)
).unsqueeze(0)
data["rx_pos_xy"] = torch.tensor(
np.array(
np.vstack(
[
self.receiver_data[receiver_idx]["rx_pos_x_mm"][
snapshot_start_idx:snapshot_end_idx
],
].astype(np.float32),
self.receiver_data[receiver_idx]["rx_pos_y_mm"][
snapshot_start_idx:snapshot_end_idx
],
].astype(np.float32),
]
)
).T.to(torch.float32)
).T.unsqueeze(0)
# for key in data.keys():
# if isinstance(data[key], list):
# print(key)
# else:
# print(key, data[key].shape, data[key].dtype)
# breakpoint()
# trimmed_cm, trimmed_stddev, abs_signal_median
# self.close()
# self.close()
if torch.rand(1).item() < 0.005:
gc.collect()
return data

def get_ground_truth_phis(self):
Expand Down Expand Up @@ -639,13 +712,15 @@ def get_estimated_thetas(self):
)
return estimated_thetas

def results_fn(self):
return os.path.join(
self.precompute_cache,
os.path.basename(self.prefix) + f"_segmentation_nthetas{self.nthetas}.pkl",
)

def get_segmentation(self):
if not hasattr(self, "segmentation"):
results_fn = os.path.join(
self.precompute_cache,
os.path.basename(self.prefix)
+ f"_segmentation_nthetas{self.nthetas}.pkl",
)
if not hasattr(self, "segmentation") or self.segmentation is None:
results_fn = self.results_fn()

if not os.path.exists(results_fn):
mp_segment_zarr(
Expand Down
Loading

0 comments on commit 1a525c0

Please sign in to comment.