From bcf96cf44c6140fc2bb743cfb8182700ae25983b Mon Sep 17 00:00:00 2001 From: misko Date: Wed, 17 Jul 2024 04:33:21 +0000 Subject: [PATCH] fix steps --- spf/notebooks/simple_train_filter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spf/notebooks/simple_train_filter.py b/spf/notebooks/simple_train_filter.py index 7d41656a..cdc1e09d 100644 --- a/spf/notebooks/simple_train_filter.py +++ b/spf/notebooks/simple_train_filter.py @@ -519,7 +519,7 @@ def params_for_ds(ds): for epoch in range(args.epochs): # breakpoint() - if step >= args.steps: + if args.steps >= 0 and step >= args.steps: break for _, batch_data in enumerate( @@ -527,7 +527,7 @@ def params_for_ds(ds): ): # , total=len(train_dataloader)): # if step > 200: # return - if step >= args.steps: + if args.steps >= 0 and step >= args.steps: break if torch.rand(1).item() < 0.02: gc.collect() @@ -731,7 +731,7 @@ def get_parser_filter(): "--steps", type=int, required=False, - default=None, + default=-1, ) parser.add_argument( "--depth",