diff --git a/src/defense/train.py b/src/defense/train.py index e03a69f..26ea313 100644 --- a/src/defense/train.py +++ b/src/defense/train.py @@ -1,7 +1,7 @@ from utils.dataset import NoisyMnist, NoisyCifar10 from models import Reformer -import sys +import argparse import tensorflow as tf keras = tf.keras @@ -26,14 +26,31 @@ def train_cifar10_reformer(epochs: int = 100): if __name__ == "__main__": - if len(sys.argv) != 2: - exit("Error: select 'mnist' or 'cifar10'.") + parser = argparse.ArgumentParser(description="Training defensive reformer models.") + parser.add_argument( + "--dataset", + "-d", + metavar="DATASET", + type=str, + help="Dataset for training", + required=True, + choices=["mnist", "cifar10"], + ) + + parser.add_argument( + "--epochs", + "-e", + metavar="EPOCHS", + type=int, + help="Target epochs", + required=False, + ) + + args = parser.parse_args() - option = sys.argv[1] + epochs = args.epochs if args.epochs is not None else 500 - if option == "mnist": - train_mnist_reformer() - elif option == "cifar10": - train_cifar10_reformer() - else: - exit("Error: select 'mnist' or 'cifar10'.") + if args.dataset == "mnist": + train_mnist_reformer(epochs) + elif args.dataset == "cifar10": + train_cifar10_reformer(epochs)