Skip to content

Commit

Permalink
Add Gemma 2 model (#1673)
Browse files Browse the repository at this point in the history
* Add Gemma2 to Keras (#91)

Add Gemma2 building blocks and presets.

---------

Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com>

* Set presets to one

* Remove extra preset test

* Pin Keras version to 3.3.3

---------

Co-authored-by: Matt Watson <1389937+mattdangerw@users.noreply.github.com>
  • Loading branch information
grasskin and mattdangerw authored Jun 27, 2024
1 parent c459519 commit da3e4ab
Show file tree
Hide file tree
Showing 9 changed files with 245 additions and 9 deletions.
51 changes: 50 additions & 1 deletion keras_nlp/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,28 @@ def __init__(
num_query_heads,
num_key_value_heads,
kernel_initializer="glorot_uniform",
logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
query_head_dim_normalize=True,
dropout=0,
**kwargs,
):
super().__init__(**kwargs)
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.logit_soft_cap = logit_soft_cap
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size
self.query_head_dim_normalize = query_head_dim_normalize
self.dropout = dropout

self._kernel_initializer = keras.initializers.get(
clone_initializer(kernel_initializer)
)
self.num_key_value_groups = num_query_heads // num_key_value_heads
self.query_head_dim_normalize = query_head_dim_normalize

def build(self, inputs_shape):
self.hidden_dim = inputs_shape[-1]
Expand Down Expand Up @@ -114,7 +123,12 @@ def _compute_attention(
attention_mask,
training=False,
):
query_normalization = 1 / np.sqrt(self.head_dim)
if self.query_head_dim_normalize:
query_normalization = 1 / np.sqrt(self.head_dim)
else:
query_normalization = 1 / np.sqrt(
self.hidden_dim // self.num_query_heads
)

q *= ops.cast(query_normalization, dtype=q.dtype)
q_shape = ops.shape(q)
Expand All @@ -130,6 +144,38 @@ def _compute_attention(
b, q_len, _, _, h = ops.shape(q)

attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)

if self.logit_soft_cap is not None:
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
attention_logits = ops.multiply(
ops.tanh(attention_logits), self.logit_soft_cap
)

if self.use_sliding_window_attention:
all_ones = ops.ones_like(attention_mask)
if keras.config.backend() == "tensorflow":
import tensorflow as tf

sliding_window_size = ops.minimum(
self.sliding_window_size - 1, q_len
)
sliding_window_size = ops.cast(
sliding_window_size, dtype="int32"
)
sliding_mask = tf.linalg.band_part(
all_ones, sliding_window_size - 1, sliding_window_size - 1
)
sliding_mask = ops.cast(sliding_mask, dtype="bool")
bool_attention_mask = ops.cast(attention_mask, dtype="bool")
attention_mask = tf.math.logical_and(
sliding_mask, bool_attention_mask
)
else:
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
attention_mask = sliding_mask * attention_mask

attention_mask = attention_mask[:, None, None, :, :]
orig_dtype = attention_logits.dtype
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
Expand Down Expand Up @@ -186,3 +232,6 @@ def call(
if cache is not None:
return attention_output, cache
return attention_output

def compute_output_shape(self, input_shape):
return input_shape
43 changes: 43 additions & 0 deletions keras_nlp/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ class GemmaBackbone(Backbone):
layer_norm_epsilon: float. The epsilon value user for every layer norm
in the transformer model.
dropout: float. Dropout probability for the Transformer encoder.
query_head_dim_normalize: boolean. Whether to normalize attention with
head dimension or hidden_dim/num_query_heads. Gemma2 uses the
second option. Defaults to True.
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
block. Defaults to False.
use_post_attention_norm: boolean. Whether to normalize after the attention
block. Defaults to False.
attention_logit_soft_cap: None or int. Soft cap for the attention logits.
Defaults to None.
final_logit_soft_cap: None or int. Soft cap for the final logits.
Defaults to None.
use_sliding_window_attention boolean. Whether to use sliding local
window attention. Defaults to False.
sliding_window_size: int. Size of the sliding local window. Defaults to
4096.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for the models computations and weights. Note that some
computations, such as softmax and layer normalization will always
Expand Down Expand Up @@ -93,6 +108,13 @@ def __init__(
hidden_dim,
intermediate_dim,
head_dim,
query_head_dim_normalize=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
layer_norm_epsilon=1e-6,
dropout=0,
dtype=None,
Expand All @@ -114,12 +136,19 @@ def __init__(
)
self.transformer_layers = []
for i in range(num_layers):
sliding_window = use_sliding_window_attention and (i % 2 == 0)
layer = GemmaDecoderBlock(
intermediate_dim=intermediate_dim,
hidden_dim=hidden_dim,
num_query_heads=num_query_heads,
head_dim=head_dim,
num_key_value_heads=num_key_value_heads,
query_head_dim_normalize=query_head_dim_normalize,
use_post_ffw_norm=use_post_ffw_norm,
use_post_attention_norm=use_post_attention_norm,
logit_soft_cap=attention_logit_soft_cap,
use_sliding_window_attention=sliding_window,
sliding_window_size=sliding_window_size,
dropout=dropout,
dtype=dtype,
name=f"decoder_block_{i}",
Expand Down Expand Up @@ -163,6 +192,13 @@ def __init__(
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.query_head_dim_normalize = query_head_dim_normalize
self.use_post_ffw_norm = use_post_ffw_norm
self.use_post_attention_norm = use_post_attention_norm
self.attention_logit_soft_cap = attention_logit_soft_cap
self.final_logit_soft_cap = final_logit_soft_cap
self.sliding_window_size = sliding_window_size
self.use_sliding_window_attention = use_sliding_window_attention

def get_config(self):
config = super().get_config()
Expand All @@ -177,6 +213,13 @@ def get_config(self):
"head_dim": self.head_dim,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"query_head_dim_normalize": self.query_head_dim_normalize,
"use_post_ffw_norm": self.use_post_ffw_norm,
"use_post_attention_norm": self.use_post_attention_norm,
"final_logit_soft_cap": self.final_logit_soft_cap,
"attention_logit_soft_cap": self.attention_logit_soft_cap,
"sliding_window_size": self.sliding_window_size,
"use_sliding_window_attention": self.use_sliding_window_attention,
}
)
return config
Expand Down
58 changes: 50 additions & 8 deletions keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
class GemmaBackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 256128,
"vocabulary_size": 20,
"num_layers": 2,
"num_query_heads": 8,
"num_key_value_heads": 8,
"hidden_dim": 128,
"intermediate_dim": 256,
"head_dim": 128,
"num_query_heads": 4,
"num_key_value_heads": 1,
"hidden_dim": 16,
"intermediate_dim": 32,
"head_dim": 4,
"layer_norm_epsilon": 1e-6,
}
self.input_data = {
Expand All @@ -41,7 +41,7 @@ def test_backbone_basics(self):
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 128),
expected_output_shape=(2, 5, 16),
)

@pytest.mark.large
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_all_presets(self):

def test_architecture_characteristics(self):
model = GemmaBackbone(**self.init_kwargs)
self.assertEqual(model.count_params(), 33931904)
self.assertEqual(model.count_params(), 3216)
self.assertEqual(len(model.layers), 6)

def test_distribution(self):
Expand Down Expand Up @@ -169,3 +169,45 @@ def test_distribution_with_lora(self):
)
if "attention/value/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))


@pytest.mark.keras_3_only
class Gemma2BackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
"vocabulary_size": 20, # 256128
"num_layers": 2, # 46
"num_query_heads": 4, # 32
"num_key_value_heads": 2, # 16
"hidden_dim": 16, # 4608
"intermediate_dim": 32, # 73728
"head_dim": 4, # 128
"sliding_window_size": 5, # 4096
"attention_logit_soft_cap": 50,
"final_logit_soft_cap": 30,
"layer_norm_epsilon": 1e-6,
"query_head_dim_normalize": False,
"use_post_ffw_norm": True,
"use_post_attention_norm": True,
"use_sliding_window_attention": True,
}
self.input_data = {
"token_ids": ops.ones((2, 10), dtype="int32"),
"padding_mask": ops.ones((2, 10), dtype="int32"),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 10, 16),
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=GemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
8 changes: 8 additions & 0 deletions keras_nlp/src/models/gemma/gemma_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,17 @@ def call_with_cache(
cache_update_index=cache_update_index,
)
caches.append(next_cache)

