Skip to content

Commit

Permalink
feat(Architectures): Started working on Vision Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 28, 2022
1 parent 09d4b13 commit 2c8f13a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
8 changes: 6 additions & 2 deletions aucmedi/neural_network/architectures/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
from aucmedi.neural_network.architectures.image.vgg19 import Architecture_VGG19
# Xception
from aucmedi.neural_network.architectures.image.xception import Architecture_Xception
# Vision Transformer (ViT)
from aucmedi.neural_network.architectures.image.vit_l32 import Architecture_ViT_L32

#-----------------------------------------------------#
# Access Functions to Architecture Classes #
Expand Down Expand Up @@ -97,7 +99,8 @@
"NASNetLarge": Architecture_NASNetLarge,
"VGG16": Architecture_VGG16,
"VGG19": Architecture_VGG19,
"Xception": Architecture_Xception
"Xception": Architecture_Xception,
"ViT_L32": Architecture_ViT_L32,
}
""" Dictionary of implemented 2D Architectures Methods in AUCMEDI.
Expand Down Expand Up @@ -178,7 +181,8 @@
"NASNetLarge": "tf",
"VGG16": "caffe",
"VGG19": "caffe",
"Xception": "tf"
"Xception": "tf",
"ViT_L32": "tf",
}
""" Dictionary of recommended [Standardize][aucmedi.data_processing.subfunctions.standardize] techniques for 2D Architectures Methods in AUCMEDI.
Expand Down
25 changes: 25 additions & 0 deletions tests/test_architectures_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,3 +702,28 @@ def test_Xception(self):
self.assertTrue(sdm_global["2D.Xception"] == "tf")
self.datagen_GRAY.sf_resize = Resize(shape=(32, 32))
self.datagen_RGB.sf_resize = Resize(shape=(32, 32))

#-------------------------------------------------#
# Architecture: ViT L32 #
#-------------------------------------------------#
def test_ViT_L32(self):
self.datagen_GRAY.sf_resize = Resize(shape=(384, 384))
self.datagen_RGB.sf_resize = Resize(shape=(384, 384))
arch = Architecture_ViT_L32(Classifier(n_labels=4), channels=1,
input_shape=(384, 384))
model = Neural_Network(n_labels=4, channels=1, architecture=arch,
batch_queue_size=1)
model.predict(self.datagen_GRAY)
arch = Architecture_ViT_L32(Classifier(n_labels=4), channels=3,
input_shape=(384, 384))
model = Neural_Network(n_labels=4, channels=3, architecture=arch,
batch_queue_size=1)
model.predict(self.datagen_RGB)
model = Neural_Network(n_labels=4, channels=3, architecture="2D.ViT_L32",
batch_queue_size=1, input_shape=(384, 384))
try : model.model.summary()
except : raise Exception()
self.assertTrue(supported_standardize_mode["ViT_L32"] == "tf")
self.assertTrue(sdm_global["2D.ViT_L32"] == "tf")
self.datagen_GRAY.sf_resize = Resize(shape=(384, 384))
self.datagen_RGB.sf_resize = Resize(shape=(384, 384))

0 comments on commit 2c8f13a

Please sign in to comment.