Skip to content

Commit

Permalink
fix(Architectures): Fixed incorrect input shapes for B16 and B32 Visi…
Browse files Browse the repository at this point in the history
…on Transformer Architecture
  • Loading branch information
muellerdo committed May 28, 2022
1 parent 33068ac commit e44916b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion aucmedi/neural_network/architectures/image/vit_b16.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Architecture_ViT_B16(Architecture_Base):
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self, classification_head, channels, input_shape=(384, 384),
def __init__(self, classification_head, channels, input_shape=(224, 224),
pretrained_weights=False):
self.classifier = classification_head
self.input = input_shape + (channels,)
Expand Down
2 changes: 1 addition & 1 deletion aucmedi/neural_network/architectures/image/vit_b32.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Architecture_ViT_B32(Architecture_Base):
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self, classification_head, channels, input_shape=(384, 384),
def __init__(self, classification_head, channels, input_shape=(224, 224),
pretrained_weights=False):
self.classifier = classification_head
self.input = input_shape + (channels,)
Expand Down

0 comments on commit e44916b

Please sign in to comment.