Skip to content

Commit

Permalink
add save load functionality to filter models
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 17, 2024
1 parent cd65a17 commit 54ba338
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 82 deletions.
274 changes: 194 additions & 80 deletions spf/notebooks/simple_train_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,80 @@
LayerNorm,
)

from torch.utils.data import DistributedSampler, Sampler, BatchSampler


# from fair-chem repo
class StatefulDistributedSampler(DistributedSampler):
"""
More fine-grained state DataSampler that uses training iteration and epoch
both for shuffling data. PyTorch DistributedSampler only uses epoch
for the shuffling and starts sampling data from the start. In case of training
on very large data, we train for one epoch only and when we resume training,
we want to resume the data sampler from the training iteration.
"""

def __init__(self, dataset, batch_size, **kwargs):
"""
Initializes the instance of StatefulDistributedSampler. Random seed is set
for the epoch set and data is shuffled. For starting the sampling, use
the start_iter (set to 0 or set by checkpointing resuming) to
sample data from the remaining images.
Args:
dataset (Dataset): Pytorch dataset that sampler will shuffle
batch_size (int): batch size we want the sampler to sample
seed (int): Seed for the torch generator.
"""
super().__init__(dataset=dataset, **kwargs)

self.start_iter = 0
self.batch_size = batch_size
assert self.batch_size > 0, "batch_size not set for the sampler"
# logging.info(f"rank: {self.rank}: Sampler created...")

def __iter__(self):
# TODO: For very large datasets, even virtual datasets this might slow down
# or not work correctly. The issue is that we enumerate the full list of all
# samples in a single epoch, and manipulate this list directly. A better way
# of doing this would be to keep this sequence strictly as an iterator
# that stores the current state (instead of the full sequence)
distributed_sampler_sequence = super().__iter__()
if self.start_iter > 0:
for i, _ in enumerate(distributed_sampler_sequence):
if i == self.start_iter * self.batch_size - 1:
break
return distributed_sampler_sequence

def set_epoch_and_start_iteration(self, epoch, start_iter):
self.set_epoch(epoch)
self.start_iter = start_iter


class StatefulBatchsampler(Sampler):
def __init__(self, dataset, batch_size, seed=0, shuffle=False):
self.single_sampler = StatefulDistributedSampler(
dataset,
num_replicas=1,
rank=0,
shuffle=shuffle,
drop_last=False,
batch_size=batch_size,
seed=seed,
)
self.batch_sampler = BatchSampler(
self.single_sampler,
batch_size,
drop_last=False,
)

def set_epoch_and_start_iteration(self, epoch: int, start_iteration: int) -> None:
self.single_sampler.set_epoch_and_start_iteration(epoch, start_iteration)

def __iter__(self):
yield from self.batch_sampler


# torch.autograd.set_detect_anomaly(True)


Expand All @@ -41,10 +115,11 @@ def worker_init_fn(worker_id):
_dataset.reinit()


