Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
add docker tea
  • Loading branch information
cmalf authored Feb 25, 2024
1 parent c66b70a commit 08aac2c
Show file tree
Hide file tree
Showing 8 changed files with 524 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM python:3.9

WORKDIR /usr/src/fast-transformer

RUN git clone https://github.com/Rishit-dagli/Fast-Transformer .

RUN pip install --no-cache-dir .
16 changes: 16 additions & 0 deletions example/docker_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import tensorflow as tf
from fast_transformer import FastTransformer

mask = tf.ones([1, 4096], dtype=tf.bool)
model = FastTransformer(
num_tokens=20000,
dim=512,
depth=2,
max_seq_len=4096,
absolute_pos_emb=True, # Absolute positional embeddings
mask=mask,
)
x = tf.experimental.numpy.random.randint(0, 20000, (1, 4096))

logits = model(x) # (1, 4096, 20000)
print("Should be (1, 4096, 20000):", logits.shape)
236 changes: 236 additions & 0 deletions example/transformer-example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Fast Transformer Example.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyN2dLGeOJKFUxGr48unCTld",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/Rishit-dagli/Fast-Transformer/blob/main/example/fast-transformer-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GTM5ACwLHfen"
},
"source": [
"# Fast Transformer Example\n",
"\n",
"This notebook shows the the process of using the [fast-transformer](https://pypi.org/project/tf-watcher/) Python package. Fast Transformer is a Transformer variant based on additive attention that can handle long sequences efficiently with linear complexity. Fastformer is much more efficient than many existing Transformer models and can meanwhile achieve comparable or even better long text modeling performance.\n",
"\n",
"If you find this useful please consider giving a ⭐ to [the repo](https://github.com/Rishit-dagli/Fast-Transformer)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fHvBydTKIF1-"
},
"source": [
"## Install the package"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UAJ67rErHTLH",
"outputId": "61f07cd1-4f60-4173-e0ce-6593d43726bb"
},
"source": [
"!pip install fast-transformer"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting fast-transformer\n",
" Downloading fast_transformer-0.1.0-py3-none-any.whl (10 kB)\n",
"Collecting rotary-embedding-tensorflow~=0.1.0\n",
" Downloading rotary-embedding-tensorflow-0.1.1.tar.gz (5.6 kB)\n",
"Collecting einops~=0.3.0\n",
" Downloading einops-0.3.2-py3-none-any.whl (25 kB)\n",
"Collecting tensorflow~=2.5.0\n",
" Downloading tensorflow-2.5.1-cp37-cp37m-manylinux2010_x86_64.whl (454.4 MB)\n",
"\u001b[K |████████████████████████████████| 454.4 MB 9.7 kB/s \n",
"\u001b[?25hRequirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.7/dist-packages (from rotary-embedding-tensorflow~=0.1.0->fast-transformer) (1.19.5)\n",
"Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.1.0)\n",
"Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.1.2)\n",
"Collecting keras-nightly~=2.5.0.dev\n",
" Downloading keras_nightly-2.5.0.dev2021032900-py2.py3-none-any.whl (1.2 MB)\n",
"\u001b[K |████████████████████████████████| 1.2 MB 37.0 MB/s \n",
"\u001b[?25hRequirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.12)\n",
"Collecting tensorflow-estimator<2.6.0,>=2.5.0\n",
" Downloading tensorflow_estimator-2.5.0-py2.py3-none-any.whl (462 kB)\n",
"\u001b[K |████████████████████████████████| 462 kB 54.1 MB/s \n",
"\u001b[?25hRequirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.12.1)\n",
"Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.17.3)\n",
"Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.37.0)\n",
"Requirement already satisfied: absl-py~=0.10 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.12.0)\n",
"Requirement already satisfied: h5py~=3.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.1.0)\n",
"Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.6.3)\n",
"Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.2.0)\n",
"Requirement already satisfied: tensorboard~=2.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (2.6.0)\n",
"Collecting grpcio~=1.34.0\n",
" Downloading grpcio-1.34.1-cp37-cp37m-manylinux2014_x86_64.whl (4.0 MB)\n",
"\u001b[K |████████████████████████████████| 4.0 MB 34.1 MB/s \n",
"\u001b[?25hRequirement already satisfied: six~=1.15.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.15.0)\n",
"Requirement already satisfied: typing-extensions~=3.7.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.7.4.3)\n",
"Requirement already satisfied: gast==0.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.4.0)\n",
"Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.3.0)\n",
"Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py~=3.1.0->tensorflow~=2.5.0->fast-transformer) (1.5.2)\n",
"Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.6.1)\n",
"Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (57.4.0)\n",
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.4.6)\n",
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.0.1)\n",
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.8.0)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.3.4)\n",
"Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.35.0)\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2.23.0)\n",
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.2.4)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.2.8)\n",
"Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.7.2)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.3.0)\n",
"Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.8.1)\n",
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.4.8)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2021.5.30)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2.10)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.1.1)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->markdown>=2.6.8->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.6.0)\n",
"Building wheels for collected packages: rotary-embedding-tensorflow\n",
" Building wheel for rotary-embedding-tensorflow (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for rotary-embedding-tensorflow: filename=rotary_embedding_tensorflow-0.1.1-py3-none-any.whl size=6219 sha256=3d7948e76623a564a497dc3410c405b3a1dadff936e124f9a01fc6c5abba1168\n",
" Stored in directory: /root/.cache/pip/wheels/ae/eb/9e/30f415d973c560e533c270bfc4deafda32b3cf6e27109dc06c\n",
"Successfully built rotary-embedding-tensorflow\n",
"Installing collected packages: grpcio, tensorflow-estimator, keras-nightly, tensorflow, rotary-embedding-tensorflow, einops, fast-transformer\n",
" Attempting uninstall: grpcio\n",
" Found existing installation: grpcio 1.41.0\n",
" Uninstalling grpcio-1.41.0:\n",
" Successfully uninstalled grpcio-1.41.0\n",
" Attempting uninstall: tensorflow-estimator\n",
" Found existing installation: tensorflow-estimator 2.6.0\n",
" Uninstalling tensorflow-estimator-2.6.0:\n",
" Successfully uninstalled tensorflow-estimator-2.6.0\n",
" Attempting uninstall: tensorflow\n",
" Found existing installation: tensorflow 2.6.0\n",
" Uninstalling tensorflow-2.6.0:\n",
" Successfully uninstalled tensorflow-2.6.0\n",
"Successfully installed einops-0.3.2 fast-transformer-0.1.0 grpcio-1.34.1 keras-nightly-2.5.0.dev2021032900 rotary-embedding-tensorflow-0.1.1 tensorflow-2.5.1 tensorflow-estimator-2.5.0\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JmSV-36nIIwt"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"metadata": {
"id": "uAFR15LtHfB4"
},
"source": [
"import tensorflow as tf"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QkIP13ZNIQI7"
},
"source": [
"## Create a `FastTransformer` class"
]
},
{
"cell_type": "code",
"metadata": {
"id": "djQDvtTPIYoN"
},
"source": [
"from fast_transformer import FastTransformer"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dlmKiMCkIMsv"
},
"source": [
"mask = tf.ones([1, 4096], dtype=tf.bool)\n",
"model = FastTransformer(\n",
" num_tokens = 20000,\n",
" dim = 512,\n",
" depth = 2,\n",
" max_seq_len = 4096,\n",
" absolute_pos_emb = True, # Absolute positional embeddings\n",
" mask = mask\n",
")\n",
"x = tf.experimental.numpy.random.randint(0, 20000, (1, 4096))"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "b8J5-wNrIa3l",
"outputId": "082c3fad-b623-42be-b902-cbfb25d9ed54"
},
"source": [
"logits = model(x) # (1, 4096, 20000)\n",
"logits.shape"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([1, 4096, 20000])"
]
},
"metadata": {},
"execution_count": 5
}
]
}
]
}
2 changes: 2 additions & 0 deletions transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .fast_transformer import FastTransformer
from .version import __version__
123 changes: 123 additions & 0 deletions transformer/fast_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import tensorflow as tf
from einops import rearrange, reduce
from rotary_embedding_tensorflow import apply_rotary_emb


