diff --git a/src/victim/models.py b/src/victim/models.py index 78eb361..1e926cb 100644 --- a/src/victim/models.py +++ b/src/victim/models.py @@ -11,14 +11,27 @@ class Classifier(Model): def _model(self) -> keras.Model: return keras.Sequential( [ - keras.layers.Conv2D(32, (3, 3), activation="relu"), - keras.layers.Conv2D(32, (3, 3), activation="relu"), + keras.layers.Conv2D(32, (3, 3), activation="relu", padding="same"), + keras.layers.BatchNormalization(), + keras.layers.Conv2D(32, (3, 3), activation="relu", padding="same"), + keras.layers.BatchNormalization(), keras.layers.MaxPooling2D((2, 2)), - keras.layers.Conv2D(64, (3, 3), activation="relu"), - keras.layers.Conv2D(64, (3, 3), activation="relu"), + keras.layers.Dropout(0.2), + keras.layers.Conv2D(64, (3, 3), activation="relu", padding="same"), + keras.layers.BatchNormalization(), + keras.layers.Conv2D(64, (3, 3), activation="relu", padding="same"), + keras.layers.BatchNormalization(), keras.layers.MaxPooling2D((2, 2)), + keras.layers.Dropout(0.3), + keras.layers.Conv2D(128, (3, 3), activation="relu", padding="same"), + keras.layers.BatchNormalization(), + keras.layers.Conv2D(128, (3, 3), activation="relu", padding="same"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((2, 2)), + keras.layers.Dropout(0.4), keras.layers.Flatten(), - keras.layers.Dense(512, activation="relu"), + keras.layers.Dense(256, activation="relu"), + keras.layers.Dropout(0.5), keras.layers.Dense(10, activation="softmax"), ], )