Skip to content

Commit

Permalink
test for partial dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 29, 2024
1 parent 27a9c33 commit 0a65533
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 67 deletions.
179 changes: 146 additions & 33 deletions spf/dataset/fake_dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import argparse
import os
import shutil

import matplotlib.pyplot as plt
import numpy as np
import torch
import yaml
import random
from spf.data_collector import rx_config_from_receiver_yaml
from spf.dataset.spf_dataset import pi_norm
import zarr

# V5 data format
from spf.dataset.v5_data import v5rx_2xf64_keys, v5rx_f64_keys, v5rx_new_dataset
Expand All @@ -15,9 +19,36 @@
pi_norm,
precompute_steering_vectors,
speed_of_light,
torch_get_avg_phase,
torch_get_avg_phase_notrim,
torch_pi_norm_pi,
)
from spf.sdrpluto.sdr_controller import rx_config_from_receiver_yaml
from spf.utils import random_signal_matrix
from spf.utils import (
random_signal_matrix,
torch_random_signal_matrix,
zarr_open_from_lmdb_store,
zarr_shrink,
)


@torch.jit.script
def phi_to_signal_matrix(
phi: torch.Tensor, buffer_size: int, noise: float, phi_drift: float
):
big_phi = phi.repeat(buffer_size).reshape(1, -1)
big_phi_with_noise = big_phi + torch.randn((1, buffer_size)) * noise
offsets = torch.zeros(big_phi.shape, dtype=torch.complex128)
return (
torch.vstack(
[
torch.exp(1j * (offsets + phi_drift)),
torch.exp(1j * (offsets + big_phi_with_noise)),
]
)
* 200
)