class FastAttention(tf.keras.layers.Layer):
def __init__(
self, dim, heads=8, dim_head=64, max_seq_len=None, pos_emb=None, mask=None
):
super(FastAttention, self).__init__()

inner_dim = heads * dim_head
self.heads = heads
self.scale = dim_head ** -0.5
self.mask = mask

self.to_qkv = tf.keras.layers.Dense(
inner_dim * 3, input_dim=dim, use_bias=False
)

if pos_emb is None and max_seq_len is None:
raise Exception(
"If you are using Rotary positional embeddings, max_seq_len must be passed in"
)

self.pos_emb = pos_emb
self.max_seq_len = max_seq_len

# reduce pairs of consecutive feature dimension before doing projection to attention logits
if pos_emb is None:
kv_attn_proj_divisor = 1
else:
kv_attn_proj_divisor = 2

# project queries to query attention logits
self.to_q_attn_logits = tf.keras.layers.Dense(
1, input_dim=dim_head, use_bias=False
)

# project keys to key attention logits
self.to_k_attn_logits = tf.keras.layers.Dense(
1, input_dim=dim_head // kv_attn_proj_divisor, use_bias=False
)

self.to_r = tf.keras.layers.Dense(
dim_head, input_dim=dim_head // kv_attn_proj_divisor
)

