Skip to content

Commit

Permalink
add craft y rad to single point
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Sep 16, 2024
1 parent a3dc291 commit b40819b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
3 changes: 3 additions & 0 deletions spf/dataset/spf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,9 @@ def render_session(self, receiver_idx, session_idx):
data["y_rad_binned"] = (
to_bin(data["y_rad"], self.nthetas).unsqueeze(0).to(torch.long)
)
data["craft_y_rad_binned"] = (
to_bin(data["craft_y_rad"], self.nthetas).unsqueeze(0).to(torch.long)
)

# convert to target dtype on CPU!
for key in data:
Expand Down
51 changes: 49 additions & 2 deletions spf/scripts/train_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def params_for_ds(ds, batch_size, resume_step):
"system_timestamp",
"empirical",
"y_rad_binned",
"craft_y_rad_binned",
]
if global_config["beamformer_input"]:
# keys_to_get += ["windowed_beamformer"]
Expand Down Expand Up @@ -154,24 +155,59 @@ def __init__(self, model_config, global_config):
act=nn.LeakyReLU,
bn=model_config["bn"], # False , bool
)
self.paired = False

def forward(self, batch):

# x = self.net(
# torch.concatenate([batch["empirical"], weighted_beamformer(batch)], dim=2)
# )
# breakpoint()
x = self.net(
torch.concatenate(
[batch["empirical"], batch["weighted_beamformer"] / 256], dim=2
)
)
# first dim odd / even is the radios
return torch.nn.functional.softmax(x, dim=2)


class PairedSinglePointWithBeamformer(nn.Module):
def __init__(self, model_config, global_config):
super().__init__()
self.single_radio_net = SinglePointWithBeamformer(model_config, global_config)
self.net = FFNN(
inputs=global_config["nthetas"] * 2,
depth=model_config["depth"], # 4
hidden=model_config["hidden"], # 128
outputs=global_config["nthetas"],
block=model_config["block"], # True
norm=model_config["norm"], # False, [batch,layer]
act=nn.LeakyReLU,
bn=model_config["bn"], # False , bool
)
self.paired = True

def forward(self, batch):
single_radio_estimates = self.single_radio_net(batch)
x = self.net(
torch.concatenate(
[single_radio_estimates[::2], single_radio_estimates[1::2]], dim=2
)
)
idxs = torch.arange(x.shape[0]).reshape(-1, 1).repeat(1, 2).reshape(-1)
# first dim odd / even is the radios
return (
torch.nn.functional.softmax(x[idxs], dim=2) * 0.5
+ single_radio_estimates * 0.5
)


class SinglePointPassThrough(nn.Module):
def __init__(self, model_config, global_config):
super().__init__()
self.w = torch.nn.Parameter(torch.zeros(1))
self.paired = False

def forward(self, batch):
# weighted_beamformer(batch)
Expand All @@ -191,6 +227,9 @@ def load_model(model_config, global_config):
return SinglePointPassThrough(model_config, global_config)
elif model_config["name"] == "beamformer":
return SinglePointWithBeamformer(model_config, global_config)
elif model_config["name"] == "pairedbeamformer":
return PairedSinglePointWithBeamformer(model_config, global_config)

raise ValueError


Expand All @@ -201,8 +240,12 @@ def uniform_preds(batch):
)


def target_from_scatter(batch, sigma=0.0, k=3):
y_rad_binned = batch["y_rad_binned"][..., None]
def target_from_scatter(batch, paired, sigma, k):
if paired:
y_rad_binned = batch["craft_y_rad_binned"][..., None]
else:
y_rad_binned = batch["y_rad_binned"][..., None]

scattered_targets = torch.zeros(
*batch["empirical"].shape, device=batch["empirical"].device
).scatter(
Expand Down Expand Up @@ -332,6 +375,7 @@ def train_single_point(args):
output = m(val_batch_data)
target = target_from_scatter(
val_batch_data,
paired=m.paired,
sigma=config["datasets"]["sigma"],
k=config["datasets"]["scatter_k"],
)
Expand All @@ -347,6 +391,7 @@ def train_single_point(args):
),
}
# breakpoint()
# breakpoint()
# for accumulaing and averaging
for key, value in loss_d.items():
val_losses[key].append(value.item())
Expand All @@ -369,6 +414,7 @@ def train_single_point(args):

target = target_from_scatter(
batch_data,
paired=m.paired,
sigma=config["datasets"]["sigma"],
k=config["datasets"]["scatter_k"],
)
Expand All @@ -377,6 +423,7 @@ def train_single_point(args):
loss_d = {
"loss": discrete_loss(output, target),
}
# print(loss_d)

# loss = loss_d["beamnet_loss"]
# loss.backward()
Expand Down
3 changes: 2 additions & 1 deletion spf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,5 +410,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(x)


def to_bin(x, bins):
@torch.jit.script
def to_bin(x: torch.Tensor, bins: int):
return ((x / (2 * torch.pi) + 0.5) * bins).to(torch.long)

0 comments on commit b40819b

Please sign in to comment.