Skip to content

Commit

Permalink
add tests for particle filter
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 28, 2024
1 parent 0eb4988 commit e6e3a3a
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 55 deletions.
5 changes: 4 additions & 1 deletion spf/dataset/fake_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import matplotlib.pyplot as plt
import numpy as np
import yaml

import random
from spf.data_collector import rx_config_from_receiver_yaml
from spf.dataset.spf_dataset import pi_norm

Expand Down Expand Up @@ -94,7 +94,10 @@ def create_fake_dataset(
noise=0.01,
phi_drift=0.0,
radius=10000,
seed=0,
):
random.seed(seed)
np.random.seed(seed)
yaml_fn = f"{filename}.yaml"
zarr_fn = f"{filename}.zar"
seg_fn = f"{filename}_segmentation.pkl"
Expand Down
48 changes: 28 additions & 20 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def __init__(
[
(
# TODO:UNBUG (circular mean)
np.array(
torch.tensor(
[x["mean"] for x in result["simple_segmentation"]]
).mean()
if len(result["simple_segmentation"]) > 0
Expand All @@ -462,31 +462,39 @@ def __init__(
)

self.all_phi_drifts = self.get_all_phi_drifts()
self.phi_drifts = np.array(
[np.nanmean(all_phi_drift) for all_phi_drift in self.all_phi_drifts]
self.phi_drifts = torch.tensor(
[torch.nanmean(all_phi_drift) for all_phi_drift in self.all_phi_drifts]
)
self.average_windows_in_segmentation = np.array(
[
self.average_windows_in_segmentation = (
torch.tensor(
[
len(x["simple_segmentation"])
for x in self.segmentation["segmentation_by_receiver"][
f"r{rx_idx}"
[
len(x["simple_segmentation"])
for x in self.segmentation["segmentation_by_receiver"][
f"r{rx_idx}"
]
]
for rx_idx in range(self.n_receivers)
]
for rx_idx in range(self.n_receivers)
]
).mean()
self.mean_sessions_with_maybe_valid_segmentation = np.array(
[
)
.to(torch.float32)
.mean()
)
self.mean_sessions_with_maybe_valid_segmentation = (
torch.tensor(
[
len(x["simple_segmentation"]) > 2
for x in self.segmentation["segmentation_by_receiver"][
f"r{rx_idx}"
[
len(x["simple_segmentation"]) > 2
for x in self.segmentation["segmentation_by_receiver"][
f"r{rx_idx}"
]
]
for rx_idx in [0, 1]
]
for rx_idx in [0, 1]
]
).mean()
)
.to(torch.float32)
.mean()
)

if not ignore_qc:
assert not temp_file
Expand Down Expand Up @@ -774,7 +782,7 @@ def get_ground_truth_thetas(self):
rx_to_tx_theta = torch.arctan2(d[0], d[1])
rx_theta_in_pis = self.cached_keys[ridx]["rx_theta_in_pis"]
ground_truth_thetas.append(
torch_pi_norm(rx_to_tx_theta - rx_theta_in_pis[:] * np.pi)
torch_pi_norm(rx_to_tx_theta - rx_theta_in_pis[:] * torch.pi)
)
# reduce GT thetas in case of two antennas
# in 2D there are generally two spots that satisfy phase diff
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def get_heatmap(dss, bins=50):
return heatmaps[0].copy() + heatmaps[1].copy()


def apply_symmetry_rules_to_heatmap(h):
half = h[:25] + np.flip(h[25:])
def apply_symmetry_rules_to_heatmap(h, bins=50):
half = h[: bins // 2] + np.flip(h[bins // 2 :])
half = half + np.flip(half, axis=0)
full = np.vstack([half, np.flip(half)])
return full / full.sum(axis=1, keepdims=True)
Expand Down
58 changes: 26 additions & 32 deletions spf/model_training_and_inference/models/particle_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,6 @@ def __init__(self, ds, full_p_fn):
# np.vstack([ds[0][0]["tx_pos_x_mm"], ds[0][0]["tx_pos_y_mm"]]).T
# )

def tx_state(self, idx):
return torch.concatenate(
[
self.ds[idx][0]["tx_pos_x_mm"].reshape(1),
self.ds[idx][0]["tx_pos_y_mm"].reshape(1),
],
axis=0,
)

def our_state(self, idx):
return torch.vstack(
[
Expand Down Expand Up @@ -431,12 +422,12 @@ def plot_single_theta_single_radio(ds, full_p_fn):
)
ax[0, rx_idx].plot(ds.ground_truth_phis[rx_idx][:n], label="perfect phi")
ax[1, rx_idx].plot(
ds[0][rx_idx]["ground_truth_theta"],
ds.ground_truth_thetas[rx_idx].reshape(-1),
label=f"r{rx_idx} gt theta",
)

xs = torch.vstack([x["mu"][0] for x in trajectory])
stds = torch.sqrt(torch.vstack([x["var"][0] for x in trajectory]))
xs = torch.hstack([x["mu"][0] for x in trajectory])
stds = torch.sqrt(torch.hstack([x["var"][0] for x in trajectory]))

ax[1, rx_idx].fill_between(
torch.arange(xs.shape[0]),
Expand All @@ -446,12 +437,13 @@ def plot_single_theta_single_radio(ds, full_p_fn):
color="red",
alpha=0.2,
)

ax[1, rx_idx].scatter(
range(ds.snapshots_per_session), xs, label="PF-x", color="orange", s=0.5
range(xs.shape[0]), xs, label="PF-x", color="orange", s=0.5
)

ax[1, rx_idx].plot(
reduce_theta_to_positive_y(ds[0][rx_idx]["ground_truth_theta"]),
reduce_theta_to_positive_y(ds.ground_truth_thetas[rx_idx]),
label=f"r{rx_idx} reduced gt theta",
color="black",
linestyle="dashed",
Expand Down Expand Up @@ -499,13 +491,14 @@ def plot_single_theta_dual_radio(ds, full_p_fn):
)

ax[1].plot(
torch_pi_norm_pi(ds[0][0]["craft_y_rad"][0]),
# torch_pi_norm_pi(ds[0][0]["craft_y_rad"][0]),
torch_pi_norm_pi(ds.craft_ground_truth_thetas),
label="craft gt theta",
linestyle="dashed",
)

xs = torch.vstack([x["mu"][0] for x in traj_paired])
stds = torch.sqrt(torch.vstack([x["var"][0] for x in traj_paired]))
xs = torch.hstack([x["mu"][0] for x in traj_paired])
stds = torch.sqrt(torch.hstack([x["var"][0] for x in traj_paired]))

ax[1].fill_between(
torch.arange(xs.shape[0]),
Expand Down Expand Up @@ -559,10 +552,11 @@ def plot_xy_dual_radio(ds, full_p_fn):
linestyle="dashed",
)

ax[1].plot(torch_pi_norm_pi(ds[0][0]["craft_y_rad"][0]))
# ax[1].plot(torch_pi_norm_pi(ds[0][0]["craft_y_rad"][0]))
ax[1].plot(torch_pi_norm_pi(ds.craft_ground_truth_thetas))

xs = torch.vstack([x["mu"][0] for x in traj_paired])
stds = torch.sqrt(torch.vstack([x["var"][0] for x in traj_paired]))
xs = torch.hstack([x["mu"][0] for x in traj_paired])
stds = torch.sqrt(torch.hstack([x["var"][0] for x in traj_paired]))

ax[1].fill_between(
torch.arange(xs.shape[0]),
Expand All @@ -572,20 +566,20 @@ def plot_xy_dual_radio(ds, full_p_fn):
color="red",
alpha=0.2,
)
ax[1].scatter(
range(ds.snapshots_per_session), xs, label="PF-x", color="orange", s=0.5
)
ax[1].scatter(range(xs.shape[0]), xs, label="PF-x", color="orange", s=0.5)

xys = torch.vstack([x["mu"][[1, 2]] for x in traj_paired])
ax[2].scatter(
range(ds.snapshots_per_session), xys[:, 0], label="PF-x", color="orange", s=0.5
)
ax[2].scatter(
range(ds.snapshots_per_session), xys[:, 1], label="PF-y", color="blue", s=0.5
gt_xy = torch.vstack(
[
ds.cached_keys[0]["tx_pos_x_mm"],
ds.cached_keys[0]["tx_pos_y_mm"],
]
)
tx = np.vstack([ds[0][0]["tx_pos_x_mm"], ds[0][0]["tx_pos_y_mm"]])
ax[2].plot(tx[0, :], label="gt-x", color="red")
ax[2].plot(tx[1, :], label="gt-y", color="black")
xys = torch.vstack([x["mu"][[1, 2]] for x in traj_paired])
ax[2].scatter(range(xys.shape[0]), xys[:, 0], label="PF-x", color="orange", s=0.5)
ax[2].scatter(range(xys.shape[0]), xys[:, 1], label="PF-y", color="blue", s=0.5)
# tx = np.vstack([ds[0][0]["tx_pos_x_mm"], ds[0][0]["tx_pos_y_mm"]])
ax[2].plot(gt_xy[0, :], label="gt-x", color="red")
ax[2].plot(gt_xy[1, :], label="gt-y", color="black")

ax[0].set_ylabel("radio phi")

Expand Down
123 changes: 123 additions & 0 deletions tests/test_particle_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import tempfile
from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml
from spf.dataset.spf_dataset import v5spfdataset

from spf.model_training_and_inference.models.create_empirical_p_dist import (
apply_symmetry_rules_to_heatmap,
get_heatmap,
)
import pytest
import pickle

from spf.model_training_and_inference.models.particle_filter import (
plot_single_theta_dual_radio,
plot_single_theta_single_radio,
plot_xy_dual_radio,
run_single_theta_dual_radio,
run_single_theta_single_radio,
run_xy_dual_radio,
)


@pytest.fixture
def noise1_n128_obits2():
with tempfile.TemporaryDirectory() as tmpdirname:
n = 128
fn = tmpdirname + f"/perfect_circle_n{n}_noise0"
create_fake_dataset(
filename=fn, yaml_config_str=fake_yaml, n=n, noise=0.3, orbits=2
)
yield tmpdirname, fn


@pytest.fixture
def heatmap(noise1_n128_obits2):
dirname, ds_fn = noise1_n128_obits2
ds = v5spfdataset(
ds_fn,
precompute_cache=dirname,
nthetas=65,
skip_signal_matrix=True,
paired=True,
ignore_qc=True,
gpu=False,
)
heatmap = get_heatmap([ds], bins=50)
heatmap = apply_symmetry_rules_to_heatmap(heatmap)
full_p_fn = f"{dirname}/full_p.pkl"
pickle.dump({"full_p": heatmap}, open(full_p_fn, "wb"))
return full_p_fn


def test_single_theta_single_radio(noise1_n128_obits2, heatmap):
dirname, ds_fn = noise1_n128_obits2
ds = v5spfdataset(
ds_fn,
precompute_cache=dirname,
nthetas=65,
skip_signal_matrix=True,
paired=True,
ignore_qc=True,
gpu=False,
)
args = {
"ds_fn": ds_fn,
"precompute_fn": dirname,
"full_p_fn": heatmap,
"N": 1024 * 4,
"theta_err": 0.01,
"theta_dot_err": 0.01,
}
results = run_single_theta_single_radio(**args)
for result in results:
assert result["metrics"]["mse_theta"] < 0.05
plot_single_theta_single_radio(ds, heatmap)


def test_single_theta_dual_radio(noise1_n128_obits2, heatmap):
dirname, ds_fn = noise1_n128_obits2
ds = v5spfdataset(
ds_fn,
precompute_cache=dirname,
nthetas=65,
skip_signal_matrix=True,
paired=True,
ignore_qc=True,
gpu=False,
)
args = {
"ds_fn": ds_fn,
"precompute_fn": dirname,
"full_p_fn": heatmap,
"N": 1024 * 4,
"theta_err": 0.01,
"theta_dot_err": 0.01,
}
result = run_single_theta_dual_radio(**args)
assert result[0]["metrics"]["mse_theta"] < 0.15
plot_single_theta_dual_radio(ds, heatmap)


def test_single_theta_dual_radio(noise1_n128_obits2, heatmap):
dirname, ds_fn = noise1_n128_obits2
ds = v5spfdataset(
ds_fn,
precompute_cache=dirname,
nthetas=65,
skip_signal_matrix=True,
paired=True,
ignore_qc=True,
gpu=False,
)
args = {
"ds_fn": ds_fn,
"precompute_fn": dirname,
"full_p_fn": heatmap,
"N": 1024 * 4,
"pos_err": 50,
"vel_err": 0.1,
}

result = run_xy_dual_radio(**args)
assert result[0]["metrics"]["mse_theta"] < 0.25
plot_xy_dual_radio(ds, heatmap)

0 comments on commit e6e3a3a

Please sign in to comment.