Skip to content

Commit

Permalink
Add dataset loader for wall array v2; clean up dataset a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jan 7, 2024
1 parent eb3afb7 commit e41fe1f
Show file tree
Hide file tree
Showing 10 changed files with 687 additions and 246 deletions.
249 changes: 219 additions & 30 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import torch
import yaml
from compress_pickle import load
from torch.utils.data import Dataset

Expand All @@ -16,25 +17,24 @@
labels_to_source_images,
radio_to_image,
)
from spf.rf import ULADetector
from spf.wall_array_v1 import v1_beamformer_start_idx, v1_time_idx, v1_tx_pos_idxs
from spf.wall_array_v2 import (
v2_beamformer_start_idx,
v2_column_names,
v2_rx_pos_idxs,
v2_rx_theta_idx,
v2_time_idx,
v2_tx_pos_idxs,
)


def pos_to_rel(p, width):
return 2 * (p / width - 0.5)


output_cols = { # maybe this should get moved to the dataset part...
"src_pos": [0, 1],
"src_theta": [2],
"src_dist": [3],
"det_delta": [4, 5],
"det_theta": [6],
"det_space": [7],
"src_v": [8, 9],
}

input_cols = {
"det_pos": [0, 1],
"time": [2],
"space_delta": [3, 4],
"space_theta": [5],
"space_dist": [6],
"det_theta2": [7],
}
def rel_to_pos(r, width):
return ((r / 2) + 0.5) * width


# from stackoverflow
Expand All @@ -46,7 +46,7 @@ class dotdict(dict):
__delattr__ = dict.__delitem__


