diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..186ab6a --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.9 + +WORKDIR /usr/src/fast-transformer + +RUN git clone https://github.com/Rishit-dagli/Fast-Transformer . + +RUN pip install --no-cache-dir . \ No newline at end of file diff --git a/example/docker_example.py b/example/docker_example.py new file mode 100644 index 0000000..9ab6d28 --- /dev/null +++ b/example/docker_example.py @@ -0,0 +1,16 @@ +import tensorflow as tf +from fast_transformer import FastTransformer + +mask = tf.ones([1, 4096], dtype=tf.bool) +model = FastTransformer( + num_tokens=20000, + dim=512, + depth=2, + max_seq_len=4096, + absolute_pos_emb=True, # Absolute positional embeddings + mask=mask, +) +x = tf.experimental.numpy.random.randint(0, 20000, (1, 4096)) + +logits = model(x) # (1, 4096, 20000) +print("Should be (1, 4096, 20000):", logits.shape) diff --git a/example/transformer-example.ipynb b/example/transformer-example.ipynb new file mode 100644 index 0000000..ea2466a --- /dev/null +++ b/example/transformer-example.ipynb @@ -0,0 +1,236 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Fast Transformer Example.ipynb", + "provenance": [], + "authorship_tag": "ABX9TyN2dLGeOJKFUxGr48unCTld", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GTM5ACwLHfen" + }, + "source": [ + "# Fast Transformer Example\n", + "\n", + "This notebook shows the the process of using the [fast-transformer](https://pypi.org/project/tf-watcher/) Python package. Fast Transformer is a Transformer variant based on additive attention that can handle long sequences efficiently with linear complexity. Fastformer is much more efficient than many existing Transformer models and can meanwhile achieve comparable or even better long text modeling performance.\n", + "\n", + "If you find this useful please consider giving a ⭐ to [the repo](https://github.com/Rishit-dagli/Fast-Transformer)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fHvBydTKIF1-" + }, + "source": [ + "## Install the package" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UAJ67rErHTLH", + "outputId": "61f07cd1-4f60-4173-e0ce-6593d43726bb" + }, + "source": [ + "!pip install fast-transformer" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting fast-transformer\n", + " Downloading fast_transformer-0.1.0-py3-none-any.whl (10 kB)\n", + "Collecting rotary-embedding-tensorflow~=0.1.0\n", + " Downloading rotary-embedding-tensorflow-0.1.1.tar.gz (5.6 kB)\n", + "Collecting einops~=0.3.0\n", + " Downloading einops-0.3.2-py3-none-any.whl (25 kB)\n", + "Collecting tensorflow~=2.5.0\n", + " Downloading tensorflow-2.5.1-cp37-cp37m-manylinux2010_x86_64.whl (454.4 MB)\n", + "\u001b[K |████████████████████████████████| 454.4 MB 9.7 kB/s \n", + "\u001b[?25hRequirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.7/dist-packages (from rotary-embedding-tensorflow~=0.1.0->fast-transformer) (1.19.5)\n", + "Requirement already satisfied: termcolor~=1.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.1.0)\n", + "Requirement already satisfied: keras-preprocessing~=1.1.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.1.2)\n", + "Collecting keras-nightly~=2.5.0.dev\n", + " Downloading keras_nightly-2.5.0.dev2021032900-py2.py3-none-any.whl (1.2 MB)\n", + "\u001b[K |████████████████████████████████| 1.2 MB 37.0 MB/s \n", + "\u001b[?25hRequirement already satisfied: flatbuffers~=1.12.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.12)\n", + "Collecting tensorflow-estimator<2.6.0,>=2.5.0\n", + " Downloading tensorflow_estimator-2.5.0-py2.py3-none-any.whl (462 kB)\n", + "\u001b[K |████████████████████████████████| 462 kB 54.1 MB/s \n", + "\u001b[?25hRequirement already satisfied: wrapt~=1.12.1 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.12.1)\n", + "Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.17.3)\n", + "Requirement already satisfied: wheel~=0.35 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.37.0)\n", + "Requirement already satisfied: absl-py~=0.10 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.12.0)\n", + "Requirement already satisfied: h5py~=3.1.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.1.0)\n", + "Requirement already satisfied: astunparse~=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.6.3)\n", + "Requirement already satisfied: google-pasta~=0.2 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.2.0)\n", + "Requirement already satisfied: tensorboard~=2.5 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (2.6.0)\n", + "Collecting grpcio~=1.34.0\n", + " Downloading grpcio-1.34.1-cp37-cp37m-manylinux2014_x86_64.whl (4.0 MB)\n", + "\u001b[K |████████████████████████████████| 4.0 MB 34.1 MB/s \n", + "\u001b[?25hRequirement already satisfied: six~=1.15.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (1.15.0)\n", + "Requirement already satisfied: typing-extensions~=3.7.4 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.7.4.3)\n", + "Requirement already satisfied: gast==0.4.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (0.4.0)\n", + "Requirement already satisfied: opt-einsum~=3.3.0 in /usr/local/lib/python3.7/dist-packages (from tensorflow~=2.5.0->fast-transformer) (3.3.0)\n", + "Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py~=3.1.0->tensorflow~=2.5.0->fast-transformer) (1.5.2)\n", + "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.6.1)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (57.4.0)\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.4.6)\n", + "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.0.1)\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.8.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.3.4)\n", + "Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.35.0)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2.23.0)\n", + "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.2.4)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.2.8)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.7.2)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.3.0)\n", + "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (4.8.1)\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (0.4.8)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2021.5.30)\n", + "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.0.4)\n", + "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (1.24.3)\n", + "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (2.10)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.1.1)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->markdown>=2.6.8->tensorboard~=2.5->tensorflow~=2.5.0->fast-transformer) (3.6.0)\n", + "Building wheels for collected packages: rotary-embedding-tensorflow\n", + " Building wheel for rotary-embedding-tensorflow (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for rotary-embedding-tensorflow: filename=rotary_embedding_tensorflow-0.1.1-py3-none-any.whl size=6219 sha256=3d7948e76623a564a497dc3410c405b3a1dadff936e124f9a01fc6c5abba1168\n", + " Stored in directory: /root/.cache/pip/wheels/ae/eb/9e/30f415d973c560e533c270bfc4deafda32b3cf6e27109dc06c\n", + "Successfully built rotary-embedding-tensorflow\n", + "Installing collected packages: grpcio, tensorflow-estimator, keras-nightly, tensorflow, rotary-embedding-tensorflow, einops, fast-transformer\n", + " Attempting uninstall: grpcio\n", + " Found existing installation: grpcio 1.41.0\n", + " Uninstalling grpcio-1.41.0:\n", + " Successfully uninstalled grpcio-1.41.0\n", + " Attempting uninstall: tensorflow-estimator\n", + " Found existing installation: tensorflow-estimator 2.6.0\n", + " Uninstalling tensorflow-estimator-2.6.0:\n", + " Successfully uninstalled tensorflow-estimator-2.6.0\n", + " Attempting uninstall: tensorflow\n", + " Found existing installation: tensorflow 2.6.0\n", + " Uninstalling tensorflow-2.6.0:\n", + " Successfully uninstalled tensorflow-2.6.0\n", + "Successfully installed einops-0.3.2 fast-transformer-0.1.0 grpcio-1.34.1 keras-nightly-2.5.0.dev2021032900 rotary-embedding-tensorflow-0.1.1 tensorflow-2.5.1 tensorflow-estimator-2.5.0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JmSV-36nIIwt" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "uAFR15LtHfB4" + }, + "source": [ + "import tensorflow as tf" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QkIP13ZNIQI7" + }, + "source": [ + "## Create a `FastTransformer` class" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "djQDvtTPIYoN" + }, + "source": [ + "from fast_transformer import FastTransformer" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "dlmKiMCkIMsv" + }, + "source": [ + "mask = tf.ones([1, 4096], dtype=tf.bool)\n", + "model = FastTransformer(\n", + " num_tokens = 20000,\n", + " dim = 512,\n", + " depth = 2,\n", + " max_seq_len = 4096,\n", + " absolute_pos_emb = True, # Absolute positional embeddings\n", + " mask = mask\n", + ")\n", + "x = tf.experimental.numpy.random.randint(0, 20000, (1, 4096))" + ], + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b8J5-wNrIa3l", + "outputId": "082c3fad-b623-42be-b902-cbfb25d9ed54" + }, + "source": [ + "logits = model(x) # (1, 4096, 20000)\n", + "logits.shape" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TensorShape([1, 4096, 20000])" + ] + }, + "metadata": {}, + "execution_count": 5 + } + ] + } + ] +} \ No newline at end of file diff --git a/transformer/__init__.py b/transformer/__init__.py new file mode 100644 index 0000000..eb36b4f --- /dev/null +++ b/transformer/__init__.py @@ -0,0 +1,2 @@ +from .fast_transformer import FastTransformer +from .version import __version__ diff --git a/transformer/fast_attention.py b/transformer/fast_attention.py new file mode 100644 index 0000000..5d05d9b --- /dev/null +++ b/transformer/fast_attention.py @@ -0,0 +1,123 @@ +import tensorflow as tf +from einops import rearrange, reduce +from rotary_embedding_tensorflow import apply_rotary_emb + + +class FastAttention(tf.keras.layers.Layer): + def __init__( + self, dim, heads=8, dim_head=64, max_seq_len=None, pos_emb=None, mask=None + ): + super(FastAttention, self).__init__() + + inner_dim = heads * dim_head + self.heads = heads + self.scale = dim_head ** -0.5 + self.mask = mask + + self.to_qkv = tf.keras.layers.Dense( + inner_dim * 3, input_dim=dim, use_bias=False + ) + + if pos_emb is None and max_seq_len is None: + raise Exception( + "If you are using Rotary positional embeddings, max_seq_len must be passed in" + ) + + self.pos_emb = pos_emb + self.max_seq_len = max_seq_len + + # reduce pairs of consecutive feature dimension before doing projection to attention logits + if pos_emb is None: + kv_attn_proj_divisor = 1 + else: + kv_attn_proj_divisor = 2 + + # project queries to query attention logits + self.to_q_attn_logits = tf.keras.layers.Dense( + 1, input_dim=dim_head, use_bias=False + ) + + # project keys to key attention logits + self.to_k_attn_logits = tf.keras.layers.Dense( + 1, input_dim=dim_head // kv_attn_proj_divisor, use_bias=False + ) + + self.to_r = tf.keras.layers.Dense( + dim_head, input_dim=dim_head // kv_attn_proj_divisor + ) + + self.to_out = tf.keras.layers.Dense(dim, input_dim=inner_dim) + + def call(self, x, **kwargs): + n, h = x.shape[1], self.heads + + use_rotary_emb = False + if self.pos_emb is not None: + use_rotary_emb = True + + qkv = self.to_qkv(x) + qkv = tf.split(qkv, num_or_size_splits=3, axis=-1) + + queries, keys, values = map( + lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv + ) + + mask_value = -tf.experimental.numpy.finfo(x.dtype).max + self.mask = rearrange(self.mask, "b n -> b () n") + + # relative positional encoding is needed + if use_rotary_emb: + frequencies = self.pos_emb( + tf.range(self.max_seq_len), cache_key=self.max_seq_len + ) + frequencies = rearrange(frequencies[:n], "n d -> () () n d") + query_aggr, keys_aggr, values_aggr = map( + lambda t: apply_rotary_emb(frequencies, t), (queries, keys, values) + ) + else: + query_aggr, keys_aggr, values_aggr = queries, keys, values + + # query attention logits + query_attn_logits = ( + rearrange(self.to_q_attn_logits(queries), "b h n () -> b h n") * self.scale + ) + + query_attn_logits = tf.where(self.mask, mask_value, query_attn_logits) + query_attn = tf.nn.softmax(query_attn_logits) + + # global query token + global_query = tf.einsum("b h n, b h n d -> b h d", query_attn, query_aggr) + global_query = rearrange(global_query, "b h d -> b h () d") + + # bias keys with global query token + keys = keys * global_query + + # if using rotary embeddings, do an inner product between adjacent pairs in the feature dimension + if use_rotary_emb: + keys = reduce(keys, "b h n (d r) -> b h n d", "sum", r=2) + + # key attention logits + keys_attn_logits = ( + rearrange(self.to_k_attn_logits(keys), "b h n () -> b h n") * self.scale + ) + keys_attn_logits = tf.where(self.mask, mask_value, keys_attn_logits) + keys_attn = tf.nn.softmax(keys_attn_logits) + + # global key token + global_keys = tf.einsum("b h n, b h n d -> b h d", keys_attn, keys_aggr) + global_keys = rearrange(global_keys, "b h d -> b h () d") + + # bias the values + u = values_aggr * global_keys + + if use_rotary_emb: + u = reduce(u, "b h n (d r) -> b h n d", "sum", r=2) + + r = self.to_r(u) + + # add queries as a residual + r = r + queries + + # combine heads + r = rearrange(r, "b h n d -> b n (h d)") + return self.to_out(r) diff --git a/transformer/test_transformer.py b/transformer/test_transformer.py new file mode 100644 index 0000000..253b4c8 --- /dev/null +++ b/transformer/test_transformer.py @@ -0,0 +1,30 @@ +import numpy as np +import tensorflow as tf + +from fast_transformer.fast_transformer import FastTransformer + + +class FastTransformerTest(tf.test.TestCase): + def setUp(self): + super(FastTransformerTest, self).setUp() + + mask = tf.ones([1, 4096], dtype=tf.bool) + self.model = FastTransformer( + num_tokens=20000, + dim=512, + depth=2, + max_seq_len=4096, + absolute_pos_emb=True, + mask=mask, + ) + + def test_shape_and_rank(self): + inputs = tf.experimental.numpy.random.randint(0, 20000, (1, 4096)) + outputs = self.model(inputs) + + self.assertEqual(tf.rank(outputs), 3) + self.assertShapeEqual(np.zeros((1, 4096, 20000)), outputs) + + +if __name__ == "__main__": + tf.test.main() diff --git a/transformer/transformer.py b/transformer/transformer.py new file mode 100644 index 0000000..5a95039 --- /dev/null +++ b/transformer/transformer.py @@ -0,0 +1,109 @@ +import tensorflow as tf +from einops import rearrange +from rotary_embedding_tensorflow import RotaryEmbedding + +from .fast_attention import FastAttention + + +class PreNorm(tf.keras.layers.Layer): + def __init__(self, dim, fn): + super(PreNorm, self).__init__() + self.fn = fn + self.norm = tf.keras.layers.LayerNormalization(axis=-1) + + def call(self, x, **kwargs): + x = self.norm(x) + return self.fn(x) + + +class FeedForward(tf.keras.layers.Layer): + def __init__(self, dim, mult=4): + super(FeedForward, self).__init__() + self.dim = dim + self.mult = mult + + self.net = tf.keras.Sequential( + [ + tf.keras.layers.Dense(dim * mult), + tf.keras.layers.Activation(tf.nn.gelu), + tf.keras.layers.Dense(dim, input_dim=dim * mult), + ] + ) + + def call(self, inputs, **kwargs): + return self.net(inputs) + + +class FastTransformer(tf.keras.Model): + def __init__( + self, + num_tokens, + dim, + depth, + max_seq_len, + heads=8, + dim_head=64, + ff_mult=4, + absolute_pos_emb=False, + mask=None, + ): + super(FastTransformer, self).__init__() + + self.token_emb = tf.keras.layers.Embedding(num_tokens, dim) + self.mask = mask + + # positional embeddings + if absolute_pos_emb: + self.abs_pos_emb = tf.keras.layers.Embedding(max_seq_len, dim) + else: + self.abs_pos_emb = None + + layer_pos_emb = None + if not absolute_pos_emb: + assert ( + dim_head % 4 + ) == 0, ( + "dimension of the head must be divisible by 4 to use rotary embeddings" + ) + layer_pos_emb = RotaryEmbedding(dim_head // 2) + + self.fast_tranformer_layers = [] + + for _ in range(depth): + attn = FastAttention( + dim, + dim_head=dim_head, + heads=heads, + pos_emb=layer_pos_emb, + max_seq_len=max_seq_len, + mask=self.mask, + ) + ff = FeedForward(dim, mult=ff_mult) + + self.fast_tranformer_layers.append(PreNorm(dim, attn)) + self.fast_tranformer_layers.append(PreNorm(dim, ff)) + + first_block = self.fast_tranformer_layers[0] + for block in self.fast_tranformer_layers[1:]: + block.fn.to_q_attn_logits = first_block.fn.to_q_attn_logits + block.fn.to_k_attn_logits = first_block.fn.to_k_attn_logits + + self.to_logits = tf.keras.Sequential( + [ + tf.keras.layers.LayerNormalization(axis=-1), + tf.keras.layers.Dense(num_tokens, input_dim=dim), + ] + ) + + def call(self, x, **kwargs): + n = x.shape[1] + x = self.token_emb(x) + + if self.abs_pos_emb is not None: + pos_emb = self.abs_pos_emb(tf.range(n)) + x = x + rearrange(pos_emb, "n d -> () n d") + + for current_layer in self.fast_tranformer_layers: + x = current_layer(x) + x + + return self.to_logits(x) diff --git a/transformer/version.py b/transformer/version.py new file mode 100644 index 0000000..d3ec452 --- /dev/null +++ b/transformer/version.py @@ -0,0 +1 @@ +__version__ = "0.2.0"