From 92b752178f3578966620e4f8994d11c4ce6e81d4 Mon Sep 17 00:00:00 2001 From: Shkarupa Alex Date: Sun, 11 Dec 2022 16:16:48 +0300 Subject: [PATCH 1/2] TF 2.11 compatibility, large model --- convert_weights.py | 6 ++++-- requirements.txt | 6 +++--- tfgcvit/__init__.py | 2 +- tfgcvit/block.py | 2 +- tfgcvit/drop.py | 2 +- tfgcvit/embed.py | 2 +- tfgcvit/extract.py | 2 +- tfgcvit/level.py | 2 +- tfgcvit/mlp.py | 2 +- tfgcvit/model.py | 11 ++++++++++- tfgcvit/norm.py | 2 +- tfgcvit/pad.py | 2 +- tfgcvit/reduce.py | 2 +- tfgcvit/se.py | 2 +- tfgcvit/tests/test_application.py | 3 ++- tfgcvit/tests/test_block.py | 2 +- tfgcvit/tests/test_winatt.py | 2 +- tfgcvit/winatt.py | 2 +- 18 files changed, 33 insertions(+), 21 deletions(-) diff --git a/convert_weights.py b/convert_weights.py index a668f19..9d234bb 100755 --- a/convert_weights.py +++ b/convert_weights.py @@ -10,14 +10,16 @@ 'micro': 'https://drive.google.com/file/d/15kt8VOXdAH_jF77g7pEPk-ZmZF13sHRd/view?usp=sharing', 'tiny': 'https://drive.google.com/file/d/1C9lLgykooDF6CxZDFDnUqw5lEqoFgULh/view?usp=sharing', 'small': 'https://drive.google.com/file/d/1bfEJQNutyDkPHAkgYcKWhjVTT_ZnYXp4/view?usp=sharing', - 'base': 'https://drive.google.com/file/d/1PFugO7dqfS-eubZi-yksM_FcYvUNjXBn/view?usp=sharing' + 'base': 'https://drive.google.com/file/d/1PFugO7dqfS-eubZi-yksM_FcYvUNjXBn/view?usp=sharing', + 'large': 'https://drive.google.com/file/d/1XDvFQrCkK-6QIpdLU1QrXWzjwnzNcH3E/view?usp=sharing' } TF_MODELS = { 'nano': tfgcvit.GCViTNano, 'micro': tfgcvit.GCViTMicro, 'tiny': tfgcvit.GCViTTiny, 'small': tfgcvit.GCViTSmall, - 'base': tfgcvit.GCViTBase + 'base': tfgcvit.GCViTBase, + 'large': tfgcvit.GCViTLarge } diff --git a/requirements.txt b/requirements.txt index 1883327..449295f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -tensorflow>=2.9.0 -keras>=2.9.0 -numpy>=1.22.4 +tensorflow>=2.11.0 +keras>=2.11.0 +numpy>=1.21.4 diff --git a/tfgcvit/__init__.py b/tfgcvit/__init__.py index e84aaac..0889ce0 100644 --- a/tfgcvit/__init__.py +++ b/tfgcvit/__init__.py @@ -1,2 +1,2 @@ -from tfgcvit.model import GCViT, GCViTNano, GCViTMicro, GCViTTiny, GCViTSmall, GCViTBase +from tfgcvit.model import GCViT, GCViTNano, GCViTMicro, GCViTTiny, GCViTSmall, GCViTBase, GCViTLarge from tfgcvit.prep import preprocess_input diff --git a/tfgcvit/block.py b/tfgcvit/block.py index 337d5c7..fde3493 100644 --- a/tfgcvit/block.py +++ b/tfgcvit/block.py @@ -1,6 +1,6 @@ import tensorflow as tf from keras import initializers, layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion from tfgcvit.drop import DropPath from tfgcvit.mlp import MLP diff --git a/tfgcvit/drop.py b/tfgcvit/drop.py index bd35a73..cbff239 100644 --- a/tfgcvit/drop.py +++ b/tfgcvit/drop.py @@ -1,7 +1,7 @@ import tensorflow as tf from keras import backend, layers from keras.utils.control_flow_util import smart_cond -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion diff --git a/tfgcvit/embed.py b/tfgcvit/embed.py index 5ee1dbc..e17aac3 100644 --- a/tfgcvit/embed.py +++ b/tfgcvit/embed.py @@ -1,5 +1,5 @@ from keras import layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion diff --git a/tfgcvit/extract.py b/tfgcvit/extract.py index f610b2a..b512cbb 100644 --- a/tfgcvit/extract.py +++ b/tfgcvit/extract.py @@ -1,5 +1,5 @@ from keras import layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion from .pad import SymmetricPadding from .se import SE diff --git a/tfgcvit/level.py b/tfgcvit/level.py index 77dd34b..7fd6821 100644 --- a/tfgcvit/level.py +++ b/tfgcvit/level.py @@ -1,6 +1,6 @@ import tensorflow as tf from keras import layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion from tfgcvit.block import Block from tfgcvit.extract import FeatExtract diff --git a/tfgcvit/mlp.py b/tfgcvit/mlp.py index 27ebd0e..3e74f2d 100644 --- a/tfgcvit/mlp.py +++ b/tfgcvit/mlp.py @@ -1,5 +1,5 @@ from keras import layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion diff --git a/tfgcvit/model.py b/tfgcvit/model.py index 6b4b5a0..28e1291 100644 --- a/tfgcvit/model.py +++ b/tfgcvit/model.py @@ -15,13 +15,15 @@ 'gcvit_tiny': BASE_URL.format('2.0.0', 'tiny'), 'gcvit_small': BASE_URL.format('2.0.0', 'small'), 'gcvit_base': BASE_URL.format('2.0.0', 'base'), + 'gcvit_large': BASE_URL.format('2.0.3', 'large'), } WEIGHT_HASHES = { 'gcvit_nano': '752926536d36707415c8b17d819fb1bfc48d22fd878edde1f622c76bfe23f690', 'gcvit_micro': 'fcea210cd00d79de3fc681ddaad965ca3601077a27db256d4aacddc1154b5517', 'gcvit_tiny': 'b55e8de5e64174619bf1ffeb11ea2d9b553ce527d6aa4370f5ade875c6e7b1f5', 'gcvit_small': '0d9755ce464c8f4eece85493697c694ea616036d41136a602204c2fddec67b1b', - 'gcvit_base': 'bcf1dd6a59f2ef12b0aa657f30aef0ed67bb6d17e9e91186c03e4da651b28b10' + 'gcvit_base': 'bcf1dd6a59f2ef12b0aa657f30aef0ed67bb6d17e9e91186c03e4da651b28b10', + 'gcvit_large': 'ec0faee8dc7a3537d8fc64d2fdf6011cbfb468cbc37426e9596a4fdef30b475a' } @@ -202,3 +204,10 @@ def GCViTBase(model_name='gcvit_base', window_size=(7, 7, 14, 7), embed_dim=128, return GCViT(model_name=model_name, window_size=window_size, embed_dim=embed_dim, depths=depths, num_heads=num_heads, mlp_ratio=mlp_ratio, path_drop=path_drop, layer_scale=layer_scale, weights=weights, **kwargs) + + +def GCViTLarge(model_name='gcvit_large', window_size=(7, 7, 14, 7), embed_dim=192, depths=(3, 4, 19, 5), + num_heads=(6, 12, 24, 48), mlp_ratio=2., path_drop=0.5, layer_scale=1e-5, weights='imagenet', **kwargs): + return GCViT(model_name=model_name, window_size=window_size, embed_dim=embed_dim, depths=depths, + num_heads=num_heads, mlp_ratio=mlp_ratio, path_drop=path_drop, layer_scale=layer_scale, + weights=weights, **kwargs) diff --git a/tfgcvit/norm.py b/tfgcvit/norm.py index 904d4fd..b415ba4 100644 --- a/tfgcvit/norm.py +++ b/tfgcvit/norm.py @@ -1,7 +1,7 @@ import tensorflow as tf import warnings from keras import layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion diff --git a/tfgcvit/pad.py b/tfgcvit/pad.py index e191713..8e649b5 100644 --- a/tfgcvit/pad.py +++ b/tfgcvit/pad.py @@ -1,6 +1,6 @@ import tensorflow as tf from keras import backend, layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable @register_keras_serializable(package='TFGCVit') diff --git a/tfgcvit/reduce.py b/tfgcvit/reduce.py index 6c3a545..424d211 100644 --- a/tfgcvit/reduce.py +++ b/tfgcvit/reduce.py @@ -1,5 +1,5 @@ from keras import layers, models -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion from tfgcvit.norm import LayerNorm from tfgcvit.se import SE diff --git a/tfgcvit/se.py b/tfgcvit/se.py index 5a37b31..24ce1da 100644 --- a/tfgcvit/se.py +++ b/tfgcvit/se.py @@ -1,5 +1,5 @@ from keras import layers -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion diff --git a/tfgcvit/tests/test_application.py b/tfgcvit/tests/test_application.py index 726cbfa..48770b8 100644 --- a/tfgcvit/tests/test_application.py +++ b/tfgcvit/tests/test_application.py @@ -11,7 +11,8 @@ (tfgcvit.GCViTMicro, 224, 512), (tfgcvit.GCViTTiny, 224, 512), (tfgcvit.GCViTSmall, 224, 768), - (tfgcvit.GCViTBase, 224, 1024) + (tfgcvit.GCViTBase, 224, 1024), + (tfgcvit.GCViTLarge, 224, 1536) ] diff --git a/tfgcvit/tests/test_block.py b/tfgcvit/tests/test_block.py index f728635..335002c 100644 --- a/tfgcvit/tests/test_block.py +++ b/tfgcvit/tests/test_block.py @@ -2,7 +2,7 @@ import tensorflow as tf from keras import layers from keras.testing_infra import test_combinations -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from tfgcvit.block import Block from testing_utils import layer_multi_io_test diff --git a/tfgcvit/tests/test_winatt.py b/tfgcvit/tests/test_winatt.py index e92cb00..f929beb 100644 --- a/tfgcvit/tests/test_winatt.py +++ b/tfgcvit/tests/test_winatt.py @@ -2,7 +2,7 @@ import tensorflow as tf from keras import layers from keras.testing_infra import test_combinations -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from tfgcvit.winatt import WindowAttention from testing_utils import layer_multi_io_test diff --git a/tfgcvit/winatt.py b/tfgcvit/winatt.py index 795656f..a041e04 100644 --- a/tfgcvit/winatt.py +++ b/tfgcvit/winatt.py @@ -2,7 +2,7 @@ import tensorflow as tf from keras import initializers, layers from keras.utils.control_flow_util import smart_cond -from keras.utils.generic_utils import register_keras_serializable +from keras.saving.object_registration import register_keras_serializable from keras.utils.tf_utils import shape_type_conversion From 9478aa21512161c9af2724d0fd16b143f1e8ad97 Mon Sep 17 00:00:00 2001 From: Shkarupa Alex Date: Sun, 11 Dec 2022 16:17:16 +0300 Subject: [PATCH 2/2] Update version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 55b4b50..51b3478 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name='tfgcvit', - version='2.0.2', + version='2.0.3', description='Keras (TensorFlow v2) reimplementation of Global Context Vision Transformer models.', long_description=long_description, long_description_content_type="text/markdown",