Skip to content

Commit

Permalink
Merge branch 'main' into jais-arch
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 authored Jan 13, 2024
2 parents 3853a47 + deda4aa commit 6947fcf
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 14 deletions.
38 changes: 37 additions & 1 deletion mergekit/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,37 @@ def num_layers_config_key(self) -> str:
"mlp.fc2.bias",
"mlp.fc2.weight",
],
num_layers_key="n_layer",
)


PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames(
name="PhiForCausalLM",
pre_weight_names=["model.embed_tokens.weight"],
post_weight_names=[
"lm_head.bias",
"lm_head.weight",
"model.final_layernorm.bias",
"model.final_layernorm.weight",
],
embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"],
layer_prefix_format="model.layers.{idx}",
layer_weight_suffixes=[
"input_layernorm.bias",
"input_layernorm.weight",
"self_attn.dense.bias",
"self_attn.dense.weight",
"self_attn.q_proj.bias",
"self_attn.q_proj.weight",
"self_attn.k_proj.bias",
"self_attn.k_proj.weight",
"self_attn.v_proj.bias",
"self_attn.v_proj.weight",
"mlp.fc1.bias",
"mlp.fc1.weight",
"mlp.fc2.bias",
"mlp.fc2.weight",
],
)


Expand All @@ -346,6 +377,12 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
if arch_name == PhiTensorNames.architecture_name:
return PhiTensorNames(config)

if arch_name == PHI2_INFO.name:
if config.model_type == "phi-msft":
return PHI2_INFO
elif config.model_type == "phi":
return PHI2_INFO_AGAIN_BUT_DIFFERENT

supported = [
LLAMA_INFO,
MISTRAL_INFO,
Expand All @@ -355,7 +392,6 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
GPT2_SEQCLASS_INFO,
CHATGLM_INFO,
STABLELM_INFO,
PHI2_INFO,
JAIS_INFO,
]
for arch in supported:
Expand Down
15 changes: 13 additions & 2 deletions mergekit/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,13 @@ def __init__(
self.cuda = cuda
self.low_cpu_memory = low_cpu_memory

def run(self, out_path: str, max_shard_size: int, clone_tensors: bool = False):
def run(
self,
out_path: str,
max_shard_size: int,
clone_tensors: bool = False,
safe_serialization: bool = True,
):
"""
Execute the computation graph and save results to disk.
Expand All @@ -245,8 +251,13 @@ def run(self, out_path: str, max_shard_size: int, clone_tensors: bool = False):
Args:
out_path (str): The path to the directory where the computed tensors will be saved.
max_shard_size (int): The maximum size of each saved shard.
safe_serialization (bool): Write the output tensors using safetensors.
"""
writer = TensorWriter(out_path, max_shard_size=max_shard_size)
writer = TensorWriter(
out_path,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization,
)
for ref, tensor in tqdm.tqdm(self.generate_tensors(), total=len(self.targets)):
if not self.low_cpu_memory:
tensor = tensor.cpu()
Expand Down
39 changes: 29 additions & 10 deletions mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,19 @@ class TensorWriter:
weight_map = Dict[str, str]
current_shard: Dict[str, torch.Tensor]
current_shard_size: int
safe_serialization: bool

def __init__(
self, out_path: str, max_shard_size: int = 1000 * 1000 * 1000 * 5
self,
out_path: str,
max_shard_size: int = 1000 * 1000 * 1000 * 5,
safe_serialization: bool = True,
) -> None:
os.makedirs(out_path, exist_ok=True)

self.out_path = out_path
self.max_shard_size = max_shard_size
self.safe_serialization = safe_serialization
self.shards_written = 0
self.weight_map = {}
self.current_shard = {}
Expand All @@ -56,34 +61,48 @@ def save_tensor(self, name: str, tensor: torch.Tensor, clone: bool = False):
self.current_shard[name] = tensor
self.current_shard_size += tensor_size

def get_name_components(self):
if self.safe_serialization:
return "model", "safetensors"
return "pytorch_model", "bin"

def flush_current_shard(self):
if not self.current_shard:
return

logging.info(f"writing shard #{self.shards_written+1} to disk")

shard_name = f"model-{self.shards_written+1}.safetensors"
prefix, extension = self.get_name_components()
shard_name = f"{prefix}-{self.shards_written+1}.{extension}"
for key in self.current_shard:
self.weight_map[key] = shard_name
safetensors.torch.save_file(
self.current_shard,
os.path.join(self.out_path, shard_name),
metadata={"format": "pt"},
)

shard_path = os.path.join(self.out_path, shard_name)
if self.safe_serialization:
safetensors.torch.save_file(
self.current_shard,
shard_path,
metadata={"format": "pt"},
)
else:
torch.save(self.current_shard, shard_path)

self.current_shard = {}
self.current_shard_size = 0
self.shards_written = self.shards_written + 1

def finalize(self):
self.flush_current_shard()

prefix, extension = self.get_name_components()

# standardize shard names to hf format
total_shards = self.shards_written
name_remap = {}
for idx in range(total_shards):
name_remap[
f"model-{idx+1}.safetensors"
] = f"model-{idx+1:05d}-of-{total_shards:05d}.safetensors"
f"{prefix}-{idx+1}.{extension}"
] = f"{prefix}-{idx+1:05d}-of-{total_shards:05d}.{extension}"

