Skip to content

Commit

Permalink
Really fix need for --allow-crimes
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Jan 4, 2024
1 parent 7fbe2b4 commit 9bd0313
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import List, Optional
from typing import ClassVar, List, Optional

from pydantic import BaseModel
from transformers import PretrainedConfig
Expand Down Expand Up @@ -108,12 +108,13 @@ def num_layers(self, config: PretrainedConfig) -> int:
)


class MixtralTensorNames(ArchitectureInfo):
architecture_name: str = "MixtralForCausalLM"
class MixtralTensorNames(ArchitectureInfo, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM"
num_local_experts: int

def __init__(self, config: PretrainedConfig):
self.num_local_experts = config.num_local_experts
@classmethod
def from_config(cls, config: PretrainedConfig):
return MixtralTensorNames(num_local_experts=config.num_local_experts)

def pre_weights(self) -> List[str]:
return MISTRAL_INFO.pre_weights()
Expand Down Expand Up @@ -266,29 +267,25 @@ def layer_weight_formats(self) -> List[str]:
)


class PhiTensorNames(ArchitectureInfo):
architecture_name: str = "MixFormerSequentialForCausalLM"
class PhiTensorNames(ArchitectureInfo, BaseModel):
ARCHITECTURE_NAME: ClassVar[str] = "MixFormerSequentialForCausalLM"
n_layer: int

def __init__(self, config: PretrainedConfig):
self.config = config

def __eq__(self, rhs: "PhiTensorNames"):
if not isinstance(rhs, PhiTensorNames):
return False
return self.num_layers() == rhs.num_layers()
def from_config(cls, config: PretrainedConfig):
return PhiTensorNames(n_layer=config.n_layer)

def pre_weights(self) -> List[str]:
return ["layers.0.wte.weight"]

def post_weights(self) -> List[str]:
fake_layer_idx = self.config.n_layer + 1
fake_layer_idx = self.n_layer
return [
f"layers.{fake_layer_idx}.{suffix}"
for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
]

def embed_weights(self) -> List[str]:
fake_layer_idx = self.config.n_layer + 1
fake_layer_idx = self.n_layer
return [
"layers.0.wte.weight",
f"layers.{fake_layer_idx}.linear.weight",
Expand Down Expand Up @@ -351,10 +348,10 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
raise RuntimeError("More than one architecture in config?")

arch_name = config.architectures[0]
if arch_name == PhiTensorNames.architecture_name:
return PhiTensorNames(config)
if arch_name == MixtralTensorNames.architecture_name:
return MixtralTensorNames(config)
if arch_name == PhiTensorNames.ARCHITECTURE_NAME:
return PhiTensorNames.from_config(config)
if arch_name == MixtralTensorNames.ARCHITECTURE_NAME:
return MixtralTensorNames.from_config(config)

supported = [
LLAMA_INFO,
Expand Down

0 comments on commit 9bd0313

Please sign in to comment.