Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model: support arch DbrxForCausalLM #6515

Merged
merged 81 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from 78 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
1d8de31
model: dbrx convert to gguf
phymbert Apr 6, 2024
ed582c1
llama: support dbrx
phymbert Apr 6, 2024
3e3d2d1
gguf-py: remove wrong clip -> clamp
phymbert Apr 6, 2024
3937100
model: dbrx, trust remote code
phymbert Apr 6, 2024
c0beb3c
llama: add label for model 132B
phymbert Apr 6, 2024
0921033
model: dbrx fix python linter in convert-hf-to-gguf.py
phymbert Apr 6, 2024
e4f8ee4
llama: support dbrx fix norm type
phymbert Apr 6, 2024
a7f9a3e
dbrx: minor
phymbert Apr 6, 2024
e3c1e81
convert: dbrx: fix mixed up and down expert tensors
phymbert Apr 6, 2024
0a35f58
convert: dbrx: fix mixed up and down expert tensors
phymbert Apr 6, 2024
c8e6f90
doc: dbrx: add the model as supported
phymbert Apr 6, 2024
916b918
convert: dbrx: fix remove wrong ATTN_OUT_NORM tensor, add output laye…
phymbert Apr 6, 2024
03da419
llama: dbrx: remove wrong attn output layer in model arch
phymbert Apr 6, 2024
76f266b
scripts: get-wikitext-2 add unzip
phymbert Apr 6, 2024
9c7dedb
llama: dbrx: no attention output layer
phymbert Apr 6, 2024
fe80898
model: dbrx: fix missing embedding tensor, mix with output layer
phymbert Apr 6, 2024
4f12a58
llama: dbrx: remove not existing condition on empty output layer
phymbert Apr 6, 2024
6985629
Merge remote-tracking branch 'origin/master' into hp/model/support-dbrx
phymbert Apr 6, 2024
7e7cd53
llama: dbrx: remove unnecessary optional tensor on FFN_GATE_EXPS
phymbert Apr 6, 2024
52c4033
llama: increase maximum experts allowed
phymbert Apr 7, 2024
06a59ab
model: dbrx: convert add n_ff
phymbert Apr 7, 2024
305ac3b
llama: dbrx: quantize fix n_attention_wv tensor name
phymbert Apr 7, 2024
b6522a9
model: dbrx: convert fix tokenizer
phymbert Apr 7, 2024
dccb012
llama: dbrx: quantize fix n_attention_wv tensor name
phymbert Apr 7, 2024
61be4b9
model: convert-hf-to-gguf.py add _set_vocab_tiktoken gpt2 backed on l…
phymbert Apr 7, 2024
1fb6d95
model: convert-hf-to-gguf.py fix classname conflict with qwen2
phymbert Apr 7, 2024
200ce21
model: dbrx: convert-hf-to-gguf.py fix fix ftype missing, fix tensor …
phymbert Apr 7, 2024
9e17dad
model: dbrx: convert-hf-to-gguf.py add chat template
phymbert Apr 7, 2024
d7546fd
llama: quantize: remove wrong look for tensor qkv name as it was badl…
phymbert Apr 7, 2024
3a9dc2e
model: dbrx: convert-hf-to-gguf.py fix 'token_embd.weight' has wrong …
phymbert Apr 7, 2024
8154617
model: dbrx: convert-hf-to-gguf.py support python 3.8
phymbert Apr 7, 2024
2449ef4
llama: dbrx: no weight suffix in ffn_gate_exps, ffn_up_exps and ffn_d…
phymbert Apr 7, 2024
1bd9427
llama: quantize: remove wrong look for tensor qkv name as it was badl…
phymbert Apr 7, 2024
e9987c6
llama: dbrx: fix tensor qkv number of elements
phymbert Apr 7, 2024
d151d8f
model: dbrx: convert reshape expert tensors to 3D
phymbert Apr 7, 2024
f062b83
model: dbrx: convert experts to f16
phymbert Apr 7, 2024
dbfd591
model: dbrx: fix tensor names mapping broken
phymbert Apr 7, 2024
7dd84b0
model: dbrx: fix expert reshape
phymbert Apr 7, 2024
c9bddbf
model: dbrx: fix expert reshape
phymbert Apr 7, 2024
e2c9199
model: dbrx: fix again sic expert reshape
phymbert Apr 7, 2024
50b4373
model: dbrx: weird fix expert reshape
phymbert Apr 7, 2024
0ab1bae
llama: dbrx: output norm dim
phymbert Apr 7, 2024
830e46d
llama: dbrx: fix last normalization
phymbert Apr 7, 2024
2897aa6
llama: dbrx: revert
phymbert Apr 7, 2024
993f836
llama: dbrx: move norm2 after attention, fix build kv
phymbert Apr 7, 2024
b01b062
llama: dbrx: fix build kv att out
phymbert Apr 7, 2024
74e6d87
llama: dbrx: fix build kv att out tensor name
phymbert Apr 7, 2024
f8f97e7
llama: dbrx: hardcode nn.LayerNorm epsilon
phymbert Apr 7, 2024
71f9e47
llama: dbrx: Try another rope type
phymbert Apr 7, 2024
52c6276
llama: dbrx: fix k scale
phymbert Apr 8, 2024
8e22688
llama: dbrx: move norm epsilon to convert. Fix missing normalization.
phymbert Apr 8, 2024
35dce3e
llama: dbrx: rename tensor to actual meaning. Fix normalization in gr…
phymbert Apr 8, 2024
506cc2e
llama: dbrx: convert remove previous reverse
phymbert Apr 8, 2024
eb0847e
llama: dbrx: load norm eps in hparams
phymbert Apr 8, 2024
81f308a
llama: dbrx: fix experts tensor layout
phymbert Apr 8, 2024
21fb24a
model: dbrx: convert-hf-to-gguf.py fix experts tensors shapes
phymbert Apr 8, 2024
f20c04f
llama: factorize moe graph implementation between grok, mixtral and dbrx
phymbert Apr 8, 2024
48909ed
model: dbrx convert permute experts directly torch, log shape
phymbert Apr 8, 2024
18a84fe
llama: dbrx: fix experts 3D tensor layout (again)
phymbert Apr 8, 2024
9968952
llama: dbrx: fix experts 3D tensor layout (again)
phymbert Apr 8, 2024
e66f1e3
llama: dbrx: document changes, permute only FFN_DOWN_EXPS. Add a chec…
phymbert Apr 8, 2024
f30a73b
llama: dbrx: rename layer_out_norm to attn_out_norm
phymbert Apr 8, 2024
ea8b58c
llama: dbrx: first add the residuals and then do the norm
phymbert Apr 8, 2024
55943a2
model: dbrx: convert fix mixed ffn_gate_exps and ffn_down_exps
phymbert Apr 8, 2024
c7b9a2e
llama: dbrx: fix ggml context of the attention outputs weight
phymbert Apr 8, 2024
ac82aa0
gguf-py: revert spaces
phymbert Apr 8, 2024
ac75fbd
gguf-py: dbrx: reverse again the MOE tensors mapping:
phymbert Apr 9, 2024
e5631cf
Merge remote-tracking branch 'origin/master' into hp/model/support-dbrx
phymbert Apr 9, 2024
6f813dc
Merge remote-tracking branch 'origin/master' into hp/model/support-dbrx
phymbert Apr 10, 2024
74529e5
llama: dbrx: use the MOE naming convention for model type
phymbert Apr 10, 2024
06527c6
Merge remote-tracking branch 'origin/master' into hp/model/support-dbrx
phymbert Apr 11, 2024
fc89fee
model: convert-hf-to-gguf.py remove tiktoken
phymbert Apr 11, 2024
bdc4efe
Is silu activation function applied to MODEL_TENSOR.FFN_GATE_EXP here…
phymbert Apr 12, 2024
542585f
Is silu activation function applied to MODEL_TENSOR.FFN_GATE_EXP here…
phymbert Apr 12, 2024
ecbfb1b
Wrong input was being fed to moe layer. This needs to be corrected
phymbert Apr 12, 2024
647a11b
eval-callback: also print last n elements of each dimension
phymbert Apr 12, 2024
03bdc36
minor spaces
phymbert Apr 12, 2024
8e6758f
convert: update comment of MOE tensors mapping
phymbert Apr 12, 2024
f1256dc
llama: rename build_moe to build_moe_ffn and fix grok is using gelu i…
phymbert Apr 12, 2024
e517585
convert-hf-to-gguf.py: fix python linter
phymbert Apr 12, 2024
9f77484
minor: fix indent in llama_build_graph
phymbert Apr 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Typically finetunes of the base models below are supported as well.
- [x] LLaMA 2 🦙🦙
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
- [X] Falcon
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
Expand Down
95 changes: 95 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,101 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)