cache = ops.stack(caches, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)

if self.backbone.final_logit_soft_cap is not None:
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
logits = ops.multiply(
ops.tanh(logits), self.backbone.final_logit_soft_cap
)

return logits, hidden_states, cache

def _build_cache(self, token_ids):
Expand Down
51 changes: 51 additions & 0 deletions keras_nlp/src/models/gemma/gemma_decoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(
head_dim,
num_query_heads,
num_key_value_heads,
query_head_dim_normalize=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
layer_norm_epsilon=1e-6,
dropout=0,
**kwargs,
Expand All @@ -45,17 +51,34 @@ def __init__(
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.query_head_dim_normalize = query_head_dim_normalize
self.use_post_ffw_norm = use_post_ffw_norm
self.use_post_attention_norm = use_post_attention_norm
self.logit_soft_cap = logit_soft_cap
self.use_sliding_window_attention = use_sliding_window_attention
self.sliding_window_size = sliding_window_size

self.pre_attention_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="pre_attention_norm",
)

if use_post_attention_norm:
self.post_attention_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="pre_attention_norm",
)

self.attention = CachedGemmaAttention(
head_dim=head_dim,
num_query_heads=num_query_heads,
num_key_value_heads=num_key_value_heads,
logit_soft_cap=logit_soft_cap,
use_sliding_window_attention=use_sliding_window_attention,
sliding_window_size=sliding_window_size,
query_head_dim_normalize=True,
dropout=dropout,
dtype=self.dtype_policy,
name="attention",
Expand All @@ -71,6 +94,13 @@ def __init__(
name="pre_ffw_norm",
)

if use_post_ffw_norm:
self.post_ffw_norm = RMSNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="post_ffw_norm",
)

