-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add docker tea
- Loading branch information
Showing
8 changed files
with
524 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
FROM python:3.9 | ||
|
||
WORKDIR /usr/src/fast-transformer | ||
|
||
RUN git clone https://github.com/Rishit-dagli/Fast-Transformer . | ||
|
||
RUN pip install --no-cache-dir . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import tensorflow as tf | ||
from fast_transformer import FastTransformer | ||
|
||
mask = tf.ones([1, 4096], dtype=tf.bool) | ||
model = FastTransformer( | ||
num_tokens=20000, | ||
dim=512, | ||
depth=2, | ||
max_seq_len=4096, | ||
absolute_pos_emb=True, # Absolute positional embeddings | ||
mask=mask, | ||
) | ||
x = tf.experimental.numpy.random.randint(0, 20000, (1, 4096)) | ||
|
||
logits = model(x) # (1, 4096, 20000) | ||
print("Should be (1, 4096, 20000):", logits.shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "Fast Transformer Example.ipynb", | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyN2dLGeOJKFUxGr48unCTld", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/Rishit-dagli/Fast-Transformer/blob/main/example/fast-transformer-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "GTM5ACwLHfen" | ||
}, | ||
"source": [ | ||
"# Fast Transformer Example\n", | ||
"\n", | ||
"This notebook shows the the process of using the [fast-transformer](https://pypi.org/project/tf-watcher/) Python package. Fast Transformer is a Transformer variant based on additive attention that can handle long sequences efficiently with linear complexity. Fastformer is much more efficient than many existing Transformer models and can meanwhile achieve comparable or even better long text modeling performance.\n", | ||
"\n", | ||
"If you find this useful please consider giving a ⭐ to [the repo](https://github.com/Rishit-dagli/Fast-Transformer)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "fHvBydTKIF1-" | ||
}, | ||
"source": [ | ||
"## Install the package" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "UAJ67rErHTLH", | ||
"outputId": "61f07cd1-4f60-4173-e0ce-6593d43726bb" | ||
}, | ||
"source": [ | ||
"!pip install fast-transformer" | ||
], | ||
"execution_count": 1, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Collecting fast-transformer\n", | ||
" Downloading fast_transformer-0.1.0-py3-none-any.whl (10 kB)\n", | ||
"Collecting rotary-embedding-tensorflow~=0.1.0\n", | ||
" Downloading rotary-embedding-tensorflow-0.1.1.tar.gz (5.6 kB)\n", | ||
"Collecting einops~=0.3.0\n", | ||
" Downloading einops-0.3.2-py3-none-any.whl (25 kB)\n", | ||
"Collecting tensorflow~=2.5.0\n", | ||
" Downloading tensorflow-2.5.1-cp37-cp37m-manylinux2010_x86_64.whl (454.4 MB)\n", | ||
"\u001b[K |████████████████████████████████| 454.4 MB 9.7 kB/s \n", | ||
"\u001b[?25hRequirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.7/dist-packages (from rotary-embedding-tensorflow~=0.1.0->fast-transformer) (1.19.5)\n", | ||
"Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.1.0)\n", | ||
"Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.1.2)\n", | ||
"Collecting keras-nightly~=2.5.0.dev\n", | ||
" Downloading keras_nightly-2.5.0.dev2021032900-py2.py3-none-any.whl (1.2 MB)\n", | ||
"\u001b[K |████████████████████████████████| 1.2 MB 37.0 MB/s \n", | ||
"\u001b[?25hRequirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.12)\n", | ||
"Collecting tensorflow-estimator<2.6.0,>=2.5.0\n", | ||
" Downloading tensorflow_estimator-2.5.0-py2.py3-none-any.whl (462 kB)\n", | ||
"\u001b[K |████████████████████████████████| 462 kB 54.1 MB/s \n", | ||
"\u001b[?25hRequirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.12.1)\n", | ||
"Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.17.3)\n", | ||
"Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.37.0)\n", | ||
"Requirement already satisfied: absl-py~=0.10 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.12.0)\n", | ||
"Requirement already satisfied: h5py~=3.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.1.0)\n", | ||
"Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.6.3)\n", | ||
"Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.2.0)\n", | ||
"Requirement already satisfied: tensorboard~=2.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (2.6.0)\n", | ||
"Collecting grpcio~=1.34.0\n", | ||
" Downloading grpcio-1.34.1-cp37-cp37m-manylinux2014_x86_64.whl (4.0 MB)\n", | ||
"\u001b[K |████████████████████████████████| 4.0 MB 34.1 MB/s \n", | ||
"\u001b[?25hRequirement already satisfied: six~=1.15.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.15.0)\n", | ||
"Requirement already satisfied: typing-extensions~=3.7.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.7.4.3)\n", | ||
"Requirement already satisfied: gast==0.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.4.0)\n", | ||
"Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.3.0)\n", | ||
"Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py~=3.1.0->tensorflow~=2.5.0->fast-transformer) (1.5.2)\n", | ||
"Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.6.1)\n", | ||
"Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (57.4.0)\n", | ||
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.4.6)\n", | ||
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.0.1)\n", | ||
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.8.0)\n", | ||
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.3.4)\n", | ||
"Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.35.0)\n", | ||
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2.23.0)\n", | ||
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.2.4)\n", | ||
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.2.8)\n", | ||
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.7.2)\n", | ||
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.3.0)\n", | ||
"Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.8.1)\n", | ||
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.4.8)\n", | ||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2021.5.30)\n", | ||
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.0.4)\n", | ||
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.24.3)\n", | ||
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2.10)\n", | ||
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.1.1)\n", | ||
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->markdown>=2.6.8->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.6.0)\n", | ||
"Building wheels for collected packages: rotary-embedding-tensorflow\n", | ||
" Building wheel for rotary-embedding-tensorflow (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | ||
" Created wheel for rotary-embedding-tensorflow: filename=rotary_embedding_tensorflow-0.1.1-py3-none-any.whl size=6219 sha256=3d7948e76623a564a497dc3410c405b3a1dadff936e124f9a01fc6c5abba1168\n", | ||
" Stored in directory: /root/.cache/pip/wheels/ae/eb/9e/30f415d973c560e533c270bfc4deafda32b3cf6e27109dc06c\n", | ||
"Successfully built rotary-embedding-tensorflow\n", | ||
"Installing collected packages: grpcio, tensorflow-estimator, keras-nightly, tensorflow, rotary-embedding-tensorflow, einops, fast-transformer\n", | ||
" Attempting uninstall: grpcio\n", | ||
" Found existing installation: grpcio 1.41.0\n", | ||
" Uninstalling grpcio-1.41.0:\n", | ||
" Successfully uninstalled grpcio-1.41.0\n", | ||
" Attempting uninstall: tensorflow-estimator\n", | ||
" Found existing installation: tensorflow-estimator 2.6.0\n", | ||
" Uninstalling tensorflow-estimator-2.6.0:\n", | ||
" Successfully uninstalled tensorflow-estimator-2.6.0\n", | ||
" Attempting uninstall: tensorflow\n", | ||
" Found existing installation: tensorflow 2.6.0\n", | ||
" Uninstalling tensorflow-2.6.0:\n", | ||
" Successfully uninstalled tensorflow-2.6.0\n", | ||
"Successfully installed einops-0.3.2 fast-transformer-0.1.0 grpcio-1.34.1 keras-nightly-2.5.0.dev2021032900 rotary-embedding-tensorflow-0.1.1 tensorflow-2.5.1 tensorflow-estimator-2.5.0\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "JmSV-36nIIwt" | ||
}, | ||
"source": [ | ||
"## Setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "uAFR15LtHfB4" | ||
}, | ||
"source": [ | ||
"import tensorflow as tf" | ||
], | ||
"execution_count": 2, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "QkIP13ZNIQI7" | ||
}, | ||
"source": [ | ||
"## Create a `FastTransformer` class" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "djQDvtTPIYoN" | ||
}, | ||
"source": [ | ||
"from fast_transformer import FastTransformer" | ||
], | ||
"execution_count": 3, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "dlmKiMCkIMsv" | ||
}, | ||
"source": [ | ||
"mask = tf.ones([1, 4096], dtype=tf.bool)\n", | ||
"model = FastTransformer(\n", | ||
" num_tokens = 20000,\n", | ||
" dim = 512,\n", | ||
" depth = 2,\n", | ||
" max_seq_len = 4096,\n", | ||
" absolute_pos_emb = True, # Absolute positional embeddings\n", | ||
" mask = mask\n", | ||
")\n", | ||
"x = tf.experimental.numpy.random.randint(0, 20000, (1, 4096))" | ||
], | ||
"execution_count": 4, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "b8J5-wNrIa3l", | ||
"outputId": "082c3fad-b623-42be-b902-cbfb25d9ed54" | ||
}, | ||
"source": [ | ||
"logits = model(x) # (1, 4096, 20000)\n", | ||
"logits.shape" | ||
], | ||
"execution_count": 5, | ||
"outputs": [ | ||
{ | ||
"output_type": "execute_result", | ||
"data": { | ||
"text/plain": [ | ||
"TensorShape([1, 4096, 20000])" | ||
] | ||
}, | ||
"metadata": {}, | ||
"execution_count": 5 | ||
} | ||
] | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .fast_transformer import FastTransformer | ||
from .version import __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import tensorflow as tf | ||
from einops import rearrange, reduce | ||
from rotary_embedding_tensorflow import apply_rotary_emb | ||
|
||
|
||
class FastAttention(tf.keras.layers.Layer): | ||
def __init__( | ||
self, dim, heads=8, dim_head=64, max_seq_len=None, pos_emb=None, mask=None | ||
): | ||
super(FastAttention, self).__init__() | ||
|
||
inner_dim = heads * dim_head | ||
self.heads = heads | ||
self.scale = dim_head ** -0.5 | ||
self.mask = mask | ||
|
||
self.to_qkv = tf.keras.layers.Dense( | ||
inner_dim * 3, input_dim=dim, use_bias=False | ||
) | ||
|
||
if pos_emb is None and max_seq_len is None: | ||
raise Exception( | ||
"If you are using Rotary positional embeddings, max_seq_len must be passed in" | ||
) | ||
|
||
self.pos_emb = pos_emb | ||
self.max_seq_len = max_seq_len | ||
|
||
# reduce pairs of consecutive feature dimension before doing projection to attention logits | ||
if pos_emb is None: | ||
kv_attn_proj_divisor = 1 | ||
else: | ||
kv_attn_proj_divisor = 2 | ||
|
||
# project queries to query attention logits | ||
self.to_q_attn_logits = tf.keras.layers.Dense( | ||
1, input_dim=dim_head, use_bias=False | ||
) | ||
|
||
# project keys to key attention logits | ||
self.to_k_attn_logits = tf.keras.layers.Dense( | ||
1, input_dim=dim_head // kv_attn_proj_divisor, use_bias=False | ||
) | ||
|
||
self.to_r = tf.keras.layers.Dense( | ||
dim_head, input_dim=dim_head // kv_attn_proj_divisor | ||
) | ||
|
||
self.to_out = tf.keras.layers.Dense(dim, input_dim=inner_dim) | ||
|
||
def call(self, x, **kwargs): | ||
n, h = x.shape[1], self.heads | ||
|
||
use_rotary_emb = False | ||
if self.pos_emb is not None: | ||
use_rotary_emb = True | ||
|
||
qkv = self.to_qkv(x) | ||
qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) | ||
|
||
queries, keys, values = map( | ||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv | ||
) | ||
|
||
mask_value = -tf.experimental.numpy.finfo(x.dtype).max | ||
self.mask = rearrange(self.mask, "b n -> b () n") | ||
|
||
# relative positional encoding is needed | ||
if use_rotary_emb: | ||
frequencies = self.pos_emb( | ||
tf.range(self.max_seq_len), cache_key=self.max_seq_len | ||
) | ||
frequencies = rearrange(frequencies[:n], "n d -> () () n d") | ||
query_aggr, keys_aggr, values_aggr = map( | ||
lambda t: apply_rotary_emb(frequencies, t), (queries, keys, values) | ||
) | ||
else: | ||
query_aggr, keys_aggr, values_aggr = queries, keys, values | ||
|
||
# query attention logits | ||
query_attn_logits = ( | ||
rearrange(self.to_q_attn_logits(queries), "b h n () -> b h n") * self.scale | ||
) | ||
|
||
query_attn_logits = tf.where(self.mask, mask_value, query_attn_logits) | ||
query_attn = tf.nn.softmax(query_attn_logits) | ||
|
||
# global query token | ||
global_query = tf.einsum("b h n, b h n d -> b h d", query_attn, query_aggr) | ||
global_query = rearrange(global_query, "b h d -> b h () d") | ||
|
||
# bias keys with global query token | ||
keys = keys * global_query | ||
|
||
# if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension | ||
if use_rotary_emb: | ||
keys = reduce(keys, "b h n (d r) -> b h n d", "sum", r=2) | ||
|
||
# key attention logits | ||
keys_attn_logits = ( | ||
rearrange(self.to_k_attn_logits(keys), "b h n () -> b h n") * self.scale | ||
) | ||
keys_attn_logits = tf.where(self.mask, mask_value, keys_attn_logits) | ||
keys_attn = tf.nn.softmax(keys_attn_logits) | ||
|
||
# global key token | ||
global_keys = tf.einsum("b h n, b h n d -> b h d", keys_attn, keys_aggr) | ||
global_keys = rearrange(global_keys, "b h d -> b h () d") | ||
|
||
# bias the values | ||
u = values_aggr * global_keys | ||
|
||
if use_rotary_emb: | ||
u = reduce(u, "b h n (d r) -> b h n d", "sum", r=2) | ||
|
||
r = self.to_r(u) | ||
|
||
# add queries as a residual | ||
r = r + queries | ||
|
||
# combine heads | ||
r = rearrange(r, "b h n d -> b n (h d)") | ||
return self.to_out(r) |
Oops, something went wrong.