Skip to content

Commit

Permalink
add direct theta estimate output for multipaired
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Sep 28, 2024
1 parent f6a6565 commit b782b40
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 28 deletions.
47 changes: 38 additions & 9 deletions spf/model_training_and_inference/models/single_point_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class PairedSinglePointWithBeamformer(nn.Module):
def __init__(self, model_config, global_config):
super().__init__()
self.single_radio_net = SinglePointWithBeamformer(model_config, global_config)

self.detach = model_config["detach"]
self.net = FFNN(
inputs=global_config["nthetas"] * 2,
depth=model_config["depth"], # 4
Expand All @@ -91,9 +91,15 @@ def __init__(self, model_config, global_config):

def forward(self, batch):
single_radio_estimates = self.single_radio_net(batch)["single"]

single_radio_estimates_input = detach_or_not(
single_radio_estimates, self.detach
)

x = self.net(
torch.concatenate(
[single_radio_estimates[::2], single_radio_estimates[1::2]], dim=2
[single_radio_estimates_input[::2], single_radio_estimates_input[1::2]],
dim=2,
)
)
idxs = torch.arange(x.shape[0]).reshape(-1, 1).repeat(1, 2).reshape(-1)
Expand All @@ -111,13 +117,20 @@ def forward(self, x):
return torch.nn.functional.normalize(x.abs(), p=1, dim=2)


def detach_or_not(x, detach):
if detach:
return x.detach()
return x


class PairedMultiPointWithBeamformer(nn.Module):
def __init__(self, model_config, global_config):
def __init__(self, model_config, global_config, ntheta=65):
super().__init__()
self.ntheta = ntheta
self.multi_radio_net = PairedSinglePointWithBeamformer(
model_config, global_config
)

self.detach = model_config["detach"]
self.transformer_config = model_config["transformer"]
self.d_model = self.transformer_config["d_model"]

Expand All @@ -135,13 +148,17 @@ def __init__(self, model_config, global_config):
LayerNorm(self.d_model),
)
self.input_net = torch.nn.Sequential(
torch.nn.Linear(65 + 1, self.d_model) # 5 output beam_former R1+R2, time
torch.nn.Linear(
self.ntheta + 1, self.d_model
) # 5 output beam_former R1+R2, time
)
# self.output_net = torch.nn.Sequential(
# torch.nn.Linear(self.d_model, 65), # 5 output beam_former R1+R2, time
# NormP1Dim2(),
# )
self.output_net = torch.nn.Linear(self.d_model, 65)
self.output_net = torch.nn.Linear(
self.d_model, self.ntheta + 1
) # output discrete and actual
self.norm = NormP1Dim2()

def forward(self, batch):
Expand All @@ -150,14 +167,26 @@ def forward(self, batch):
normalized_time = batch["system_timestamp"] - batch["system_timestamp"][:, [0]]
normalized_time /= normalized_time[:, [-1]]

output_paired = detach_or_not(output["paired"], self.detach)

input_with_time = torch.concatenate(
[output["paired"], normalized_time.unsqueeze(2)], axis=2
[
output_paired,
normalized_time.unsqueeze(2),
],
axis=2,
)
full_output = self.output_net(
self.transformer_encoder(self.input_net(input_with_time))
)
discrete_output = full_output[..., : self.ntheta]
direct_output = full_output[..., [-1]]

output["multipaired"] = self.norm(
output["paired"] # skip connection for paired
+ self.output_net(self.transformer_encoder(self.input_net(input_with_time)))
output_paired + discrete_output # skip connection for paired
)

output["multipaired_direct"] = direct_output
return output


Expand Down
68 changes: 49 additions & 19 deletions spf/scripts/train_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
SinglePointPassThrough,
SinglePointWithBeamformer,
)
from spf.rf import torch_pi_norm
from spf.utils import StatefulBatchsampler


Expand Down Expand Up @@ -241,6 +242,7 @@ def new_log():
"paired_loss": [],
"multipaired_loss": [],
"learning_rate": [],
"multipaired_direct_loss": [],
}


