Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemma 2 model #1673

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion keras_nlp/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,28 @@ def __init__(
num_query_heads,
num_key_value_heads,
kernel_initializer="glorot_uniform",
logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
query_head_dim_normalize=True,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.logit_soft_cap = logit_soft_cap
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size
self.query_head_dim_normalize = query_head_dim_normalize
self.dropout = dropout

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.query_head_dim_normalize = query_head_dim_normalize

def build(self, inputs_shape):
self.hidden_dim = inputs_shape[-1]
Expand Down Expand Up @@ -114,7 +123,12 @@ def _compute_attention(
attention_mask,
training=False,
):
query_normalization = 1 / np.sqrt(self.head_dim)
if self.query_head_dim_normalize:
query_normalization = 1 / np.sqrt(self.head_dim)
else:
query_normalization = 1 / np.sqrt(
self.hidden_dim // self.num_query_heads
)

q *= ops.cast(query_normalization, dtype=q.dtype)
q_shape = ops.shape(q)
Expand All @@ -130,6 +144,38 @@ def _compute_attention(
b, q_len, _, _, h = ops.shape(q)

attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)

if self.logit_soft_cap is not None:
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
attention_logits = ops.multiply(
ops.tanh(attention_logits), self.logit_soft_cap
)

if self.use_sliding_window_attention:
all_ones = ops.ones_like(attention_mask)
if keras.config.backend() == "tensorflow":
import tensorflow as tf

sliding_window_size = ops.minimum(
self.sliding_window_size - 1, q_len
)
sliding_window_size = ops.cast(
sliding_window_size, dtype="int32"
)
sliding_mask = tf.linalg.band_part(
all_ones, sliding_window_size - 1, sliding_window_size - 1
)
sliding_mask = ops.cast(sliding_mask, dtype="bool")
bool_attention_mask = ops.cast(attention_mask, dtype="bool")
attention_mask = tf.math.logical_and(
sliding_mask, bool_attention_mask
)
else:
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
attention_mask = sliding_mask * attention_mask

attention_mask = attention_mask[:, None, None, :, :]
orig_dtype = attention_logits.dtype
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
Expand Down Expand Up @@ -186,3 +232,6 @@ def call(
if cache is not None:
return attention_output, cache
return attention_output

def compute_output_shape(self, input_shape):
return input_shape
43 changes: 43 additions & 0 deletions keras_nlp/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ class GemmaBackbone(Backbone):
layer_norm_epsilon: float. The epsilon value user for every layer norm
in the transformer model.
dropout: float. Dropout probability for the Transformer encoder.
query_head_dim_normalize: boolean. Whether to normalize attention with
head dimension or hidden_dim/num_query_heads. Gemma2 uses the
second option. Defaults to True.
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
block. Defaults to False.
use_post_attention_norm: boolean. Whether to normalize after the attention
block. Defaults to False.
attention_logit_soft_cap: None or int. Soft cap for the attention logits.
Defaults to None.
final_logit_soft_cap: None or int. Soft cap for the final logits.
Defaults to None.
use_sliding_window_attention boolean. Whether to use sliding local
window attention. Defaults to False.
sliding_window_size: int. Size of the sliding local window. Defaults to
4096.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for the models computations and weights. Note that some
computations, such as softmax and layer normalization will always
Expand Down Expand Up @@ -93,6 +108,13 @@ def __init__(
hidden_dim,
intermediate_dim,
head_dim,
query_head_dim_normalize=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
layer_norm_epsilon=1e-6,
dropout=0,
dtype=None,
Expand All @@ -114,12 +136,19 @@ def __init__(
)
self.transformer_layers = []
for i in range(num_layers):
sliding_window = use_sliding_window_attention and (i % 2 == 0)
layer = GemmaDecoderBlock(
intermediate_dim=intermediate_dim,
hidden_dim=hidden_dim,
num_query_heads=num_query_heads,
head_dim=head_dim,
num_key_value_heads=num_key_value_heads,
query_head_dim_normalize=query_head_dim_normalize,
use_post_ffw_norm=use_post_ffw_norm,
use_post_attention_norm=use_post_attention_norm,
logit_soft_cap=attention_logit_soft_cap,
use_sliding_window_attention=sliding_window,
sliding_window_size=sliding_window_size,
dropout=dropout,
dtype=dtype,
name=f"decoder_block_{i}",
Expand Down Expand Up @@ -163,6 +192,13 @@ def __init__(
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.query_head_dim_normalize = query_head_dim_normalize
self.use_post_ffw_norm = use_post_ffw_norm
self.use_post_attention_norm = use_post_attention_norm
self.attention_logit_soft_cap = attention_logit_soft_cap
self.final_logit_soft_cap = final_logit_soft_cap
self.sliding_window_size = sliding_window_size
self.use_sliding_window_attention = use_sliding_window_attention

def get_config(self):
config = super().get_config()
Expand All @@ -177,6 +213,13 @@ def get_config(self):
"head_dim": self.head_dim,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"query_head_dim_normalize": self.query_head_dim_normalize,
"use_post_ffw_norm": self.use_post_ffw_norm,
"use_post_attention_norm": self.use_post_attention_norm,
"final_logit_soft_cap": self.final_logit_soft_cap,
"attention_logit_soft_cap": self.attention_logit_soft_cap,
"sliding_window_size": self.sliding_window_size,
"use_sliding_window_attention": self.use_sliding_window_attention,
}
)
return config
Expand Down
58 changes: 50 additions & 8 deletions keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
class GemmaBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 256128,
"vocabulary_size": 20,
"num_layers": 2,
"num_query_heads": 8,
"num_key_value_heads": 8,
"hidden_dim": 128,
"intermediate_dim": 256,
"head_dim": 128,
"num_query_heads": 4,
"num_key_value_heads": 1,
"hidden_dim": 16,
"intermediate_dim": 32,
"head_dim": 4,
"layer_norm_epsilon": 1e-6,
}
self.input_data = {
Expand All @@ -41,7 +41,7 @@ def test_backbone_basics(self):
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 128),
expected_output_shape=(2, 5, 16),
)

@pytest.mark.large
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_all_presets(self):

def test_architecture_characteristics(self):
model = GemmaBackbone(**self.init_kwargs)
self.assertEqual(model.count_params(), 33931904)
self.assertEqual(model.count_params(), 3216)
self.assertEqual(len(model.layers), 6)

def test_distribution(self):
Expand Down Expand Up @@ -169,3 +169,45 @@ def test_distribution_with_lora(self):
)
if "attention/value/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))


