Skip to content

Commit

Permalink
add latent dim as param
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 15, 2024
1 parent 4f06ad6 commit 4c56c31
Showing 1 changed file with 42 additions and 4 deletions.
46 changes: 42 additions & 4 deletions spf/notebooks/simple_train_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
n_layers=24,
output_dim=1,
token_dropout=0.5,
latent=0,
):
super(FunkyNet, self).__init__()
self.z = torch.nn.Linear(10, 3)
Expand All @@ -127,7 +128,7 @@ def __init__(
)
self.input_net = torch.nn.Sequential(
torch.nn.Linear(
input_dim + 5 * 2 + 1, d_model
input_dim + (5 + latent) * 2 + 1, d_model
) # 5 output beam_former R1+R2, time
)
self.output_dim = output_dim
Expand All @@ -145,7 +146,7 @@ def __init__(
inputs=3, # + (1 if args.rx_spacing else 0),
norm="layer",
positional_encoding=False,
latent=0,
latent=latent,
max_angle=np.pi / 2,
linear_sigmas=True,
correction=True,
Expand Down Expand Up @@ -347,7 +348,14 @@ def simple_train(args):

# init model here
#######
m = FunkyNet().to(torch_device)
m = FunkyNet(
d_hid=args.tformer_dhid,
d_model=args.tformer_dmodel,
dropout=args.tformer_dropout,
token_dropout=args.tformer_snapshot_dropout,
n_layers=args.tformer_layers,
latent=args.beamnet_latent,
).to(torch_device)
# m = DebugFunkyNet().to(torch_device)
########

Expand Down Expand Up @@ -416,7 +424,7 @@ def new_log():
): # , total=len(train_dataloader)):
# if step > 200:
# return
if torch.rand(1).item() < 0.05:
if torch.rand(1).item() < 0.02:
gc.collect()
if step % args.save_every == 0:
m.eval()
Expand Down Expand Up @@ -721,6 +729,36 @@ def get_parser():
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--tformer-layers",
type=int,
default=24,
)
parser.add_argument(
"--tformer-dmodel",
type=int,
default=2048,
)
parser.add_argument(
"--tformer-dhid",
type=int,
default=512,
)
parser.add_argument(
"--tformer-dropout",
type=int,
default=0.1,
)
parser.add_argument(
"--tformer-snapshot-dropout",
type=int,
default=0.5,
)
parser.add_argument(
"--beamnet-latent",
type=int,
default=0,
)
parser.add_argument("--save-prefix", type=str, default="./this_model_")
return parser

Expand Down

0 comments on commit 4c56c31

Please sign in to comment.