self.gating_ffw = keras.layers.EinsumDense(
equation="btd,df->btf",
output_shape=(None, self.intermediate_dim // 2),
Expand All @@ -96,13 +126,22 @@ def build(self, input_shape):
self.pre_attention_norm.build(input_shape)
self.attention.build(input_shape)

if self.use_post_attention_norm:
shape = self.attention.compute_output_shape(input_shape)
self.post_attention_norm.build(shape)

shape = input_shape
self.pre_ffw_norm.build(shape)
self.gating_ffw.build(shape)
self.gating_ffw_2.build(shape)

shape = self.gating_ffw.compute_output_shape(shape)
self.ffw_linear.build(shape)

if self.use_post_ffw_norm:
shape = self.ffw_linear.compute_output_shape(shape)
self.post_ffw_norm.build(shape)

self.built = True

def compute_output_shape(self, input_shape):
Expand Down Expand Up @@ -157,6 +196,9 @@ def call(
attention_mask=attention_mask,
)

if self.use_post_attention_norm:
attention = self.post_attention_norm(attention)

if self.dropout:
attention = self.attention_dropout(attention)

Expand All @@ -168,6 +210,9 @@ def call(
x = keras.activations.gelu(x1, approximate=True) * x2
x = self.ffw_linear(x)

if self.use_post_ffw_norm:
x = self.post_ffw_norm(x)

x = x + attention_x

if cache is not None:
Expand All @@ -185,6 +230,12 @@ def get_config(self):
"num_key_value_heads": self.num_key_value_heads,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"use_post_ffw_norm": self.use_post_ffw_norm,
"use_post_attention_norm": self.use_post_attention_norm,
"logit_soft_cap": self.logit_soft_cap,
"use_sliding_window_attention": self.use_sliding_window_attention,
"sliding_window_size": self.sliding_window_size,
"query_head_dim_normalize": self.query_head_dim_normalize,
}
)
return config
Loading

0 comments on commit da3e4ab

Please sign in to comment.