Skip to content

Commit

Permalink
Merge pull request #8 from ariG23498/masking
Browse files Browse the repository at this point in the history
Masking from 60% to 75%
  • Loading branch information
sayakpaul authored Nov 21, 2021
2 parents aa44408 + 45fc66d commit 7a15186
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Our main objective is to present the core idea of the proposed method in a minim
</div><br>


With just **100 epochs** of pre-training and a fairly lightweight Autoencoder architecture we achieve **44.25%** accuracy
With just **100 epochs** of pre-training and a fairly lightweight Autoencoder architecture we achieve **46.85%** accuracy
with linear probing on the **CIFAR-10** dataset. Our training logs and encoder weights are available inside the
[`encoder_weights_logs`](https://github.com/ariG23498/mae-scalable-vision-learners/tree/master/encoder_weights_logs) directory.
For comparison, we took the encoder architecture and trained it from scratch (refer to [`regular-classification.ipynb`](https://github.com/ariG23498/mae-scalable-vision-learners/blob/master/regular-classification.ipynb)) in a fully supervised manner. This gave us ~76% test top-1 accuracy.
Expand Down
Binary file not shown.
Binary file not shown.
60 changes: 39 additions & 21 deletions mae-pretraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "80ZKaTtG9zw9",
"outputId": "944d16c1-42b0-47aa-d3f8-04593bf18faa"
"outputId": "467ea57f-03ac-43d5-e19a-a5a8ac99c393"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -91,7 +91,7 @@
"IMAGE_SIZE = 48 # We'll resize input images to this size.\n",
"PATCH_SIZE = 6 # Size of the patches to be extract from the input images.\n",
"NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2\n",
"MASK_PROPORTION = 0.6\n",
"MASK_PROPORTION = 0.75\n",
"\n",
"# ENCODER and DECODER\n",
"LAYER_NORM_EPS = 1e-6\n",
Expand Down Expand Up @@ -130,7 +130,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "lMOYr_h1_QY6",
"outputId": "bfe4e094-3b98-450b-955a-e25b8b80eac5"
"outputId": "e0201dc6-c933-4cf8-c54f-7fde8e3d27de"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -291,7 +291,7 @@
"height": 496
},
"id": "ptI3I2aMB_rS",
"outputId": "2dc21f01-9c8d-43c5-de0f-d16f2a613d93"
"outputId": "895d0714-f9ba-4235-ca8a-b3c21ec7091c"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -322,7 +322,7 @@
"height": 248
},
"id": "qv0Va_68CPmF",
"outputId": "6e7e902e-cf50-48e9-f3f6-d52ae3fcd8e8"
"outputId": "5dd6d051-2740-452b-f57c-2131b7b8a3b1"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -497,7 +497,7 @@
"height": 301
},
"id": "UlxangzdFwMJ",
"outputId": "bee08ebe-cc62-43be-f650-498cb4bd3c52"
"outputId": "72caf44d-71d8-4e95-b66d-f4a7a75b0c90"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -749,7 +749,9 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "oCOI8_9BX_6g"
},
"source": [
"# Model init"
]
Expand Down Expand Up @@ -783,7 +785,9 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "TCDOA_kdX_6h"
},
"source": [
"## Training callbacks"
]
Expand Down Expand Up @@ -850,7 +854,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"id": "s4PEJxppX_6h"
},
"outputs": [],
"source": [
"# Some code is taken from:\n",
Expand Down Expand Up @@ -901,7 +907,14 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 283
},
"id": "iTn6VcaBX_6h",
"outputId": "c0df7a58-763f-4a1f-bbf3-1e1f2e64784f"
},
"outputs": [],
"source": [
"total_steps = int((len(x_train) / BATCH_SIZE) * EPOCHS)\n",
Expand Down Expand Up @@ -938,7 +951,9 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "-dUZ1RYqX_6h"
},
"source": [
"# Compilation and training"
]
Expand Down Expand Up @@ -967,7 +982,7 @@
"height": 1000
},
"id": "ZUAXzpDoJiXG",
"outputId": "1488c9ba-08b1-4a11-970b-f9191d1af41a"
"outputId": "2015187b-0bfb-445e-a687-018a9284a649"
},
"outputs": [],
"source": [
Expand All @@ -984,7 +999,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "S0aVUm63Lj-L",
"outputId": "0634a201-0b1a-4fa3-c079-11ede2cae9d5"
"outputId": "cbe80727-6cc3-4db3-ec5d-cf01469f38d0"
},
"outputs": [],
"source": [
Expand All @@ -1010,7 +1025,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "kXnO5jNALndF",
"outputId": "2f1dc3e9-70f7-44b4-e4cf-9a15ea3e0bf2"
"outputId": "8b6c45a7-10bf-4e58-f866-f691b5352b12"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1084,7 +1099,7 @@
"base_uri": "https://localhost:8080/"
},
"id": "xdeuZ98oLvis",
"outputId": "75bdb66e-ffcb-454a-e4ff-1bb29e64cd73"
"outputId": "c9f7b566-47e9-4afa-a86d-728df915cd99"
},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1116,7 +1131,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0o8DHVfAMrSL"
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0o8DHVfAMrSL",
"outputId": "c98b307d-c002-4507-ac21-0ca777cf3f06"
},
"outputs": [],
"source": [
Expand All @@ -1127,10 +1146,9 @@
"metadata": {
"accelerator": "GPU",
"colab": {
"authorship_tag": "ABX9TyMop/SYrAmInThOJSwKchR0",
"collapsed_sections": [],
"machine_shape": "hm",
"name": "mae.ipynb",
"name": "mae-pretraining.ipynb",
"provenance": []
},
"environment": {
Expand All @@ -1140,7 +1158,7 @@
"uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m84"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -1154,9 +1172,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.2"
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 1
}

0 comments on commit 7a15186

Please sign in to comment.