self.to_out = tf.keras.layers.Dense(dim, input_dim=inner_dim)

def call(self, x, **kwargs):
n, h = x.shape[1], self.heads

use_rotary_emb = False
if self.pos_emb is not None:
use_rotary_emb = True

qkv = self.to_qkv(x)
qkv = tf.split(qkv, num_or_size_splits=3, axis=-1)

queries, keys, values = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv
)

mask_value = -tf.experimental.numpy.finfo(x.dtype).max
self.mask = rearrange(self.mask, "b n -> b () n")

# relative positional encoding is needed
if use_rotary_emb:
frequencies = self.pos_emb(
tf.range(self.max_seq_len), cache_key=self.max_seq_len
)
frequencies = rearrange(frequencies[:n], "n d -> () () n d")
query_aggr, keys_aggr, values_aggr = map(
lambda t: apply_rotary_emb(frequencies, t), (queries, keys, values)
)
else:
query_aggr, keys_aggr, values_aggr = queries, keys, values

# query attention logits
query_attn_logits = (
rearrange(self.to_q_attn_logits(queries), "b h n () -> b h n") * self.scale
)

query_attn_logits = tf.where(self.mask, mask_value, query_attn_logits)
query_attn = tf.nn.softmax(query_attn_logits)

# global query token
global_query = tf.einsum("b h n, b h n d -> b h d", query_attn, query_aggr)
global_query = rearrange(global_query, "b h d -> b h () d")

# bias keys with global query token
keys = keys * global_query

# if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension
if use_rotary_emb:
keys = reduce(keys, "b h n (d r) -> b h n d", "sum", r=2)

# key attention logits
keys_attn_logits = (
rearrange(self.to_k_attn_logits(keys), "b h n () -> b h n") * self.scale
)
keys_attn_logits = tf.where(self.mask, mask_value, keys_attn_logits)
keys_attn = tf.nn.softmax(keys_attn_logits)

# global key token
global_keys = tf.einsum("b h n, b h n d -> b h d", keys_attn, keys_aggr)
global_keys = rearrange(global_keys, "b h d -> b h () d")

# bias the values
u = values_aggr * global_keys

if use_rotary_emb:
u = reduce(u, "b h n (d r) -> b h n d", "sum", r=2)

r = self.to_r(u)

# add queries as a residual
r = r + queries

# combine heads
r = rearrange(r, "b h n d -> b n (h d)")
return self.to_out(r)
Loading

0 comments on commit 08aac2c

Please sign in to comment.