Expand Down Expand Up @@ -302,10 +304,10 @@ def compute_loss(

fig = None
if plot:
fig, axs = plt.subplots(3, 1, figsize=(6, 9))
fig, axs = plt.subplots(3, 2, figsize=(12, 9))
n = output["single"].shape[0] * output["single"].shape[1]
d = output["single"].shape[2]
show_n = min(n, 10)
show_n = min(n, 20)

if single and "single" in output:
target = scatter_fn(
Expand All @@ -319,11 +321,14 @@ def compute_loss(
loss += loss_d["single_loss"]

if plot:
im = np.zeros((show_n * 2, d))
im[1::2] = output["single"].reshape(n, -1).cpu().detach().numpy()[:show_n]
im[::2] = target.reshape(n, -1).cpu().detach().numpy()[:show_n]
axs[0].imshow(im)
axs[0].set_title("Single")
axs[0, 0].imshow(
output["single"].reshape(n, -1).cpu().detach().numpy()[:show_n]
)
axs[0, 1].imshow(target.reshape(n, -1).cpu().detach().numpy()[:show_n])
axs[0, 0].set_title("1radio x 1timestep (Pred)")
axs[0, 0].set_xticks([0, d // 2, d - 1], labels=["-pi", "0", "+pi"])
axs[0, 1].set_title("1radio x 1timestep (label)")
axs[0, 1].set_xticks([0, d // 2, d - 1], labels=["-pi", "0", "+pi"])

if paired and "paired" in output or "multipaired" in output:
paired_target = scatter_fn(
Expand All @@ -337,22 +342,47 @@ def compute_loss(
loss_d["paired_loss"] = loss_fn(output["paired"], paired_target)
loss += loss_d["paired_loss"]
if plot:
im = np.zeros((show_n * 2, d))
im[1::2] = output["paired"].reshape(n, -1).cpu().detach().numpy()[:show_n]
im[::2] = paired_target.reshape(n, -1).cpu().detach().numpy()[:show_n]
axs[1].imshow(im)
axs[1].set_title("Paired")
axs[1, 0].imshow(
output["paired"].reshape(n, -1).cpu().detach().numpy()[:show_n]
)
axs[1, 1].imshow(
paired_target.reshape(n, -1).cpu().detach().numpy()[:show_n]
)
axs[1, 0].set_title("2radio x 1timestep (pred)")
axs[1, 1].set_title("2radio x 1timestep (label)")
axs[1, 0].set_xticks([0, d // 2, d - 1], labels=["-pi", "0", "+pi"])
axs[1, 1].set_xticks([0, d // 2, d - 1], labels=["-pi", "0", "+pi"])

if paired and "multipaired" in output:
loss_d["multipaired_loss"] = loss_fn(output["multipaired"], paired_target)
loss += loss_d["multipaired_loss"]

loss_d["multipaired_direct_loss"] = (
(
torch_pi_norm(
output["multipaired_direct"] - batch_data["craft_y_rad"][..., None]
)
)
** 2
).mean()
loss += (
loss_d["multipaired_direct_loss"] / 30
) # TODO constant to keep losses balanced
if plot:
im = np.zeros((show_n * 2, d))
im[1::2] = (
axs[2, 0].imshow(
output["multipaired"].reshape(n, -1).cpu().detach().numpy()[:show_n]
)
im[::2] = paired_target.reshape(n, -1).cpu().detach().numpy()[:show_n]
axs[2].imshow(im)
axs[2].set_title("MultiPaired")
axs[2, 1].imshow(
paired_target.reshape(n, -1).cpu().detach().numpy()[:show_n]
)
axs[2, 0].set_title(
f"2radio x {output['multipaired'].shape[1]}timestep (pred)"
)
axs[2, 1].set_title(
f"2radio x {output['multipaired'].shape[1]}timestep (target)"
)
axs[2, 0].set_xticks([0, d // 2, d - 1], labels=["-pi", "0", "+pi"])
axs[2, 1].set_xticks([0, d // 2, d - 1], labels=["-pi", "0", "+pi"])
loss_d["loss"] = loss
return loss_d, fig

Expand Down Expand Up @@ -448,7 +478,7 @@ def train_single_point(args):
datasets_config=config["datasets"],
scatter_fn=scatter_fn,
single=True,
paired=epoch > args.head_start_epoch,
paired=epoch >= config["optim"]["head_start"],
)

# scheduler.step(loss_d["loss"])
Expand Down Expand Up @@ -510,7 +540,7 @@ def train_single_point(args):
plot=step == 0 or (step + 1) % config["logger"]["log_every"] == 0,
scatter_fn=scatter_fn,
single=True,
paired=epoch > args.head_start_epoch,
paired=epoch >= config["optim"]["head_start"],
)

if fig is not None:
Expand Down

0 comments on commit b782b40

Please sign in to comment.