Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve Dilated #87 #88

Merged
merged 8 commits into from
Apr 20, 2023
4 changes: 1 addition & 3 deletions tests/layers/test_dot_product_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,13 @@ def test_call_with_different_values(

@pytest.fixture
def attention_layer(self):
return DotProductAttention(dropout_rate=0.2, scaled=True, normalize=False)
return DotProductAttention(dropout_rate=0.2, scaled=True)

def test_from_config(self, attention_layer):
config = attention_layer.get_config()
new_layer = DotProductAttention.from_config(config)
assert attention_layer.dropout.rate == new_layer.dropout.rate
assert attention_layer.scaled == new_layer.scaled
assert attention_layer.normalize == new_layer.normalize

def test_get_attention_weights(self, attention_layer):
attention_layer.attention_weights = np.random.rand(5, 10)
Expand All @@ -113,4 +112,3 @@ def test_get_config(self, attention_layer):
assert isinstance(config, dict)
assert config["dropout_rate"] == 0.2
assert config["scaled"] == True
assert config["normalize"] == False
52 changes: 26 additions & 26 deletions transformerx/layers/dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os

import numpy as np
import tensorflow as tf

from transformerx.layers.masks.global_attention_mask import GlobalAttentionMask
from transformerx.utils import masked_softmax


Expand Down Expand Up @@ -86,20 +85,30 @@ def __init__(
self,
dropout_rate: float = 0,
scaled: bool = True,
normalize: bool = False,
kernel_initializer: str = "ones",
kernel_regularizer: str = None,
**kwargs
mask_type="dilated",
mask_prob=0.0,
dilation_rate=1,
**kwargs,
):
super().__init__(**kwargs)
self.dropout_rate = dropout_rate
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
self.scaled = scaled
self.normalize = normalize
self.attention_weights = None
self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer

self.mask_type = mask_type
self.mask_prob = mask_prob
self.dilation_rate = dilation_rate
self.global_mask = GlobalAttentionMask(
mask_type=self.mask_type,
mask_prob=self.mask_prob,
dilation_rate=self.dilation_rate,
)

def build(self, input_shape):
super().build(input_shape)

Expand All @@ -115,28 +124,13 @@ def call(
attention_mask: tf.Tensor = None,
causal_mask: bool = None,
training=None,
**kwargs
**kwargs,
) -> tf.Tensor:
scores = tf.matmul(queries, keys, transpose_b=True)
if self.scaled:
# self.scale = self.add_weight(
# name="scale",
# shape=(scores.shape),
# initializer=self.kernel_initializer,
# regularizer=self.kernel_regularizer,
# trainable=True,
# )
depth = queries.shape[-1]
# print(self.scale, scores.shape)
# self.scale = tf.broadcast_to(scores.shape)
# self.scale = tf.broadcast_to(
# tf.expand_dims(tf.expand_dims(self.scale, -1), -1), scores.shape
# )
scores = (
scores
/ tf.math.sqrt(tf.cast(depth, dtype=tf.float32))
# * self.scale
)

scores = scores / tf.math.sqrt(tf.cast(depth, dtype=tf.float32))

# apply causal mask
if causal_mask:
Expand All @@ -151,12 +145,18 @@ def call(
tf.expand_dims(causal_mask, -1), scores.shape
) # broadcast across batch dimension

# to be uncommented later
# apply global mask
# gmask = self.global_mask.get_mask(keys.shape)
# masked_attention_scores = tf.math.multiply(scores, gmask)
# attention_probs = tf.nn.softmax(masked_attention_scores, axis=-1)
# uncomment until here

self.attention_weights = masked_softmax(scores, attention_mask)
# self.attention_weights = tf.nn.softmax(scores, axis=-1, mask=attention_mask)
# scores = tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
scores = tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
if self.normalize:
depth = tf.cast(tf.shape(keys)[-1], tf.float32)
scores /= tf.sqrt(depth)

return scores

def get_attention_weights(self):
Expand Down
Empty file.
65 changes: 65 additions & 0 deletions transformerx/layers/masks/global_attention_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import tensorflow as tf


class GlobalAttentionMask:
"""Global attention mask class.

This class creates a global attention mask for the input.
It is created by applying a random mask to the input.
The mask is applied by randomly selecting a number of tokens from the input.
The number of tokens selected is determined by the mask probability."""

def __init__(self, mask_type="none", mask_prob=0.0, dilation_rate=1):
self.mask_type = mask_type
self.mask_prob = mask_prob
self.dilation_rate = dilation_rate

def get_mask(self, input_shape):
if len(input_shape) == 4:
# Assumes the input shape is 4-d ('b', 'h', 'l', 'd')
input_shape = input_shape[1:]
batch_size, seq_len = input_shape[0], input_shape[2]
elif len(input_shape) == 3:
# Assumes the input shape is 3-d ('b', 'l', 'd')
batch_size, seq_len = input_shape[0], input_shape[1]
elif len(input_shape) == 2:
# Assumes the input shape is 2-d ('b', 'd')
batch_size, seq_len = input_shape[0], input_shape[1]
else:
raise ValueError(
"The input shape must be 2-d ('b', 'd'), 3-d ('b', 'l', 'd') or 4-d ('b', 'h', 'l', 'd')"
)

mask = tf.ones((batch_size, seq_len, seq_len), dtype=tf.float32)

if self.mask_type == "none":
pass

elif self.mask_type == "random":
mask = tf.where(
tf.random.uniform((batch_size, seq_len, seq_len)) < self.mask_prob,
tf.zeros((batch_size, seq_len, seq_len)),
mask,
)
elif self.mask_type == "dilated":
mask = self.create_dilated_mask(mask, self.dilation_rate)

return mask

# create a dilated mask method
def create_dilated_mask(self, mask, dilation_rate):
batch_size, seq_len = mask.shape[0], mask.shape[1]

# Create a boolean mask where True indicates positions that need to be masked
mask_bool = tf.math.logical_and(
tf.math.abs(tf.range(seq_len) - tf.range(seq_len)[:, tf.newaxis])
<= dilation_rate,
tf.math.not_equal(tf.range(seq_len)[:, tf.newaxis], tf.range(seq_len)),
)

# Convert the boolean mask to float32
mask_float = tf.cast(mask_bool, dtype=tf.float32)

# Multiply the original mask with the dilated mask to apply the dilation
dilated_mask = mask * mask_float
return dilated_mask
2 changes: 1 addition & 1 deletion transformerx/layers/transformer_encoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def main():
attention_mask = tf.cast(attention_mask, dtype=tf.bool)
# Initialize a TransformerEncoderBlock object
encoder_block = TransformerEncoderBlock(
d_model=256,
d_model=512,
num_heads=4,
dropout_rate=0.1,
norm_type="batch",
Expand Down
64 changes: 5 additions & 59 deletions transformerx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tensorflow as tf


def sequence_mask1(X, attention_mask, value=-1e9):
def sequence_mask(X, attention_mask, value=-1e9):
if not isinstance(X, tf.Tensor):
raise TypeError("X must be a Tensor")
if not isinstance(attention_mask, tf.Tensor):
Expand All @@ -20,6 +20,8 @@ def sequence_mask1(X, attention_mask, value=-1e9):
)
else:
maxlen = X.shape[0]
print("range.shape: ", tf.range(start=0, limit=maxlen, dtype=tf.float32).shape)
print("attention mask.shape: ", tf.cast(attention_mask, dtype=tf.float32).shape)
mask = tf.range(start=0, limit=maxlen, dtype=tf.float32) < tf.cast(
attention_mask, dtype=tf.float32
)
Expand All @@ -30,48 +32,14 @@ def sequence_mask1(X, attention_mask, value=-1e9):
return tf.where(mask, X, value)


def sequence_mask(X, attention_mask, value=-1e9):
if not isinstance(X, tf.Tensor):
raise TypeError("X must be a Tensor")
if not isinstance(attention_mask, tf.Tensor):
raise TypeError("attention_mask must be a Tensor")
if len(X.shape) not in (2, 3):
raise ValueError("X must be a 2D or 3D tensor")
if len(attention_mask.shape) not in (1, 2):
raise ValueError("attention_mask must be a 1D or 2D tensor")

# Check if the attention mask is a valid mask.
if not tf.reduce_all(attention_mask):
raise ValueError(
"attention_mask must be a binary matrix where each row and column corresponds to a token in the sequence, and the value of each entry is 1 if the corresponding tokens are to be attended to and 0 otherwise."
)

# Check if the value parameter is a valid value.
if not isinstance(value, float):
raise TypeError("value must be a float")

print(X.shape, X.dtype, attention_mask.shape, attention_mask.dtype)
# Handle the case where the sequence length is greater than the attention mask length.
if X.shape[1] > attention_mask.shape[1]:
attention_mask = tf.pad(
attention_mask,
[[0, 0], [0, len(X.shape[1]) - len(attention_mask.shape[1])]],
)

# Create the mask.
mask = tf.cast(attention_mask, dtype=tf.float32)

# Return the masked sequence.
return tf.where(mask, X, value)


def masked_softmax_old(X, attention_mask, temperature=1.0):
def masked_softmax(X, attention_mask, temperature=1.0):
"""Perform softmax operation by masking elements on the last axis."""

# x: 3D tensor, attention_mask: 1D or 2D tensor
if attention_mask is None:
return tf.nn.softmax(X / temperature, axis=-1)
else:
print("attention mask len: ", len(attention_mask.shape))
shape = X.shape
if isinstance(attention_mask, tf.SparseTensor):
attention_mask = tf.sparse.reshape(attention_mask, shape=(-1,))
Expand All @@ -87,28 +55,6 @@ def masked_softmax_old(X, attention_mask, temperature=1.0):
return tf.nn.softmax(tf.reshape(X, shape=shape) / temperature, axis=-1)


def masked_softmax(logits, mask):
"""Compute masked softmax over the last dimension of logits."""
# Cast the mask to float32
mask = tf.cast(mask, dtype=tf.float32)
mask = tf.reshape(mask, logits.shape)
# Subtract a large negative number from masked positions to make them close to zero after softmax
logits -= (1.0 - mask) * 1e32

# Apply softmax along the last dimension of logits
softmax_output = tf.nn.softmax(logits, axis=-1)

# Apply the mask to the softmax output
masked_softmax_output = softmax_output * mask

# Normalize the masked softmax output along the last dimension
masked_softmax_output /= tf.reduce_sum(
masked_softmax_output, axis=-1, keepdims=True
)

return masked_softmax_output


def use_device(device):
if device == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
Expand Down