for old_name, new_name in name_remap.items():
os.rename(
Expand All @@ -95,7 +114,7 @@ def finalize(self):
self.weight_map[key] = name_remap[self.weight_map[key]]

with open(
os.path.join(self.out_path, "model.safetensors.index.json"),
os.path.join(self.out_path, f"{prefix}.{extension}.index.json"),
"w",
encoding="utf-8",
) as file:
Expand Down
2 changes: 2 additions & 0 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class MergeOptions(BaseModel):
random_seed: Optional[int] = None
lazy_unpickle: bool = False
write_model_card: bool = True
safe_serialization: bool = True


def run_merge(
Expand Down Expand Up @@ -111,6 +112,7 @@ def run_merge(
out_path,
max_shard_size=options.out_shard_size,
clone_tensors=options.clone_tensors,
safe_serialization=options.safe_serialization,
)

cfg_out = method.model_out_config(
Expand Down
1 change: 1 addition & 0 deletions mergekit/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"random_seed": "Seed for reproducible use of randomized merge methods",
"lazy_unpickle": "Experimental lazy unpickler for lower memory usage",
"write_model_card": "Output README.md containing details of the merge",
"safe_serialization": "Save output in safetensors. Do this, don't poison the world with more pickled models.",
}


Expand Down
6 changes: 5 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@ def run_and_check_merge(
config: MergeConfiguration,
check_nan: bool = True,
validate: Optional[Callable[[str], None]] = None,
index_json_name: Optional[str] = None,
):
if index_json_name is None:
index_json_name = "model.safetensors.index.json"

with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
assert os.path.exists(
os.path.join(tmpdir, "model.safetensors.index.json")
os.path.join(tmpdir, index_json_name)
), "No index file for merge"
assert os.path.exists(
os.path.join(tmpdir, "config.json")
Expand Down
26 changes: 26 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import tempfile

import torch

from mergekit.io import TensorWriter


class TestTensorWriter:
def test_safetensors(self):
with tempfile.TemporaryDirectory() as d:
writer = TensorWriter(d, safe_serialization=True)
writer.save_tensor("steve", torch.randn(4))
writer.finalize()

assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors"))
assert os.path.exists(os.path.join(d, "model.safetensors.index.json"))

def test_pickle(self):
with tempfile.TemporaryDirectory() as d:
writer = TensorWriter(d, safe_serialization=False)
writer.save_tensor("timothan", torch.randn(4))
writer.finalize()

assert os.path.exists(os.path.join(d, "pytorch_model-00001-of-00001.bin"))
assert os.path.exists(os.path.join(d, "pytorch_model.bin.index.json"))

0 comments on commit 6947fcf

Please sign in to comment.