diff --git a/spf/notebooks/simple_train_filter.py b/spf/notebooks/simple_train_filter.py index 7d41656a..cdc1e09d 100644 --- a/spf/notebooks/simple_train_filter.py +++ b/spf/notebooks/simple_train_filter.py @@ -519,7 +519,7 @@ def params_for_ds(ds): for epoch in range(args.epochs): # breakpoint() - if step >= args.steps: + if args.steps >= 0 and step >= args.steps: break for _, batch_data in enumerate( @@ -527,7 +527,7 @@ def params_for_ds(ds): ): # , total=len(train_dataloader)): # if step > 200: # return - if step >= args.steps: + if args.steps >= 0 and step >= args.steps: break if torch.rand(1).item() < 0.02: gc.collect() @@ -731,7 +731,7 @@ def get_parser_filter(): "--steps", type=int, required=False, - default=None, + default=-1, ) parser.add_argument( "--depth",