diff --git a/ivy/stateful/layers.py b/ivy/stateful/layers.py index 65b445b102ff9..2cada923ee860 100644 --- a/ivy/stateful/layers.py +++ b/ivy/stateful/layers.py @@ -5,6 +5,8 @@ from ivy.func_wrapper import handle_nestable from ivy.stateful.initializers import GlorotUniform, Zeros from ivy.stateful.module import Module +from ivy.stateful.norms import LayerNorm +import numpy as np # ToDo: update docstrings and typehints according to ivy\layers @@ -2224,3 +2226,47 @@ def _forward(self, x): The input array as it is. """ return x + + +# Transformer # +# ----------# + + +class Transformer(Module): + def __init__(self, d_model, num_heads, dff, rate=0.1, max_sequence_length=1000): + super(Transformer, self).__init__() + self.multihead_attention = MultiHeadAttention(d_model, num_heads) + self.feedforward = Linear(d_model, dff) + self.layernorm1 = LayerNorm() + self.layernorm2 = LayerNorm() + self.dropout1 = Dropout(rate) + self.dropout2 = Dropout(rate) + self.positional_encoding = self._get_positional_encoding( + max_sequence_length, d_model + ) + + def _get_positional_encoding(self, max_sequence_length, d_model): + pos_enc = np.zeros((1, max_sequence_length, d_model)) + position = np.arange(0, max_sequence_length, dtype=np.float32)[:, np.newaxis] + div_term = np.exp( + np.arange(0, d_model, 2).astype(np.float32) * -(np.log(10000.0) / d_model) + ) + pos_enc[:, :, 0::2] = np.sin(position * div_term) + pos_enc[:, :, 1::2] = np.cos(position * div_term) + return pos_enc + + def call(self, inputs, training): + inputs_with_pos = inputs + self.positional_encoding[:, : inputs.shape[1], :] + + attn_output = self.multihead_attention( + inputs_with_pos, inputs_with_pos, inputs_with_pos + ) + attn_output = self.dropout1(attn_output, training=training) + out1 = self.layernorm1(inputs + attn_output) + + ffn_output = self.feedforward(out1) + ffn_output = ivy.gelu(ffn_output) + ffn_output = self.dropout2(ffn_output, training=training) + out2 = self.layernorm2(out1 + ffn_output) + + return out2 diff --git a/ivy_tests/test_ivy/test_stateful/test_layers.py b/ivy_tests/test_ivy/test_stateful/test_layers.py index 327b948a51f26..255a2cf025db8 100644 --- a/ivy_tests/test_ivy/test_stateful/test_layers.py +++ b/ivy_tests/test_ivy/test_stateful/test_layers.py @@ -1639,3 +1639,101 @@ def test_identity_layer( test_gradients=test_gradients, on_device=on_device, ) + + +# Transformer # +# ----------- # + + +@st.composite +def transformer_data(draw): + dtype = draw( + helpers.get_dtypes("float", full=False).filter(lambda x: x != ["float16"]) + ) + query_dim = draw(st.integers(min_value=1, max_value=128)) + key_dim = draw(st.integers(min_value=1, max_value=128)) + value_dim = draw(st.integers(min_value=1, max_value=128)) + num_heads = draw(st.integers(min_value=1, max_value=8)) + ff_dim = draw(st.integers(min_value=1, max_value=512)) + dropout_rate = draw(st.floats(min_value=0.0, max_value=0.9)) + max_sequence_length = draw(st.integers(min_value=1, max_value=1000)) + + return ( + dtype, + query_dim, + key_dim, + value_dim, + num_heads, + ff_dim, + dropout_rate, + max_sequence_length, + ) + + +@handle_method( + method_tree="Transformer.__call__", + transformer_data=transformer_data(), + init_with_v=st.booleans(), + method_with_v=st.booleans(), + method_num_positional_args=helpers.num_positional_args( + fn_name="Transformer._forward" + ), + build_mode=st.just("on_init"), +) +def test_transformer_layer( + transformer_data, + init_with_v, + method_with_v, + build_mode, + on_device, + class_name, + method_name, + backend_fw, + ground_truth_backend, + init_flags, + method_flags, +): + ( + input_dtype, + query_dim, + key_dim, + value_dim, + num_heads, + ff_dim, + dropout_rate, + max_sequence_length, + ) = transformer_data + ret_np_flat, ret_np_from_gt_flat = helpers.test_method( + backend_to_test=backend_fw, + ground_truth_backend=ground_truth_backend, + init_flags=init_flags, + method_flags=method_flags, + init_all_as_kwargs_np={ + "query_dim": query_dim, + "key_dim": key_dim, + "value_dim": value_dim, + "num_heads": num_heads, + "ff_dim": ff_dim, + "dropout_rate": dropout_rate, + "max_sequence_length": max_sequence_length, + "device": on_device, + "dtype": input_dtype[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "inputs": np.random.randn( + batch_size=32, max_sequence_length=1000, input_dim=(224, 224, 3) + ).astype(input_dtype[0]), + "training": True, + }, + class_name=class_name, + method_name=method_name, + init_with_v=init_with_v, + method_with_v=method_with_v, + rtol_=1e-2, + atol_=1e-2, + test_values=False, + return_flat_np_arrays=True, + on_device=on_device, + ) + assert_same_type_and_shape([ret_np_flat, ret_np_from_gt_flat])