@Model.register("DbrxForCausalLM")
class DbrxModel(Model):
model_arch = gguf.MODEL_ARCH.DBRX

def set_gguf_parameters(self):
ffn_config = self.hparams["ffn_config"]
attn_config = self.hparams["attn_config"]
self.gguf_writer.add_name(self.hparams["model_type"])
self.gguf_writer.add_block_count(self.hparams["n_layers"])

self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])

self.gguf_writer.add_head_count(self.hparams["n_heads"])
self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])

self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])

self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
self.gguf_writer.add_file_type(self.ftype)

self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])

self.gguf_writer.add_layer_norm_eps(1e-5)

self.gguf_writer.add_file_type(self.ftype)
print(f"gguf: file type = {self.ftype}")

def write_tensors(self):
block_count = self.hparams.get("n_layers")
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
n_expert = self.hparams["ffn_config"]["moe_num_experts"]
n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
n_embd = self.hparams["d_model"]

# Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
# original implementation expects (n_expert, n_ff, n_embd) for all experts weights
# But llama.cpp moe graph works differently
# AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
# so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
"ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert}
"ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
experts = False
for exp_tensor_name in exp_tensor_names.keys():
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
experts = True
data_torch = data_torch.view(n_expert, n_ff, n_embd)
if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None:
data_torch = data_torch.permute(*permute_tensor)
phymbert marked this conversation as resolved.
Show resolved Hide resolved
break

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