def save_everything(model, optimizer, config, step, path):
def save_everything(model, optimizer, config, step, epoch, path):
torch.save(
{
"step": step,
"epoch": epoch,
"config": config,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
Expand Down Expand Up @@ -248,11 +323,13 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):

beamnet_mse = self.beam_m.mse(output["pred_theta"], y_rad_reduced)
beamnet_mse_random = (
(
torch_pi_norm(
y_rad_reduced
- (torch.rand(y_rad_reduced.shape, device=target.device) - 0.5)
* 2
* np.pi
/ 2,
max_angle=torch.pi / 2,
)
** 2
).mean()
Expand All @@ -267,7 +344,7 @@ def loss(self, output, y_rad, craft_y_rad, seg_mask):
}


def simple_train(args):
def simple_train_filter(args):
assert args.n_radios == 2
# torch.autograd.detect_anomaly()
# "/Volumes/SPFData/missions/april5/wallarrayv3_2024_05_06_19_04_15_nRX2_bounce",
Expand All @@ -278,64 +355,6 @@ def simple_train(args):
random.seed(args.seed)

assert args.n_radios in [1, 2]
# loop over and concat datasets here
datasets = [
v5spfdataset(
prefix,
precompute_cache=args.precompute_cache,
nthetas=args.nthetas,
skip_signal_matrix=True,
paired=args.n_radios > 1,
ignore_qc=args.skip_qc,
gpu=args.device == "cuda",
snapshots_per_session=args.snapshots_per_session,
readahead=False,
skip_simple_segmentations=True,
)
for prefix in args.datasets
]
for ds in datasets:
ds.get_segmentation()
complete_ds = torch.utils.data.ConcatDataset(datasets)

if args.val_on_train:
train_ds = complete_ds
val_ds = complete_ds
else:
n = len(complete_ds)
train_idxs = range(int((1.0 - args.val_holdout_fraction) * n))
val_idxs = list(range(train_idxs[-1] + 1, n))

shuffle(val_idxs)
val_idxs = val_idxs[: max(1, int(len(val_idxs) * args.val_subsample_fraction))]

train_ds = torch.utils.data.Subset(complete_ds, train_idxs)
val_ds = torch.utils.data.Subset(complete_ds, val_idxs)
print(f"Train-dataset size {len(train_ds)}, Val dataset size {len(val_ds)}")

dataloader_params = {
"batch_size": args.batch,
"shuffle": args.shuffle,
"num_workers": args.workers,
"collate_fn": partial(
v5_collate_keys_fast,
[
"all_windows_stats",
"rx_pos_xy",
"downsampled_segmentation_mask",
"rx_spacing",
"windowed_beamformer",
"y_rad",
"craft_y_rad",
"y_phi",
],
),
"worker_init_fn": worker_init_fn,
# "pin_memory": True,
"prefetch_factor": 2 if args.workers > 0 else None,
}
train_dataloader = torch.utils.data.DataLoader(train_ds, **dataloader_params)
val_dataloader = torch.utils.data.DataLoader(val_ds, **dataloader_params)

scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

Expand All @@ -350,15 +369,17 @@ def simple_train(args):

# init model here
#######
m = FunkyNet(
d_hid=args.tformer_dhid,
d_model=args.tformer_dmodel,
dropout=args.tformer_dropout,
token_dropout=args.tformer_snapshot_dropout,
n_layers=args.tformer_layers,
latent=args.beamnet_latent,
).to(torch_device)
# m = DebugFunkyNet().to(torch_device)
if args.debug_model:
m = DebugFunkyNet().to(torch_device)
else:
m = FunkyNet(
d_hid=args.tformer_dhid,
d_model=args.tformer_dmodel,
dropout=args.tformer_dropout,
token_dropout=args.tformer_snapshot_dropout,
n_layers=args.tformer_layers,
latent=args.beamnet_latent,
).to(torch_device)
########

if args.wandb_project:
Expand Down Expand Up @@ -420,13 +441,94 @@ def new_log():
}

to_log = new_log()
for _ in range(args.epochs):
step = 0
epoch = 0

if args.load_checkpoint is not None:
checkpoint = torch.load(args.load_checkpoint)
step = checkpoint["step"]
epoch = checkpoint["epoch"]
m.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
config = checkpoint["config"]

# loop over and concat datasets here
datasets = [
v5spfdataset(
prefix,
precompute_cache=args.precompute_cache,
nthetas=args.nthetas,
skip_signal_matrix=True,
paired=args.n_radios > 1,
ignore_qc=args.skip_qc,
gpu=args.device == "cuda",
snapshots_per_session=args.snapshots_per_session,
readahead=False,
skip_simple_segmentations=True,
)
for prefix in args.datasets
]
for ds in datasets:
ds.get_segmentation()
complete_ds = torch.utils.data.ConcatDataset(datasets)

if args.val_on_train:
train_ds = complete_ds
val_ds = complete_ds
else:
n = len(complete_ds)
train_idxs = range(int((1.0 - args.val_holdout_fraction) * n))
val_idxs = list(range(train_idxs[-1] + 1, n))

shuffle(val_idxs)
val_idxs = val_idxs[: max(1, int(len(val_idxs) * args.val_subsample_fraction))]

