diff --git a/src/typings/models.py b/src/typings/models.py index 6ed29e4..f52cd85 100644 --- a/src/typings/models.py +++ b/src/typings/models.py @@ -18,9 +18,7 @@ def __init__( data_train: tf.data.Dataset = None, data_test: tf.data.Dataset = None, optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam(), - loss: keras.losses.Loss = keras.losses.SparseCategoricalCrossentropy( - from_logits=True - ), + loss: keras.losses.Loss = keras.losses.SparseCategoricalCrossentropy(), accuracy: keras.metrics.Accuracy = keras.metrics.SparseCategoricalAccuracy( name="accuracy" ), diff --git a/src/victim/models.py b/src/victim/models.py index e25e864..f523f93 100644 --- a/src/victim/models.py +++ b/src/victim/models.py @@ -11,13 +11,15 @@ class Classifier(Model): def _model(self) -> keras.Model: return keras.Sequential( [ - keras.layers.Conv2D(32, (3, 3), activation="relu"), - keras.layers.MaxPooling2D((2, 2)), keras.layers.Conv2D(64, (3, 3), activation="relu"), - keras.layers.MaxPooling2D((2, 2)), keras.layers.Conv2D(64, (3, 3), activation="relu"), + keras.layers.MaxPooling2D((2, 2)), + keras.layers.Conv2D(128, (3, 3), activation="relu"), + keras.layers.Conv2D(128, (3, 3), activation="relu"), + keras.layers.MaxPooling2D((2, 2)), keras.layers.Flatten(), - keras.layers.Dense(64, activation="relu"), + keras.layers.Dense(256, activation="relu"), + keras.layers.Dense(256, activation="relu"), keras.layers.Dense(10, activation="softmax"), ], )