Skip to content

Commit

Permalink
(#4) Refactor: argparse
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed Mar 24, 2022
1 parent 0a704d3 commit b6619e1
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions src/defense/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utils.dataset import NoisyMnist, NoisyCifar10
from models import Reformer

import sys
import argparse
import tensorflow as tf

keras = tf.keras
Expand All @@ -26,14 +26,31 @@ def train_cifar10_reformer(epochs: int = 100):


if __name__ == "__main__":
if len(sys.argv) != 2:
exit("Error: select 'mnist' or 'cifar10'.")
parser = argparse.ArgumentParser(description="Training defensive reformer models.")
parser.add_argument(
"--dataset",
"-d",
metavar="DATASET",
type=str,
help="Dataset for training",
required=True,
choices=["mnist", "cifar10"],
)

parser.add_argument(
"--epochs",
"-e",
metavar="EPOCHS",
type=int,
help="Target epochs",
required=False,
)

args = parser.parse_args()

option = sys.argv[1]
epochs = args.epochs if args.epochs is not None else 500

if option == "mnist":
train_mnist_reformer()
elif option == "cifar10":
train_cifar10_reformer()
else:
exit("Error: select 'mnist' or 'cifar10'.")
if args.dataset == "mnist":
train_mnist_reformer(epochs)
elif args.dataset == "cifar10":
train_cifar10_reformer(epochs)

0 comments on commit b6619e1

Please sign in to comment.