Skip to content

Commit

Permalink
adding a helper to guess the hidden layers of the model
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 19, 2023
1 parent fcf673d commit c1c01d6
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 1 deletion.
33 changes: 32 additions & 1 deletion linear_relational/lib/layer_matching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Callable, Union
import re
from collections import defaultdict
from typing import Callable, Iterable, Union

from torch import nn

Expand Down Expand Up @@ -59,3 +61,32 @@ def _layer_matcher_to_callable(
# for some reason mypy doesn't like directly returning the lambda without assigning to a var first
return matcher_callable
return layer_matcher


LAYER_GUESS_RE = r"^([^\d]+)\.([\d]+)(.*)$"


def guess_hidden_layer_matcher(model: nn.Module) -> str:
"""
Guess the hidden layer matcher for a given model. This is a best guess and may not always be correct.
"""
return _guess_hidden_layer_matcher_from_layers(dict(model.named_modules()).keys())


# broken into a separate function for easier testing
def _guess_hidden_layer_matcher_from_layers(layers: Iterable[str]) -> str:
counts_by_guess: dict[str, int] = defaultdict(int)
for layer in layers:
if re.match(LAYER_GUESS_RE, layer):
guess = re.sub(LAYER_GUESS_RE, r"\1.{num}\3", layer)
counts_by_guess[guess] += 1
if len(counts_by_guess) == 0:
raise ValueError(
"Could not guess hidden layer matcher, please provide a layer_matcher"
)

# score is higher for guesses that match more often, are and shorter in length
guess_scores = [
(guess, count + 1 / len(guess)) for guess, count in counts_by_guess.items()
]
return max(guess_scores, key=lambda x: x[1])[0]
92 changes: 92 additions & 0 deletions tests/lib/test_layer_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from transformers import GPT2LMHeadModel

from linear_relational.lib.layer_matching import (
_guess_hidden_layer_matcher_from_layers,
guess_hidden_layer_matcher,
)


def test_guess_hidden_layer_matcher(model: GPT2LMHeadModel) -> None:
assert guess_hidden_layer_matcher(model) == "transformer.h.{num}"


def test_guess_hidden_layer_matcher_from_layers() -> None:
layers = [
"x.e",
"x.y.0",
"x.y.0.attn",
"x.y.1",
"x.y.1.attn",
"x.y.2",
"x.y.2.attn",
"x.lm_head",
]
assert _guess_hidden_layer_matcher_from_layers(layers) == "x.y.{num}"


def test_guess_hidden_layer_matcher_from_layers_guess_llama_matcher() -> None:
layers = [
"",
"model",
"model.embed_tokens",
"model.layers",
"model.layers.0",
"model.layers.0.self_attn",
"model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj",
"model.layers.0.self_attn.o_proj",
"model.layers.0.self_attn.rotary_emb",
"model.layers.0.mlp",
"model.layers.0.mlp.gate_proj",
"model.layers.0.mlp.up_proj",
"model.layers.0.mlp.down_proj",
"model.layers.0.mlp.act_fn",
"model.layers.0.input_layernorm",
"model.layers.0.post_attention_layernorm",
"model.layers.1",
"model.layers.1.self_attn",
"model.layers.1.self_attn.q_proj",
"model.layers.1.self_attn.k_proj",
"model.layers.1.self_attn.v_proj",
"model.layers.1.self_attn.o_proj",
"model.layers.1.self_attn.rotary_emb",
"model.layers.1.mlp",
"model.layers.1.mlp.gate_proj",
"model.layers.1.mlp.up_proj",
"model.layers.1.mlp.down_proj",
"model.layers.1.mlp.act_fn",
"model.layers.1.input_layernorm",
"model.layers.1.post_attention_layernorm",
"model.layers.2",
"model.layers.2.self_attn",
"model.layers.2.self_attn.q_proj",
"model.layers.2.self_attn.k_proj",
"model.layers.2.self_attn.v_proj",
"model.layers.2.self_attn.o_proj",
"model.layers.2.self_attn.rotary_emb",
"model.layers.2.mlp",
"model.layers.2.mlp.gate_proj",
"model.layers.2.mlp.up_proj",
"model.layers.2.mlp.down_proj",
"model.layers.2.mlp.act_fn",
"model.layers.2.input_layernorm",
"model.layers.2.post_attention_layernorm",
"model.layers.3",
"model.layers.3.self_attn",
"model.layers.3.self_attn.q_proj",
"model.layers.3.self_attn.k_proj",
"model.layers.3.self_attn.v_proj",
"model.layers.3.self_attn.o_proj",
"model.layers.3.self_attn.rotary_emb",
"model.layers.3.mlp",
"model.layers.3.mlp.gate_proj",
"model.layers.3.mlp.up_proj",
"model.layers.3.mlp.down_proj",
"model.layers.3.mlp.act_fn",
"model.layers.3.input_layernorm",
"model.layers.3.post_attention_layernorm",
"model.norm",
"lm_head",
]
assert _guess_hidden_layer_matcher_from_layers(layers) == "model.layers.{num}"

0 comments on commit c1c01d6

Please sign in to comment.