# map tensor names
# In MoE models the ffn tensors are typically most of the model weights,
# and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
# Every other model has the weight names ending in .weight,
# let's assume that is the convention which is not the case for dbrx:
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# Most of the codebase that takes in 1D tensors only handles F32 tensors
# and most of the outputs tensors are F32.
if data_dtype != np.float32 and n_dims == 1:
print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
sys.exit()

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)

@Model.register("MiniCPMForCausalLM")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM
Expand Down
26 changes: 18 additions & 8 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,27 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
}

static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0);
float sum = 0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
printf(" [\n");
for (int64_t i2 = 0; i2 < ne[2] && i2 < n; i2++) {
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2*n) {
printf(" ..., \n");
i2 = ne[2] - n;
}
printf(" [\n");
for (int64_t i1 = 0; i1 < ne[1] && i1 < n; i1++) {
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
if (i1 == n && ne[1] > 2*n) {
printf(" ..., \n");
i1 = ne[1] - n;
}
printf(" [");
for (int64_t i0 = 0; i0 < ne[0] && i0 < n; i0++) {
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
if (i0 == n && ne[0] > 2*n) {
printf("..., ");
i0 = ne[0] - n;
}
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
if (type == GGML_TYPE_F16) {
Expand All @@ -51,17 +64,14 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
} else {
GGML_ASSERT(false);
}
printf("%8.4f", v);
printf("%12.4f", v);
sum += v;
if (i0 < ne[0] - 1 && i0 < n - 1) printf(", ");
if (i0 < ne[0] - 1) printf(", ");
}
if (ne[0] > n) printf(", ...");
printf("],\n");
}
if (ne[1] > n) printf(" ...\n");
printf(" ],\n");
}
if (ne[2] > n) printf(" ...\n");
printf(" ]\n");
printf(" sum = %f\n", sum);
}
Expand Down
15 changes: 15 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class MODEL_ARCH(IntEnum):
MAMBA = auto()
XVERSE = auto()
COMMAND_R = auto()
DBRX = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -195,6 +196,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -642,6 +644,19 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
],
MODEL_ARCH.DBRX: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

Expand Down
58 changes: 33 additions & 25 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
Expand Down Expand Up @@ -48,7 +48,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -60,7 +60,7 @@ class TensorNameMap:
"transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"transformer.norm_f", # mpt
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon
"model.final_layernorm", # persimmon
Expand Down Expand Up @@ -96,6 +96,7 @@ class TensorNameMap:
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
),

# Attention norm 2
Expand All @@ -108,6 +109,7 @@ class TensorNameMap:
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
Expand Down Expand Up @@ -152,30 +154,32 @@ class TensorNameMap:

# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
"gpt_neox.layers.{bid}.attention.dense", # gptneox
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
),

# Attention output norm
MODEL_TENSOR.ATTN_OUT_NORM: (
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),

# Rotary embeddings
Expand All @@ -202,9 +206,10 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"transformer.decoder_layer.{bid}.router" # Grok
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
),

# Feed-forward up
Expand Down Expand Up @@ -233,6 +238,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
),

# AWQ-activation gate
Expand All @@ -251,8 +257,9 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear" # Grok (merged)
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
),

# Feed-forward down
Expand Down Expand Up @@ -280,6 +287,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
),

MODEL_TENSOR.ATTN_Q_NORM: (
Expand Down
Loading
Loading