Skip to content

Commit

Permalink
Use idx->idx dict instead of full permutation matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Jan 4, 2024
1 parent 7ceeb07 commit f2a49e8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 24 deletions.
24 changes: 14 additions & 10 deletions mergekit/merge_methods/tokenizer_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TokenizerPermutationMerge(MergeMethod):
def __call__(
self,
input_tensors: Dict[TensorReference, torch.Tensor],
embed_permutations: Dict[ModelReference, torch.IntTensor],
embed_permutations: Dict[ModelReference, Dict[int, int]],
config: ConfigReader,
**_kwargs,
) -> torch.Tensor:
Expand All @@ -47,27 +47,31 @@ def __call__(
models.append(tr.model)

x = input_tensors[tr]
p = embed_permutations[tr.model].to(dtype=x.dtype, device=x.device)
temp_dtype = torch.float32 if x.device.type == "cpu" else x.dtype
if p.shape[1] == x.shape[0]:
xp = (p.to(dtype=temp_dtype) @ x.to(dtype=temp_dtype)).to(x.dtype)
else:
raise RuntimeError("Shape mismatch")
p = embed_permutations[tr.model]

xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device)
mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device)
for out_idx in p:
in_idx = p[out_idx]
if in_idx < 0:
continue

xp[out_idx, :] = x[in_idx, :]
mask[out_idx] = 1

expanded.append(xp)
masks.append(p.sum(dim=-1, keepdim=True) > 0)
masks.append(mask)

is_base = tr.model == config.base_model
if use_slerp:
t = config.parameter("t", required=True)
weight = (1.0 - t) if is_base else t
else:
weight = config.parameter("weight", model=tr.model, default=1.0)

weights.append(weight)

expanded = torch.stack(expanded, dim=0)
masks = torch.stack(masks, dim=0)
masks = torch.stack(masks, dim=0).unsqueeze(-1)
weights = (
torch.tensor(weights, dtype=expanded.dtype, device=expanded.device)
.unsqueeze(-1)
Expand Down
56 changes: 42 additions & 14 deletions mergekit/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import json
import logging
import tempfile
from typing import Dict, Optional, Tuple

import tokenizers
Expand Down Expand Up @@ -42,15 +43,23 @@ def get_vocab_size(model_path: str, trust_remote_code: bool) -> Optional[int]:
def get_stripped_tokenizer(
path: str, trust_remote_code: bool = False
) -> transformers.PreTrainedTokenizerFast:
"""
Return a tokenizer for a model that only contains used tokens.
Strips any tokens with indices >= model.vocab_size.
"""
tokenizer = transformers.AutoTokenizer.from_pretrained(
path, trust_remote_code=trust_remote_code, use_fast=True
)
vocab_size = get_vocab_size(path) or len(tokenizer.get_vocab())
vocab_size = get_vocab_size(path, trust_remote_code=trust_remote_code) or len(
tokenizer.get_vocab()
)

unused_toks = [
tok for tok, idx in tokenizer.get_vocab().items() if idx >= vocab_size
]
if not unused_toks:
# we're good, ship it
return tokenizer

if not tokenizer.is_fast:
Expand Down Expand Up @@ -91,12 +100,18 @@ def _keep_merge(m):
def build_union_tokenizer(
base_tok: transformers.PreTrainedTokenizerBase,
tokenizers: Dict[ModelReference, transformers.PreTrainedTokenizerBase],
trust_remote_code: bool = False,
) -> transformers.PreTrainedTokenizerBase:
out_added_tokens = {}
out_vocab = {}

warned_added_tokens = set()

for model, tokenizer in tokenizers.items():
vocab_size = get_vocab_size(model) or tokenizer.vocab_size
vocab_size = (
get_vocab_size(model, trust_remote_code=trust_remote_code)
or tokenizer.vocab_size
)
added_tokens = tokenizer.added_tokens_decoder

vocab = tokenizer.get_vocab()
Expand All @@ -115,14 +130,22 @@ def build_union_tokenizer(

for tok, info in tokenizer.added_tokens_decoder.items():
if tok in out_added_tokens:
if out_added_tokens[tok] != info:
if (out_added_tokens[tok] != info) and tok not in warned_added_tokens:
logging.warning(
f"Token '{tok}' added with multiple different settings, using first"
)
warned_added_tokens.add(tok)

continue
out_added_tokens[tok] = info

res = base_tok
# HACK: save base tokenizer to temp dir and reload to avoid mutating base_tok
with tempfile.TemporaryDirectory() as p:
base_tok.save_pretrained(p, legacy_format=False, safe_serialization=True)
res = transformers.AutoTokenizer.from_pretrained(
p, use_fast=True, trust_remote_code=trust_remote_code
)

orig_base_vocab = base_tok.get_vocab()
for tok in out_vocab:
if tok in out_added_tokens:
Expand All @@ -148,19 +171,20 @@ def build_tokenizer(
if base_model is None:
raise RuntimeError("No models referenced")

tokenizer_out = get_stripped_tokenizer(
#
tokenizer_base = get_stripped_tokenizer(
base_model.path, trust_remote_code=trust_remote_code
)

# load all tokenizers
logging.info("Loading tokenizers")
tokenizers = {base_model: tokenizer_out}
tokenizers = {base_model: tokenizer_base}
for model in config.referenced_models():
if model == base_model:
continue

try:
model_tok = get_stripped_tokenizer(
model_tok = transformers.AutoTokenizer.from_pretrained(
model.path, trust_remote_code=trust_remote_code
)
except Exception:
Expand All @@ -174,9 +198,11 @@ def build_tokenizer(
# build final vocabulary
if config.tokenizer_source == "base":
# it done
pass
tokenizer_out = tokenizer_base
elif config.tokenizer_source == "union":
tokenizer_out = build_union_tokenizer(tokenizer_out, tokenizers)
tokenizer_out = build_union_tokenizer(
tokenizer_base, tokenizers, trust_remote_code=trust_remote_code
)
elif config.tokenizer_source.startswith("model:"):
tokenizer_out = transformers.AutoTokenizer.from_pretrained(
config.tokenizer_source.removeprefix("model:"),
Expand All @@ -199,9 +225,11 @@ def build_tokenizer(
if vocab_size is None:
vocab_size = len(model_vocab)

p = torch.zeros(len(vocab_out), vocab_size, dtype=torch.int32)
for tok in model_vocab:
if tok not in vocab_out:
p = {}
for tok in vocab_out:
new_idx = vocab_out[tok]
if tok not in model_vocab:
p[new_idx] = -1
continue

orig_idx = model_vocab[tok]
Expand All @@ -211,8 +239,8 @@ def build_tokenizer(
)
continue

new_idx = vocab_out[tok]
p[new_idx, orig_idx] = 1
p[new_idx] = orig_idx

permutations[model] = p

return tokenizer_out, permutations

0 comments on commit f2a49e8

Please sign in to comment.