From 0ddc826b34dcf6e424004a2d1f2ddbc9a3548336 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 16 Oct 2023 16:16:29 -0300 Subject: [PATCH 1/2] Support tie_word_embeddings --- lib/bumblebee/text/albert.ex | 11 +++++++++-- lib/bumblebee/text/bart.ex | 8 +++++--- lib/bumblebee/text/bert.ex | 11 +++++++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/lib/bumblebee/text/albert.ex b/lib/bumblebee/text/albert.ex index b5969b81..2e6865e6 100644 --- a/lib/bumblebee/text/albert.ex +++ b/lib/bumblebee/text/albert.ex @@ -496,7 +496,14 @@ defmodule Bumblebee.Text.Albert do end defimpl Bumblebee.HuggingFace.Transformers.Model do - def params_mapping(_spec) do + def params_mapping(spec) do + language_modeling_head_output = + if Map.get(spec, :tie_word_embeddings, true) do + "albert.embeddings.word_embeddings" + else + "predictions.decoder" + end + %{ "embedder.token_embedding" => "albert.embeddings.word_embeddings", "embedder.position_embedding" => "albert.embeddings.position_embeddings", @@ -522,7 +529,7 @@ defmodule Bumblebee.Text.Albert do "pooler.output" => "albert.pooler", "language_modeling_head.dense" => "predictions.dense", "language_modeling_head.norm" => "predictions.LayerNorm", - "language_modeling_head.output" => "predictions.decoder", + "language_modeling_head.output" => language_modeling_head_output, "sequence_classification_head.output" => "classifier", "token_classification_head.output" => "classifier", "multiple_choice_head.output" => "classifier", diff --git a/lib/bumblebee/text/bart.ex b/lib/bumblebee/text/bart.ex index 69434455..1b45986c 100644 --- a/lib/bumblebee/text/bart.ex +++ b/lib/bumblebee/text/bart.ex @@ -652,9 +652,11 @@ defmodule Bumblebee.Text.Bart do end defimpl Bumblebee.HuggingFace.Transformers.Model do - def params_mapping(_spec) do + def params_mapping(spec) do + tie_word_embeddings = Map.get(spec, :tie_word_embeddings, true) + %{ - "encoder_embedder.token_embedding" => "model.encoder.embed_tokens", + "encoder_embedder.token_embedding" => (if tie_word_embeddings, do: "model.shared", else: "model.encoder.embed_tokens"), "encoder_embedder.position_embedding" => "model.encoder.embed_positions", "encoder_embedder.norm" => "model.encoder.layernorm_embedding", "encoder.blocks.{n}.self_attention.query" => "model.encoder.layers.{n}.self_attn.q_proj", @@ -667,7 +669,7 @@ defmodule Bumblebee.Text.Bart do "encoder.blocks.{n}.ffn.intermediate" => "model.encoder.layers.{n}.fc1", "encoder.blocks.{n}.ffn.output" => "model.encoder.layers.{n}.fc2", "encoder.blocks.{n}.output_norm" => "model.encoder.layers.{n}.final_layer_norm", - "decoder_embedder.token_embedding" => "model.decoder.embed_tokens", + "decoder_embedder.token_embedding" => (if tie_word_embeddings, do: "model.shared", else: "model.decoder.embed_tokens"), "decoder_embedder.position_embedding" => "model.decoder.embed_positions", "decoder_embedder.norm" => "model.decoder.layernorm_embedding", "decoder.blocks.{n}.self_attention.query" => "model.decoder.layers.{n}.self_attn.q_proj", diff --git a/lib/bumblebee/text/bert.ex b/lib/bumblebee/text/bert.ex index 2823a2eb..fe295a23 100644 --- a/lib/bumblebee/text/bert.ex +++ b/lib/bumblebee/text/bert.ex @@ -624,7 +624,14 @@ defmodule Bumblebee.Text.Bert do end defimpl Bumblebee.HuggingFace.Transformers.Model do - def params_mapping(_spec) do + def params_mapping(spec) do + language_modeling_head_output = + if Map.get(spec, :tie_word_embeddings, true) do + "bert.embeddings.word_embeddings" + else + "cls.predictions.decoder" + end + %{ "embedder.token_embedding" => "bert.embeddings.word_embeddings", "embedder.position_embedding" => "bert.embeddings.position_embeddings", @@ -655,7 +662,7 @@ defmodule Bumblebee.Text.Bert do "pooler.output" => "bert.pooler.dense", "language_modeling_head.dense" => "cls.predictions.transform.dense", "language_modeling_head.norm" => "cls.predictions.transform.LayerNorm", - "language_modeling_head.output" => "cls.predictions.decoder", + "language_modeling_head.output" => language_modeling_head_output, "language_modeling_head.bias" => "cls.predictions", "next_sentence_prediction_head.output" => "cls.seq_relationship", "sequence_classification_head.output" => "classifier", From 3b2f7a02af7c0418f20e9fcd851d1e5bae3c3205 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 16 Oct 2023 17:17:46 -0300 Subject: [PATCH 2/2] mix format --- lib/bumblebee/text/bart.ex | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/text/bart.ex b/lib/bumblebee/text/bart.ex index 1b45986c..1aacfd77 100644 --- a/lib/bumblebee/text/bart.ex +++ b/lib/bumblebee/text/bart.ex @@ -656,7 +656,8 @@ defmodule Bumblebee.Text.Bart do tie_word_embeddings = Map.get(spec, :tie_word_embeddings, true) %{ - "encoder_embedder.token_embedding" => (if tie_word_embeddings, do: "model.shared", else: "model.encoder.embed_tokens"), + "encoder_embedder.token_embedding" => + if(tie_word_embeddings, do: "model.shared", else: "model.encoder.embed_tokens"), "encoder_embedder.position_embedding" => "model.encoder.embed_positions", "encoder_embedder.norm" => "model.encoder.layernorm_embedding", "encoder.blocks.{n}.self_attention.query" => "model.encoder.layers.{n}.self_attn.q_proj", @@ -669,7 +670,8 @@ defmodule Bumblebee.Text.Bart do "encoder.blocks.{n}.ffn.intermediate" => "model.encoder.layers.{n}.fc1", "encoder.blocks.{n}.ffn.output" => "model.encoder.layers.{n}.fc2", "encoder.blocks.{n}.output_norm" => "model.encoder.layers.{n}.final_layer_norm", - "decoder_embedder.token_embedding" => (if tie_word_embeddings, do: "model.shared", else: "model.decoder.embed_tokens"), + "decoder_embedder.token_embedding" => + if(tie_word_embeddings, do: "model.shared", else: "model.decoder.embed_tokens"), "decoder_embedder.position_embedding" => "model.decoder.embed_positions", "decoder_embedder.norm" => "model.decoder.layernorm_embedding", "decoder.blocks.{n}.self_attention.query" => "model.decoder.layers.{n}.self_attn.q_proj",