diff --git a/spf/dataset/fake_dataset.py b/spf/dataset/fake_dataset.py index 66cb36d2..99cd47f1 100644 --- a/spf/dataset/fake_dataset.py +++ b/spf/dataset/fake_dataset.py @@ -1,11 +1,15 @@ +import argparse import os +import shutil import matplotlib.pyplot as plt import numpy as np +import torch import yaml import random from spf.data_collector import rx_config_from_receiver_yaml from spf.dataset.spf_dataset import pi_norm +import zarr # V5 data format from spf.dataset.v5_data import v5rx_2xf64_keys, v5rx_f64_keys, v5rx_new_dataset @@ -15,9 +19,36 @@ pi_norm, precompute_steering_vectors, speed_of_light, + torch_get_avg_phase, + torch_get_avg_phase_notrim, + torch_pi_norm_pi, ) from spf.sdrpluto.sdr_controller import rx_config_from_receiver_yaml -from spf.utils import random_signal_matrix +from spf.utils import ( + random_signal_matrix, + torch_random_signal_matrix, + zarr_open_from_lmdb_store, + zarr_shrink, +) + + +@torch.jit.script +def phi_to_signal_matrix( + phi: torch.Tensor, buffer_size: int, noise: float, phi_drift: float +): + big_phi = phi.repeat(buffer_size).reshape(1, -1) + big_phi_with_noise = big_phi + torch.randn((1, buffer_size)) * noise + offsets = torch.zeros(big_phi.shape, dtype=torch.complex128) + return ( + torch.vstack( + [ + torch.exp(1j * (offsets + phi_drift)), + torch.exp(1j * (offsets + big_phi_with_noise)), + ] + ) + * 200 + ) + """ theta is the angle from array normal to incident @@ -123,24 +154,24 @@ def create_fake_dataset( config=yaml_config, ) - thetas = pi_norm( - np.linspace(0, 2 * np.pi * orbits, yaml_config["n-records-per-receiver"]) + thetas = torch_pi_norm_pi( + torch.linspace(0, 2 * torch.pi * orbits, yaml_config["n-records-per-receiver"]) ) def theta_to_phi(theta, antenna_spacing_m, _lambda): - return np.sin(theta) * antenna_spacing_m * 2 * np.pi / _lambda + return torch.sin(theta) * antenna_spacing_m * 2 * torch.pi / _lambda def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False): - sin_arg = _lambda * phi / (antenna_spacing_m * 2 * np.pi) + sin_arg = _lambda * phi / (antenna_spacing_m * 2 * torch.pi) # assert sin_arg.min()>-1 # assert sin_arg.max()<1 if limit: edge = 1 - 1e-8 - sin_arg = np.clip(sin_arg, a_min=-edge, a_max=edge) - v = np.arcsin(_lambda * phi / (antenna_spacing_m * 2 * np.pi)) - return v, np.pi - v + sin_arg = torch.clip(sin_arg, min=-edge, max=edge) + v = torch.arcsin(_lambda * phi / (antenna_spacing_m * 2 * torch.pi)) + return v, torch.pi - v - rnd_noise = np.random.randn(thetas.shape[0]) + rnd_noise = torch.randn(thetas.shape[0]) # signal_matrix = np.vstack([np.exp(1j * phis), np.ones(phis.shape)]) @@ -149,29 +180,18 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False): thetas - yaml_config["receivers"][receiver_idx]["theta-in-pis"] * np.pi ) phis_nonoise = theta_to_phi(receiver_thetas, rx_config.rx_spacing, _lambda) - phis = pi_norm(phis_nonoise + rnd_noise * noise) - _thetas = phi_to_theta(phis, rx_config.rx_spacing, _lambda, limit=True) + phis = torch_pi_norm_pi(phis_nonoise + rnd_noise * noise) + # _thetas = phi_to_theta(phis, rx_config.rx_spacing, _lambda, limit=True) for record_idx in range(yaml_config["n-records-per-receiver"]): - big_phi = phis[[record_idx], None].repeat(rx_config.buffer_size, axis=1) - big_phi_with_noise = big_phi + np.random.randn(*big_phi.shape) * noise - offsets = np.random.uniform(-np.pi, np.pi, big_phi.shape) * 0 - signal_matrix = ( - np.vstack( - [ - np.exp( - 1j - * ( - offsets - + phi_drift * np.pi * (1 if receiver_idx == 0 else -1) - ) - ), - np.exp(1j * (offsets + big_phi_with_noise)), - ] - ) - * 200 + signal_matrix = phi_to_signal_matrix( + phis[[record_idx]], + rx_config.buffer_size, + noise, + phi_drift * torch.pi * (1 if receiver_idx == 0 else -1), ) - noise_matrix = random_signal_matrix( + + noise_matrix = torch_random_signal_matrix( signal_matrix.reshape(-1).shape[0] ).reshape(signal_matrix.shape) # add stripes @@ -185,8 +205,8 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False): data = { "rx_pos_x_mm": 0, "rx_pos_y_mm": 0, - "tx_pos_x_mm": np.sin(thetas[record_idx]) * radius, - "tx_pos_y_mm": np.cos(thetas[record_idx]) * radius, + "tx_pos_x_mm": torch.sin(thetas[record_idx]) * radius, + "tx_pos_y_mm": torch.cos(thetas[record_idx]) * radius, "system_timestamp": 1.0 + record_idx * 5.0, "rx_theta_in_pis": yaml_config["receivers"][receiver_idx][ "theta-in-pis" @@ -194,13 +214,13 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False): "rx_spacing": rx_config.rx_spacing, "rx_lo": rx_config.lo, "rx_bandwidth": rx_config.rf_bandwidth, - "avg_phase_diff": get_avg_phase(signal_matrix), + "avg_phase_diff": torch_get_avg_phase_notrim(signal_matrix), # , 0.0), "rssis": [0, 0], "gains": [0, 0], } z = m[f"receivers/r{receiver_idx}"] - z.signal_matrix[record_idx] = signal_matrix + z.signal_matrix[record_idx] = signal_matrix.numpy() for k in v5rx_f64_keys + v5rx_2xf64_keys: z[k][record_idx] = data[k] # nthetas = 64 + 1 @@ -214,3 +234,96 @@ def phi_to_theta(phi, antenna_spacing_m, _lambda, limit=False): # steering_vectors=steering_vectors, # signal_matrix=signal_matrix, # ) + + +def compare_and_copy_n(prefix, src, dst, n): + if isinstance(src, zarr.hierarchy.Group): + for key in src.keys(): + compare_and_copy_n(prefix + "/" + key, src[key], dst[key], n) + else: + if prefix == "/config": + if src.shape != (): + dst[:] = src[:] + else: + for x in range(n): + dst[x] = src[x] + + +def partial_dataset(input_fn, output_fn, n): + input_fn.replace(".zarr", "") + z = zarr_open_from_lmdb_store(input_fn + ".zarr") + timesteps, _, buffer_size = z["receivers/r0/signal_matrix"].shape + input_yaml_fn = input_fn + ".yaml" + output_yaml_fn = output_fn + ".yaml" + yaml_config = yaml.safe_load(open(input_yaml_fn, "r")) + shutil.copyfile(input_yaml_fn, output_yaml_fn) + new_z = v5rx_new_dataset( + filename=output_fn + ".zarr", + timesteps=timesteps, + buffer_size=buffer_size, + n_receivers=len(yaml_config["receivers"]), + chunk_size=512, + compressor=None, + config=yaml_config, + remove_if_exists=False, + ) + compare_and_copy_n("", z, new_z, n) + new_z.store.close() + new_z = None + zarr_shrink(output_fn) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--filename", + type=str, + required=False, + default="fake_dataset", + ) + parser.add_argument( + "--orbits", + type=int, + required=False, + default="2", + ) + parser.add_argument( + "--n", + type=int, + required=False, + default="1024", + ) + parser.add_argument( + "--noise", + type=float, + required=False, + default="0.3", + ) + parser.add_argument( + "--phi-drift", + type=float, + required=False, + default=0.0, + ) + parser.add_argument( + "--seed", + type=int, + required=False, + default=0, + ) + return parser + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + create_fake_dataset( + fake_yaml, + args.filename, + orbits=args.orbits, + n=args.n, + noise=args.noise, + phi_drift=args.phi_drift, + radius=10000, + seed=args.seed, + ) diff --git a/spf/dataset/spf_dataset.py b/spf/dataset/spf_dataset.py index 9d3a1525..7283ff74 100644 --- a/spf/dataset/spf_dataset.py +++ b/spf/dataset/spf_dataset.py @@ -512,27 +512,7 @@ def __init__( ).expand(1, self.snapshots_per_session) if not self.temp_file: - self.get_segmentation() - # get mean phase segmentation - self.mean_phase = {} - for receiver, results in self.segmentation[ - "segmentation_by_receiver" - ].items(): - self.mean_phase[receiver] = torch.tensor( - [ - ( - # TODO:UNBUG (circular mean) - torch.tensor( - [x["mean"] for x in result["simple_segmentation"]] - ).mean() - if len(result["simple_segmentation"]) > 0 - else 0.0 - ) - for result in results - ], - dtype=torch.float32, - ) self.all_phi_drifts = self.get_all_phi_drifts() self.phi_drifts = torch.tensor( @@ -850,6 +830,24 @@ def get_ground_truth_thetas(self): assert self.n_receivers == 2 return torch.vstack(ground_truth_thetas) + def get_mean_phase(self): + self.mean_phase = {} + for receiver, results in self.segmentation["segmentation_by_receiver"].items(): + self.mean_phase[receiver] = torch.tensor( + [ + ( + # TODO:UNBUG (circular mean) + torch.tensor( + [x["mean"] for x in result["simple_segmentation"]] + ).mean() + if len(result["simple_segmentation"]) > 0 + else 0.0 + ) + for result in results + ], + dtype=torch.float32, + ) + def __getitem__(self, idx): if self.paired: assert idx < self.n_snapshots @@ -903,8 +901,10 @@ def get_segmentation(self, precompute_to_idx=-1): precompute_to_idx=precompute_to_idx, gpu=self.gpu, ) + try: segmentation = pickle.load(open(results_fn, "rb")) + precomputed_zarr = zarr_open_from_lmdb_store( results_fn.replace(".pkl", ".yarr"), mode="r" ) @@ -930,6 +930,7 @@ def get_segmentation(self, precompute_to_idx=-1): os.remove(results_fn) return self.get_segmentation(precompute_to_idx=precompute_to_idx) self.segmentation = segmentation + self.get_mean_phase() self.precomputed_zarr = precomputed_zarr return self.segmentation diff --git a/spf/model_training_and_inference/models/beamsegnet.py b/spf/model_training_and_inference/models/beamsegnet.py index 83e3bc65..f669494a 100644 --- a/spf/model_training_and_inference/models/beamsegnet.py +++ b/spf/model_training_and_inference/models/beamsegnet.py @@ -9,6 +9,7 @@ pi_norm, reduce_theta_to_positive_y, torch_circular_mean, + # torch_circular_mean_weighted, torch_pi_norm, ) from torch.nn.functional import sigmoid @@ -905,7 +906,7 @@ def forward( ) if self.circular_mean: weighted_input[:, self.beamnet.pd_track] = torch_circular_mean( - x[:, self.beamnet.pd_track], weights=seg_mask[:, 0], trim=0 + x[:, self.beamnet.pd_track], weights=seg_mask[:, 0], trim=0.0 )[0] if self.training and self.drop_in_gt > 0.0: diff --git a/spf/model_training_and_inference/models/particle_filter.py b/spf/model_training_and_inference/models/particle_filter.py index 0bfb811a..592b9611 100644 --- a/spf/model_training_and_inference/models/particle_filter.py +++ b/spf/model_training_and_inference/models/particle_filter.py @@ -180,12 +180,12 @@ def observation(self, idx): # self.mean_phase[f"r{receiver_idx}"][snapshot_start_idx:snapshot_end_idx] return self.ds.mean_phase[f"r{self.rx_idx}"][idx] # breakpoint() - # return ( - # self.ds[idx][self.rx_idx]["mean_phase_segmentation"] - # .detach() - # .numpy() - # .reshape(-1) - # ) + return ( + self.ds[idx][self.rx_idx]["mean_phase_segmentation"] + .detach() + .numpy() + .reshape(-1) + ) def fix_particles(self): self.particles = fix_particles_single(self.particles) diff --git a/spf/rf.py b/spf/rf.py index 1c88f1dd..988d2fdc 100644 --- a/spf/rf.py +++ b/spf/rf.py @@ -140,13 +140,16 @@ def reduce_theta_to_positive_y(ground_truth_thetas): return reduced_thetas +# @njit def circular_diff_to_mean(angles, means): assert means.ndim == 1 a = np.abs(means[:, None] - angles) % (2 * np.pi) b = 2 * np.pi - a + # breakpoint() return np.min(np.vstack([a[None], b[None]]), axis=0) +# @njit def circular_mean(angles, trim, weights=None): assert angles.ndim == 2 _sin_angles = np.sin(angles) @@ -182,7 +185,18 @@ def torch_circular_diff_to_mean(angles: torch.Tensor, means: torch.Tensor): return m -def torch_circular_mean(angles, trim, weights=None): +@torch.jit.script +def torch_circular_mean_notrim(angles: torch.Tensor): + assert angles.ndim == 2 + _sin_angles = torch.sin(angles) + _cos_angles = torch.cos(angles) + cm = torch.arctan2(_sin_angles.sum(dim=1), _cos_angles.sum(dim=1)) % (2 * torch.pi) + + r = torch_pi_norm_pi(cm) + return r, r + + +def torch_circular_mean(angles: torch.Tensor, trim: float, weights=None): assert angles.ndim == 2 _sin_angles = torch.sin(angles) _cos_angles = torch.cos(angles) @@ -190,9 +204,7 @@ def torch_circular_mean(angles, trim, weights=None): _sin_angles = _sin_angles * weights _cos_angles = _cos_angles * weights - cm = torch.arctan2(_sin_angles.sum(axis=1), _cos_angles.sum(axis=1)) % ( - 2 * torch.pi - ) + cm = torch.arctan2(_sin_angles.sum(dim=1), _cos_angles.sum(dim=1)) % (2 * torch.pi) if trim == 0.0: r = torch_pi_norm_pi(cm) @@ -200,7 +212,7 @@ def torch_circular_mean(angles, trim, weights=None): dists = torch_circular_diff_to_mean(angles=angles, means=cm) - mask = dists <= torch.quantile(dists, (1.0 - trim / 100), axis=1, keepdims=True) + mask = dists <= torch.quantile(dists, (1.0 - trim / 100), dim=1, keepdim=True) _cm = torch.zeros(angles.shape[0]) for idx in range(angles.shape[0]): _cm[idx] = torch.arctan2( @@ -450,6 +462,32 @@ def torch_get_phase_diff(signal_matrix: torch.Tensor): return torch_pi_norm_pi(signal_matrix[:, 0].angle() - signal_matrix[:, 1].angle()) +# @njit +def get_avg_phase(signal_matrix, trim=0.0): + return np.array( + circular_mean(get_phase_diff(signal_matrix=signal_matrix)[None], trim=trim) + ).reshape(-1) + + +@torch.jit.script +def torch_get_avg_phase_notrim(signal_matrix: torch.Tensor): + return torch.hstack( + torch_circular_mean_notrim( + torch_get_phase_diff(signal_matrix=signal_matrix)[None], + ) + ) + + +# @torch.jit.script +def torch_get_avg_phase(signal_matrix: torch.Tensor, trim: float): + return torch.tensor( + torch_circular_mean( + torch_get_phase_diff(signal_matrix=signal_matrix)[None], trim + ) + ) + + +# @njit def get_avg_phase(signal_matrix, trim=0.0): return np.array( circular_mean(get_phase_diff(signal_matrix=signal_matrix)[None], trim=trim) diff --git a/spf/utils.py b/spf/utils.py index 1aa63d20..abe357ae 100644 --- a/spf/utils.py +++ b/spf/utils.py @@ -3,6 +3,7 @@ from datetime import datetime import numpy as np +import torch import yaml import zarr from numcodecs import Blosc @@ -36,6 +37,11 @@ def random_signal_matrix(n, rng=None): return np.random.uniform(-1, 1, (n,)) + 1.0j * np.random.uniform(-1, 1, (n,)) +@torch.jit.script +def torch_random_signal_matrix(n: int): + return (torch.rand((n,)) - 0.5) * 2 + 1.0j * (torch.rand((n,)) - 0.5) + + def zarr_remove_if_exists(zarr_fn): for fn in ["data.mdb", "lock.mdb"]: if os.path.exists(zarr_fn + "/" + fn): diff --git a/tests/test_particle_filter.py b/tests/test_particle_filter.py index 13be4492..31fc822e 100644 --- a/tests/test_particle_filter.py +++ b/tests/test_particle_filter.py @@ -1,7 +1,9 @@ import tempfile -from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml -from spf.dataset.spf_dataset import v5spfdataset +import torch +from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml, partial_dataset +from spf.dataset.spf_dataset import v5spfdataset +import random from spf.model_training_and_inference.models.create_empirical_p_dist import ( apply_symmetry_rules_to_heatmap, get_heatmap, @@ -121,3 +123,40 @@ def test_single_theta_dual_radio(noise1_n128_obits2, heatmap): result = run_xy_dual_radio(**args) assert result[0]["metrics"]["mse_theta"] < 0.25 plot_xy_dual_radio(ds, heatmap) + + +def test_partial(noise1_n128_obits2): + dirname, ds_fn = noise1_n128_obits2 + ds_og = v5spfdataset( + ds_fn, + precompute_cache=dirname, + nthetas=65, + skip_signal_matrix=True, + paired=True, + ignore_qc=True, + gpu=False, + ) + with tempfile.TemporaryDirectory() as tmpdirname: + ds_fn_out = f"{tmpdirname}/partial" + for partial_n in [10, 100, 128]: + partial_dataset(ds_fn, ds_fn_out, partial_n) + ds = v5spfdataset( + ds_fn_out, + precompute_cache=tmpdirname, + nthetas=65, + skip_signal_matrix=True, + paired=True, + ignore_qc=True, + gpu=False, + temp_file=True, + temp_file_suffix="", + ) + assert min(ds.valid_entries) == partial_n + random.seed(0) + idxs = list(range(partial_n)) + random.shuffle(idxs) + for idx in idxs[:8]: + for r_idx in range(2): + for key in ds_og[0][0].keys(): + if isinstance(ds_og[idx][r_idx][key], torch.Tensor): + assert (ds_og[idx][r_idx][key] == ds[idx][r_idx][key]).all()