@pytest.mark.keras_3_only
class Gemma2BackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 20, # 256128
"num_layers": 2, # 46
"num_query_heads": 4, # 32
"num_key_value_heads": 2, # 16
"hidden_dim": 16, # 4608
"intermediate_dim": 32, # 73728
"head_dim": 4, # 128
"sliding_window_size": 5, # 4096
"attention_logit_soft_cap": 50,
"final_logit_soft_cap": 30,
"layer_norm_epsilon": 1e-6,
"query_head_dim_normalize": False,
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"use_sliding_window_attention": True,
}
self.input_data = {
"token_ids": ops.ones((2, 10), dtype="int32"),
"padding_mask": ops.ones((2, 10), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 10, 16),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
8 changes: 8 additions & 0 deletions keras_nlp/src/models/gemma/gemma_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,17 @@ def call_with_cache(
cache_update_index=cache_update_index,
)
caches.append(next_cache)

cache = ops.stack(caches, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)

if self.backbone.final_logit_soft_cap is not None:
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
logits = ops.multiply(
ops.tanh(logits), self.backbone.final_logit_soft_cap
)

return logits, hidden_states, cache

def _build_cache(self, token_ids):
Expand Down
51 changes: 51 additions & 0 deletions keras_nlp/src/models/gemma/gemma_decoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(
head_dim,
num_query_heads,
num_key_value_heads,
query_head_dim_normalize=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
layer_norm_epsilon=1e-6,
dropout=0,
**kwargs,
Expand All @@ -45,17 +51,34 @@ def __init__(
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.query_head_dim_normalize = query_head_dim_normalize
self.use_post_ffw_norm = use_post_ffw_norm
self.use_post_attention_norm = use_post_attention_norm
self.logit_soft_cap = logit_soft_cap
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size

self.pre_attention_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="pre_attention_norm",
)

if use_post_attention_norm:
self.post_attention_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="pre_attention_norm",
)

self.attention = CachedGemmaAttention(
head_dim=head_dim,
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
logit_soft_cap=logit_soft_cap,
use_sliding_window_attention=use_sliding_window_attention,
sliding_window_size=sliding_window_size,
query_head_dim_normalize=True,
dropout=dropout,
dtype=self.dtype_policy,
name="attention",
Expand All @@ -71,6 +94,13 @@ def __init__(
name="pre_ffw_norm",
)

if use_post_ffw_norm:
self.post_ffw_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="post_ffw_norm",
)

self.gating_ffw = keras.layers.EinsumDense(
equation="btd,df->btf",
output_shape=(None, self.intermediate_dim // 2),
Expand All @@ -96,13 +126,22 @@ def build(self, input_shape):
self.pre_attention_norm.build(input_shape)
self.attention.build(input_shape)

if self.use_post_attention_norm:
shape = self.attention.compute_output_shape(input_shape)
self.post_attention_norm.build(shape)

shape = input_shape
self.pre_ffw_norm.build(shape)
self.gating_ffw.build(shape)
self.gating_ffw_2.build(shape)

shape = self.gating_ffw.compute_output_shape(shape)
self.ffw_linear.build(shape)

if self.use_post_ffw_norm:
shape = self.ffw_linear.compute_output_shape(shape)
self.post_ffw_norm.build(shape)

self.built = True

def compute_output_shape(self, input_shape):
Expand Down Expand Up @@ -157,6 +196,9 @@ def call(
attention_mask=attention_mask,
)

if self.use_post_attention_norm:
attention = self.post_attention_norm(attention)

if self.dropout:
attention = self.attention_dropout(attention)

Expand All @@ -168,6 +210,9 @@ def call(
x = keras.activations.gelu(x1, approximate=True) * x2
x = self.ffw_linear(x)

if self.use_post_ffw_norm:
x = self.post_ffw_norm(x)

x = x + attention_x

if cache is not None:
Expand All @@ -185,6 +230,12 @@ def get_config(self):
"num_key_value_heads": self.num_key_value_heads,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"use_post_ffw_norm": self.use_post_ffw_norm,
"use_post_attention_norm": self.use_post_attention_norm,
"logit_soft_cap": self.logit_soft_cap,
"use_sliding_window_attention": self.use_sliding_window_attention,
"sliding_window_size": self.sliding_window_size,
"query_head_dim_normalize": self.query_head_dim_normalize,
}
)
return config
Loading
Loading