class SessionsDataset(Dataset):
class SessionsDatasetSimulated(Dataset):
def __init__(self, root_dir, snapshots_in_sample=5):
"""
Arguments:
Expand Down Expand Up @@ -95,7 +95,7 @@ def __getitem__(self, idx):
return {k: session[k][start_idx:end_idx] for k in session.keys()}


class SessionsDatasetReal(Dataset):
class SessionsDatasetRealV1(Dataset):
def get_m(self, filename, bypass=False):
if bypass:
return np.memmap(
Expand Down Expand Up @@ -230,7 +230,7 @@ def __getitem__(self, idx):
self.detector_position, (m.shape[0], 2)
)
detector_orientation_at_t = self.zeros[:, [0]]
source_positions_at_t = m[:, 1:3][:, None]
source_positions_at_t = m[:, v1_tx_pos_idxs()][:, None]
diffs = (
source_positions_at_t - detector_position_at_t[:, None]
) # broadcast over nsource dimension
Expand All @@ -247,11 +247,11 @@ def __getitem__(self, idx):
"receiver_positions_at_t": np.broadcast_to(
self.receiver_pos[None], (m.shape[0], 2, 2)
),
"beam_former_outputs_at_t": m[:, 5:],
"beam_former_outputs_at_t": m[:, v1_beamformer_start_idx() :],
"thetas_at_t": np.broadcast_to(
self.thetas[None], (m.shape[0], self.thetas.shape[0])
),
"time_stamps": m[:, [0]],
"time_stamps": m[:, [v1_time_idx()]],
"width_at_t": self.widths,
"detector_orientation_at_t": detector_orientation_at_t, # self.halfpis*0,#np.arctan2(1,0)=np.pi/2
"detector_position_at_t": detector_position_at_t,
Expand All @@ -260,15 +260,204 @@ def __getitem__(self, idx):
}


def pos_to_rel(p, width):
return 2 * (p / width - 0.5)
class SessionsDatasetRealV2(Dataset):
def get_m(self, filename, bypass=False):
if bypass:
return np.memmap(
filename,
dtype="float32",
mode="r",
shape=(2, self.snapshots_in_file, len(self.column_names)),
)

if filename not in self.m_cache:
self.m_cache[filename] = np.memmap(
filename,
dtype="float32",
mode="r",
shape=(2, self.snapshots_in_file, len(self.column_names)),
)
return self.m_cache[filename]

def rel_to_pos(r, width):
return ((r / 2) + 0.5) * width
def check_file(self, filename):
try:
m = self.get_m(filename, bypass=True)
except:
print(
"SessionDatasetReal: Dropping file from loading because memmap failed",
filename,
)
return False
if self.check_files:
status = not (np.abs(m).mean(axis=1) == 0).any()
if status == False:
print(
"SessionDatasetReal: Dropping file from loading because it looks like all zeros",
filename,
)
return False
return True

def get_all_valid_files(self):
return sorted(
filter(
self.check_file,
filter(
lambda x: ".npy" in x,
["%s/%s" % (self.root_dir, x) for x in os.listdir(self.root_dir)],
),
)
)

def __init__(
self,
root_dir,
yaml_config_fn,
snapshots_in_session=128,
nsources=1,
width=3000,
step_size=1,
seed=1337,
check_files=True,
snapshots_in_file=300000,
):
# time_step,x,y,mean_angle,_mean_angle #0,1,2,3,4
# m = np.memmap(filename, dtype='float32', mode='r', shape=(,70))
"""
Arguments:
root_dir (string): Directory with all the images.
"""
self.check_files = check_files
# read YAML
with open(yaml_config_fn, "r") as stream:
yaml_config = yaml.safe_load(stream)

assert nsources == 1 # TODO implement more
self.snapshots_in_file = snapshots_in_file
self.snapshots_in_session = snapshots_in_session
self.step_size = step_size

self.root_dir = root_dir
self.nthetas = yaml_config["n-thetas"]
self.thetas = np.linspace(-np.pi, np.pi, self.nthetas)

self.column_names = v2_column_names(nthetas=self.nthetas)

self.args = dotdict(
{
"width": width,
}
)
self.m_cache = {}

# in case we need them later generate the offsets from the
# RX center for antenna 0 and antenna 1
self.rx_antenna_offsets = []
for receiver in yaml_config["receivers"]:
self.rx_antenna_offsets.append(
ULADetector(
sampling_frequency=None,
n_elements=2,
spacing=receiver["antenna-spacing-m"] * 1000,
orientation=receiver["theta-in-pis"] * np.pi,
).all_receiver_pos()
)
self.rx_antenna_offsets = np.array(self.rx_antenna_offsets)

self.filenames = self.get_all_valid_files()

if len(self.filenames) == 0:
print("SessionsDatasetReal: No valid files to load from")
raise ValueError
self.rng = np.random.default_rng(seed)
self.rng.shuffle(self.filenames)

self.sessions_per_file = [
(
self.get_m(filename, bypass=True).shape[1]
- (self.snapshots_in_session * self.step_size)
)
* 2 # TODO should use n-receivers
for filename in self.filenames
]
self.cumsum_sessions_per_file = np.cumsum([0] + self.sessions_per_file)
self.len = sum(self.sessions_per_file)

# optimizations : cached constants values
self.zeros = np.zeros((self.snapshots_in_session, 5))
self.ones = np.ones((self.snapshots_in_session, 5))
self.widths = (
np.ones((self.snapshots_in_session, 1), dtype=np.int32) * self.args.width
)
self.halfpis = -np.ones((self.snapshots_in_session, 1)) * np.pi / 2

def idx_to_fileidx_and_startidx(self, idx):
file_idx = bisect.bisect_right(self.cumsum_sessions_per_file, idx) - 1
if file_idx >= len(self.sessions_per_file):
return None
start_idx = (
idx - self.cumsum_sessions_per_file[file_idx]
) # *(self.snapshots_in_sample*self.step_size)
return file_idx, start_idx

def __len__(self):
return self.len

def __getitem__(self, idx):
if idx < 0 or idx >= self.len:
raise IndexError
fileidx, unadjusted_startidx = self.idx_to_fileidx_and_startidx(idx)
# need to figure out which receiver A/B we are using here
sessions_per_receiver = self.sessions_per_file[fileidx] // 2

receiver_idx = 0 # assume A
startidx = unadjusted_startidx
if unadjusted_startidx >= sessions_per_receiver: # use B
receiver_idx = 1
startidx = unadjusted_startidx - sessions_per_receiver

m = self.get_m(self.filenames[fileidx])[
receiver_idx,
startidx : startidx
+ self.snapshots_in_session * self.step_size : self.step_size,
]

rx_position_at_t = m[:, v2_rx_pos_idxs()]
rx_orientation_at_t = m[:, v2_rx_theta_idx()]
rx_antenna_positions_at_t = (
rx_position_at_t[:, None] + self.rx_antenna_offsets[receiver_idx]
)

tx_positions_at_t = m[:, v2_tx_pos_idxs()][:, None]
diffs = (
tx_positions_at_t - rx_position_at_t[:, None]
) # broadcast over nsource dimension
# diffs=(batchsize, nsources, 2)
source_theta_at_t = np.arctan2(
diffs[:, 0, [0]], diffs[:, 0, [1]]
) # rotation to the right around x=0, y+

return {
"broadcasting_positions_at_t": self.ones[:, [0]][
:, None
], # TODO multi source
"source_positions_at_t": tx_positions_at_t,
"source_velocities_at_t": self.zeros[:, :2][:, None], # TODO calc velocity,
"receiver_positions_at_t": rx_antenna_positions_at_t,
"beam_former_outputs_at_t": m[:, v2_beamformer_start_idx() :],
"thetas_at_t": np.broadcast_to(
self.thetas[None], (m.shape[0], self.thetas.shape[0])
),
"time_stamps": m[:, [v2_time_idx()]],
"width_at_t": self.widths,
"detector_orientation_at_t": rx_orientation_at_t, # self.halfpis*0,#np.arctan2(1,0)=np.pi/2
"detector_position_at_t": rx_position_at_t,
"source_theta_at_t": source_theta_at_t,
"source_distance_at_t": self.zeros[:, [0]][:, None],
}


class SessionsDatasetRealTask2(SessionsDatasetReal):
class SessionsDatasetRealTask2(SessionsDatasetRealV1):
def __getitem__(self, idx):
d = super().__getitem__(idx)
# normalize these before heading out
Expand All @@ -287,7 +476,7 @@ def __getitem__(self, idx):
return d # ,d['source_positions_at_t']


class SessionsDatasetTask2(SessionsDataset):
class SessionsDatasetTask2(SessionsDatasetSimulated):
def __getitem__(self, idx):
d = super().__getitem__(idx)
# normalize these before heading out
Expand All @@ -306,7 +495,7 @@ def __getitem__(self, idx):
return d # ,d['source_positions_at_t']


class SessionsDatasetTask2WithImages(SessionsDataset):
class SessionsDatasetTask2WithImages(SessionsDatasetSimulated):
def __getitem__(self, idx):
d = super().__getitem__(idx)
# normalize these before heading out
Expand Down
26 changes: 20 additions & 6 deletions spf/model_training_and_inference/12_task2_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@
import torch.optim as optim
from models.models import SingleSnapshotNet, SnapshotNet, Task1Net, UNet

from spf.dataset.spf_dataset import (
SessionsDatasetTask2,
collate_fn,
input_cols,
output_cols,
)
from spf.dataset.spf_dataset import SessionsDatasetTask2, collate_fn

output_cols = { # maybe this should get moved to the dataset part...
"src_pos": [0, 1],
"src_theta": [2],
"src_dist": [3],
"det_delta": [4, 5],
"det_theta": [6],
"det_space": [7],
"src_v": [8, 9],
}

input_cols = {
"det_pos": [0, 1],
"time": [2],
"space_delta": [3, 4],
"space_theta": [5],
"space_dist": [6],
"det_theta2": [7],
}

torch.set_printoptions(precision=5, sci_mode=False, linewidth=1000)

Expand Down
21 changes: 19 additions & 2 deletions spf/model_training_and_inference/14_task3_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,27 @@
SessionsDatasetRealTask2,
SessionsDatasetTask2,
collate_fn_transformer_filter,
input_cols,
output_cols,
)

output_cols = { # maybe this should get moved to the dataset part...
"src_pos": [0, 1],
"src_theta": [2],
"src_dist": [3],
"det_delta": [4, 5],
"det_theta": [6],
"det_space": [7],
"src_v": [8, 9],
}

input_cols = {
"det_pos": [0, 1],
"time": [2],
"space_delta": [3, 4],
"space_theta": [5],
"space_dist": [6],
"det_theta2": [7],
}

torch.set_num_threads(8)
torch.set_printoptions(precision=5, sci_mode=False, linewidth=1000)

Expand Down
4 changes: 2 additions & 2 deletions spf/model_training_and_inference/90_real_session_plotter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse

from spf.dataset.spf_dataset import SessionsDatasetReal
from spf.dataset.spf_dataset import SessionsDatasetRealV1
from spf.plot.plot import filenames_to_gif, plot_full_session

if __name__ == "__main__":
Expand All @@ -19,7 +19,7 @@
)
args = parser.parse_args()
assert args.snapshots_in_sample >= args.steps
ds = SessionsDatasetReal(
ds = SessionsDatasetRealV1(
root_dir=args.root_dir,
snapshots_in_file=args.snapshots_in_file,
nthetas=args.nthetas,
Expand Down
Loading

0 comments on commit e41fe1f

Please sign in to comment.