"""
theta is the angle from array normal to incident
Expand Down Expand Up @@ -123,24 +154,24 @@ def create_fake_dataset(
config=yaml_config,
)

thetas = pi_norm(
np.linspace(0, 2 * np.pi * orbits, yaml_config["n-records-per-receiver"])
thetas = torch_pi_norm_pi(
torch.linspace(0, 2 * torch.pi * orbits, yaml_config["n-records-per-receiver"])
)

def theta_to_phi(theta, antenna_spacing_m, _lambda):
return np.sin(theta) * antenna_spacing_m * 2 * np.pi / _lambda
return torch.sin(theta) * antenna_spacing_m * 2 * torch.pi / _lambda

def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False):
sin_arg = _lambda * phi / (antenna_spacing_m * 2 * np.pi)
sin_arg = _lambda * phi / (antenna_spacing_m * 2 * torch.pi)
# assert sin_arg.min()>-1
# assert sin_arg.max()<1
if limit:
edge = 1 - 1e-8
sin_arg = np.clip(sin_arg, a_min=-edge, a_max=edge)
v = np.arcsin(_lambda * phi / (antenna_spacing_m * 2 * np.pi))
return v, np.pi - v
sin_arg = torch.clip(sin_arg, min=-edge, max=edge)
v = torch.arcsin(_lambda * phi / (antenna_spacing_m * 2 * torch.pi))
return v, torch.pi - v

rnd_noise = np.random.randn(thetas.shape[0])
rnd_noise = torch.randn(thetas.shape[0])

# signal_matrix = np.vstack([np.exp(1j * phis), np.ones(phis.shape)])

Expand All @@ -149,29 +180,18 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False):
thetas - yaml_config["receivers"][receiver_idx]["theta-in-pis"] * np.pi
)
phis_nonoise = theta_to_phi(receiver_thetas, rx_config.rx_spacing, _lambda)
phis = pi_norm(phis_nonoise + rnd_noise * noise)
_thetas = phi_to_theta(phis, rx_config.rx_spacing, _lambda, limit=True)
phis = torch_pi_norm_pi(phis_nonoise + rnd_noise * noise)
# _thetas = phi_to_theta(phis, rx_config.rx_spacing, _lambda, limit=True)

for record_idx in range(yaml_config["n-records-per-receiver"]):
big_phi = phis[[record_idx], None].repeat(rx_config.buffer_size, axis=1)
big_phi_with_noise = big_phi + np.random.randn(*big_phi.shape) * noise
offsets = np.random.uniform(-np.pi, np.pi, big_phi.shape) * 0
signal_matrix = (
np.vstack(
[
np.exp(
1j
* (
offsets
+ phi_drift * np.pi * (1 if receiver_idx == 0 else -1)
)
),
np.exp(1j * (offsets + big_phi_with_noise)),
]
)
* 200
signal_matrix = phi_to_signal_matrix(
phis[[record_idx]],
rx_config.buffer_size,
noise,
phi_drift * torch.pi * (1 if receiver_idx == 0 else -1),
)
noise_matrix = random_signal_matrix(

noise_matrix = torch_random_signal_matrix(
signal_matrix.reshape(-1).shape[0]
).reshape(signal_matrix.shape)
# add stripes
Expand All @@ -185,22 +205,22 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False):
data = {
"rx_pos_x_mm": 0,
"rx_pos_y_mm": 0,
"tx_pos_x_mm": np.sin(thetas[record_idx]) * radius,
"tx_pos_y_mm": np.cos(thetas[record_idx]) * radius,
"tx_pos_x_mm": torch.sin(thetas[record_idx]) * radius,
"tx_pos_y_mm": torch.cos(thetas[record_idx]) * radius,
"system_timestamp": 1.0 + record_idx * 5.0,
"rx_theta_in_pis": yaml_config["receivers"][receiver_idx][
"theta-in-pis"
],
"rx_spacing": rx_config.rx_spacing,
"rx_lo": rx_config.lo,
"rx_bandwidth": rx_config.rf_bandwidth,
"avg_phase_diff": get_avg_phase(signal_matrix),
"avg_phase_diff": torch_get_avg_phase_notrim(signal_matrix), # , 0.0),
"rssis": [0, 0],
"gains": [0, 0],
}

z = m[f"receivers/r{receiver_idx}"]
z.signal_matrix[record_idx] = signal_matrix
z.signal_matrix[record_idx] = signal_matrix.numpy()
for k in v5rx_f64_keys + v5rx_2xf64_keys:
z[k][record_idx] = data[k]
# nthetas = 64 + 1
Expand All @@ -214,3 +234,96 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False):
# steering_vectors=steering_vectors,
# signal_matrix=signal_matrix,
# )


def compare_and_copy_n(prefix, src, dst, n):
if isinstance(src, zarr.hierarchy.Group):
for key in src.keys():
compare_and_copy_n(prefix + "/" + key, src[key], dst[key], n)
else:
if prefix == "/config":
if src.shape != ():
dst[:] = src[:]
else:
for x in range(n):
dst[x] = src[x]


def partial_dataset(input_fn, output_fn, n):
input_fn.replace(".zarr", "")
z = zarr_open_from_lmdb_store(input_fn + ".zarr")
timesteps, _, buffer_size = z["receivers/r0/signal_matrix"].shape
input_yaml_fn = input_fn + ".yaml"
output_yaml_fn = output_fn + ".yaml"
yaml_config = yaml.safe_load(open(input_yaml_fn, "r"))
shutil.copyfile(input_yaml_fn, output_yaml_fn)
new_z = v5rx_new_dataset(
filename=output_fn + ".zarr",
timesteps=timesteps,
buffer_size=buffer_size,
n_receivers=len(yaml_config["receivers"]),
chunk_size=512,
compressor=None,
config=yaml_config,
remove_if_exists=False,
)
compare_and_copy_n("", z, new_z, n)
new_z.store.close()
new_z = None
zarr_shrink(output_fn)


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--filename",
type=str,
required=False,
default="fake_dataset",
)
parser.add_argument(
"--orbits",
type=int,
required=False,
default="2",
)
parser.add_argument(
"--n",
type=int,
required=False,
default="1024",
)
parser.add_argument(
"--noise",
type=float,
required=False,
default="0.3",
)
parser.add_argument(
"--phi-drift",
type=float,
required=False,
default=0.0,
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=0,
)
return parser


if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
create_fake_dataset(
fake_yaml,
args.filename,
orbits=args.orbits,
n=args.n,
noise=args.noise,
phi_drift=args.phi_drift,
radius=10000,
seed=args.seed,
)
41 changes: 21 additions & 20 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,27 +512,7 @@ def __init__(
).expand(1, self.snapshots_per_session)

if not self.temp_file:

self.get_segmentation()
# get mean phase segmentation
self.mean_phase = {}
for receiver, results in self.segmentation[
"segmentation_by_receiver"
].items():
self.mean_phase[receiver] = torch.tensor(
[
(
# TODO:UNBUG (circular mean)
torch.tensor(
[x["mean"] for x in result["simple_segmentation"]]
).mean()
if len(result["simple_segmentation"]) > 0
else 0.0
)
for result in results
],
dtype=torch.float32,
)

self.all_phi_drifts = self.get_all_phi_drifts()
self.phi_drifts = torch.tensor(
Expand Down Expand Up @@ -850,6 +830,24 @@ def get_ground_truth_thetas(self):
assert self.n_receivers == 2
return torch.vstack(ground_truth_thetas)

def get_mean_phase(self):
self.mean_phase = {}
for receiver, results in self.segmentation["segmentation_by_receiver"].items():
self.mean_phase[receiver] = torch.tensor(
[
(
# TODO:UNBUG (circular mean)
torch.tensor(
[x["mean"] for x in result["simple_segmentation"]]
).mean()
if len(result["simple_segmentation"]) > 0
else 0.0
)
for result in results
],
dtype=torch.float32,
)

def __getitem__(self, idx):
if self.paired:
assert idx < self.n_snapshots
Expand Down Expand Up @@ -903,8 +901,10 @@ def get_segmentation(self, precompute_to_idx=-1):
precompute_to_idx=precompute_to_idx,
gpu=self.gpu,
)

try:
segmentation = pickle.load(open(results_fn, "rb"))

precomputed_zarr = zarr_open_from_lmdb_store(
results_fn.replace(".pkl", ".yarr"), mode="r"
)
Expand All @@ -930,6 +930,7 @@ def get_segmentation(self, precompute_to_idx=-1):
os.remove(results_fn)
return self.get_segmentation(precompute_to_idx=precompute_to_idx)
self.segmentation = segmentation
self.get_mean_phase()
self.precomputed_zarr = precomputed_zarr
return self.segmentation

Expand Down
3 changes: 2 additions & 1 deletion spf/model_training_and_inference/models/beamsegnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
pi_norm,
reduce_theta_to_positive_y,
torch_circular_mean,
# torch_circular_mean_weighted,
torch_pi_norm,
)
from torch.nn.functional import sigmoid
Expand Down Expand Up @@ -905,7 +906,7 @@ def forward(
)
if self.circular_mean:
weighted_input[:, self.beamnet.pd_track] = torch_circular_mean(
x[:, self.beamnet.pd_track], weights=seg_mask[:, 0], trim=0
x[:, self.beamnet.pd_track], weights=seg_mask[:, 0], trim=0.0
)[0]

if self.training and self.drop_in_gt > 0.0:
Expand Down
12 changes: 6 additions & 6 deletions spf/model_training_and_inference/models/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ def observation(self, idx):
# self.mean_phase[f"r{receiver_idx}"][snapshot_start_idx:snapshot_end_idx]
return self.ds.mean_phase[f"r{self.rx_idx}"][idx]
# breakpoint()
# return (
# self.ds[idx][self.rx_idx]["mean_phase_segmentation"]
# .detach()
# .numpy()
# .reshape(-1)
# )
return (
self.ds[idx][self.rx_idx]["mean_phase_segmentation"]
.detach()
.numpy()
.reshape(-1)
)

def fix_particles(self):
self.particles = fix_particles_single(self.particles)
Expand Down
Loading

0 comments on commit 0a65533

Please sign in to comment.