train_ds = torch.utils.data.Subset(complete_ds, train_idxs)
val_ds = torch.utils.data.Subset(complete_ds, val_idxs)
print(f"Train-dataset size {len(train_ds)}, Val dataset size {len(val_ds)}")

def params_for_ds(ds):
sampler = StatefulBatchsampler(
ds, shuffle=args.shuffle, seed=args.seed, batch_size=args.batch
)
sampler.set_epoch_and_start_iteration(epoch=epoch, start_iteration=step)
return {
# "batch_size": args.batch,
"num_workers": args.workers,
"collate_fn": partial(
v5_collate_keys_fast,
[
"all_windows_stats",
"rx_pos_xy",
"downsampled_segmentation_mask",
"rx_spacing",
"windowed_beamformer",
"y_rad",
"craft_y_rad",
"y_phi",
],
),
"worker_init_fn": worker_init_fn,
# "pin_memory": True,
"prefetch_factor": 2 if args.workers > 0 else None,
"batch_sampler": sampler,
}

train_dataloader = torch.utils.data.DataLoader(train_ds, **params_for_ds(train_ds))
val_dataloader = torch.utils.data.DataLoader(val_ds, **params_for_ds(val_ds))

for epoch in range(args.epochs):
# breakpoint()
for step, batch_data in enumerate(
if step >= args.steps:
break

for _, batch_data in enumerate(
tqdm(train_dataloader)
): # , total=len(train_dataloader)):
# if step > 200:
# return
if step >= args.steps:
break
if torch.rand(1).item() < 0.02:
gc.collect()
if step % args.save_every == 0:
Expand All @@ -436,6 +538,7 @@ def new_log():
optimizer=optimizer,
config=args,
step=step,
epoch=epoch,
path=f"{args.save_prefix}_step{step}.chkpnt",
)
if step % args.val_every == 0:
Expand Down Expand Up @@ -520,10 +623,9 @@ def new_log():
)
loss_d = m.loss(output, y_rad, craft_y_rad, seg_mask)

# if step < args.head_start:
# loss = loss_d["beamnet_loss"] * 100
# else:
loss = loss_d["loss"]
loss = loss_d["beamnet_loss"] * 100
if step > args.head_start:
loss += loss_d["loss"]

# loss = loss_d["beamnet_loss"]
# loss.backward()
Expand Down Expand Up @@ -567,7 +669,7 @@ def new_log():
return {"losses": losses}


def get_parser():
def get_parser_filter():
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
Expand Down Expand Up @@ -625,6 +727,12 @@ def get_parser():
required=False,
default=1000,
)
parser.add_argument(
"--steps",
type=int,
required=False,
default=None,
)
parser.add_argument(
"--depth",
type=int,
Expand Down Expand Up @@ -697,6 +805,7 @@ def get_parser():
type=str,
required=True,
)
parser.add_argument("--load-checkpoint", type=str, required=False, default=None)
parser.add_argument(
"--precompute-cache",
type=str,
Expand All @@ -717,7 +826,7 @@ def get_parser():
parser.add_argument(
"--amp",
action=argparse.BooleanOptionalAction,
default=True,
default=False,
)
parser.add_argument(
"--val-on-train",
Expand Down Expand Up @@ -784,6 +893,11 @@ def get_parser():
type=int,
default=10000,
)
parser.add_argument(
"--debug-model",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument("--save-prefix", type=str, default="./this_model_")
return parser

Expand All @@ -792,12 +906,12 @@ def get_parser():
# from pyinstrument.renderers import ConsoleRenderer

if __name__ == "__main__":
parser = get_parser()
parser = get_parser_filter()
args = parser.parse_args()
# with Profile() as profile:
# profiler = Profiler()
# profiler.start()
simple_train(args)
simple_train_filter(args)

# session = profiler.stop()
# # (Stats(profile).strip_dirs().sort_stats(SortKey.TIME).print_stats(200))
Expand Down
Loading

0 comments on commit 54ba338

Please sign in to comment.