diff --git a/mergekit/architecture.py b/mergekit/architecture.py index f8c620e4..a18aad29 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -335,6 +335,37 @@ def num_layers_config_key(self) -> str: "mlp.fc2.bias", "mlp.fc2.weight", ], + num_layers_key="n_layer", +) + + +PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames( + name="PhiForCausalLM", + pre_weight_names=["model.embed_tokens.weight"], + post_weight_names=[ + "lm_head.bias", + "lm_head.weight", + "model.final_layernorm.bias", + "model.final_layernorm.weight", + ], + embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"], + layer_prefix_format="model.layers.{idx}", + layer_weight_suffixes=[ + "input_layernorm.bias", + "input_layernorm.weight", + "self_attn.dense.bias", + "self_attn.dense.weight", + "self_attn.q_proj.bias", + "self_attn.q_proj.weight", + "self_attn.k_proj.bias", + "self_attn.k_proj.weight", + "self_attn.v_proj.bias", + "self_attn.v_proj.weight", + "mlp.fc1.bias", + "mlp.fc1.weight", + "mlp.fc2.bias", + "mlp.fc2.weight", + ], ) @@ -346,6 +377,12 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames: if arch_name == PhiTensorNames.architecture_name: return PhiTensorNames(config) + if arch_name == PHI2_INFO.name: + if config.model_type == "phi-msft": + return PHI2_INFO + elif config.model_type == "phi": + return PHI2_INFO_AGAIN_BUT_DIFFERENT + supported = [ LLAMA_INFO, MISTRAL_INFO, @@ -355,7 +392,6 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames: GPT2_SEQCLASS_INFO, CHATGLM_INFO, STABLELM_INFO, - PHI2_INFO, JAIS_INFO, ] for arch in supported: diff --git a/mergekit/graph.py b/mergekit/graph.py index 4a055a14..534cb04e 100644 --- a/mergekit/graph.py +++ b/mergekit/graph.py @@ -235,7 +235,13 @@ def __init__( self.cuda = cuda self.low_cpu_memory = low_cpu_memory - def run(self, out_path: str, max_shard_size: int, clone_tensors: bool = False): + def run( + self, + out_path: str, + max_shard_size: int, + clone_tensors: bool = False, + safe_serialization: bool = True, + ): """ Execute the computation graph and save results to disk. @@ -245,8 +251,13 @@ def run(self, out_path: str, max_shard_size: int, clone_tensors: bool = False): Args: out_path (str): The path to the directory where the computed tensors will be saved. max_shard_size (int): The maximum size of each saved shard. + safe_serialization (bool): Write the output tensors using safetensors. """ - writer = TensorWriter(out_path, max_shard_size=max_shard_size) + writer = TensorWriter( + out_path, + max_shard_size=max_shard_size, + safe_serialization=safe_serialization, + ) for ref, tensor in tqdm.tqdm(self.generate_tensors(), total=len(self.targets)): if not self.low_cpu_memory: tensor = tensor.cpu() diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py index 8138b8c4..b4942a82 100644 --- a/mergekit/io/tensor_writer.py +++ b/mergekit/io/tensor_writer.py @@ -29,14 +29,19 @@ class TensorWriter: weight_map = Dict[str, str] current_shard: Dict[str, torch.Tensor] current_shard_size: int + safe_serialization: bool def __init__( - self, out_path: str, max_shard_size: int = 1000 * 1000 * 1000 * 5 + self, + out_path: str, + max_shard_size: int = 1000 * 1000 * 1000 * 5, + safe_serialization: bool = True, ) -> None: os.makedirs(out_path, exist_ok=True) self.out_path = out_path self.max_shard_size = max_shard_size + self.safe_serialization = safe_serialization self.shards_written = 0 self.weight_map = {} self.current_shard = {} @@ -56,20 +61,32 @@ def save_tensor(self, name: str, tensor: torch.Tensor, clone: bool = False): self.current_shard[name] = tensor self.current_shard_size += tensor_size + def get_name_components(self): + if self.safe_serialization: + return "model", "safetensors" + return "pytorch_model", "bin" + def flush_current_shard(self): if not self.current_shard: return logging.info(f"writing shard #{self.shards_written+1} to disk") - shard_name = f"model-{self.shards_written+1}.safetensors" + prefix, extension = self.get_name_components() + shard_name = f"{prefix}-{self.shards_written+1}.{extension}" for key in self.current_shard: self.weight_map[key] = shard_name - safetensors.torch.save_file( - self.current_shard, - os.path.join(self.out_path, shard_name), - metadata={"format": "pt"}, - ) + + shard_path = os.path.join(self.out_path, shard_name) + if self.safe_serialization: + safetensors.torch.save_file( + self.current_shard, + shard_path, + metadata={"format": "pt"}, + ) + else: + torch.save(self.current_shard, shard_path) + self.current_shard = {} self.current_shard_size = 0 self.shards_written = self.shards_written + 1 @@ -77,13 +94,15 @@ def flush_current_shard(self): def finalize(self): self.flush_current_shard() + prefix, extension = self.get_name_components() + # standardize shard names to hf format total_shards = self.shards_written name_remap = {} for idx in range(total_shards): name_remap[ - f"model-{idx+1}.safetensors" - ] = f"model-{idx+1:05d}-of-{total_shards:05d}.safetensors" + f"{prefix}-{idx+1}.{extension}" + ] = f"{prefix}-{idx+1:05d}-of-{total_shards:05d}.{extension}" for old_name, new_name in name_remap.items(): os.rename( @@ -95,7 +114,7 @@ def finalize(self): self.weight_map[key] = name_remap[self.weight_map[key]] with open( - os.path.join(self.out_path, "model.safetensors.index.json"), + os.path.join(self.out_path, f"{prefix}.{extension}.index.json"), "w", encoding="utf-8", ) as file: diff --git a/mergekit/merge.py b/mergekit/merge.py index 2ac57ff5..1ddbe666 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -44,6 +44,7 @@ class MergeOptions(BaseModel): random_seed: Optional[int] = None lazy_unpickle: bool = False write_model_card: bool = True + safe_serialization: bool = True def run_merge( @@ -111,6 +112,7 @@ def run_merge( out_path, max_shard_size=options.out_shard_size, clone_tensors=options.clone_tensors, + safe_serialization=options.safe_serialization, ) cfg_out = method.model_out_config( diff --git a/mergekit/options.py b/mergekit/options.py index e7587b64..30880954 100644 --- a/mergekit/options.py +++ b/mergekit/options.py @@ -35,6 +35,7 @@ "random_seed": "Seed for reproducible use of randomized merge methods", "lazy_unpickle": "Experimental lazy unpickler for lower memory usage", "write_model_card": "Output README.md containing details of the merge", + "safe_serialization": "Save output in safetensors. Do this, don't poison the world with more pickled models.", } diff --git a/tests/common.py b/tests/common.py index 5b988c3e..0a817110 100644 --- a/tests/common.py +++ b/tests/common.py @@ -13,11 +13,15 @@ def run_and_check_merge( config: MergeConfiguration, check_nan: bool = True, validate: Optional[Callable[[str], None]] = None, + index_json_name: Optional[str] = None, ): + if index_json_name is None: + index_json_name = "model.safetensors.index.json" + with tempfile.TemporaryDirectory() as tmpdir: run_merge(config, out_path=tmpdir, options=MergeOptions()) assert os.path.exists( - os.path.join(tmpdir, "model.safetensors.index.json") + os.path.join(tmpdir, index_json_name) ), "No index file for merge" assert os.path.exists( os.path.join(tmpdir, "config.json") diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..f9cf8b9e --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,26 @@ +import os +import tempfile + +import torch + +from mergekit.io import TensorWriter + + +class TestTensorWriter: + def test_safetensors(self): + with tempfile.TemporaryDirectory() as d: + writer = TensorWriter(d, safe_serialization=True) + writer.save_tensor("steve", torch.randn(4)) + writer.finalize() + + assert os.path.exists(os.path.join(d, "model-00001-of-00001.safetensors")) + assert os.path.exists(os.path.join(d, "model.safetensors.index.json")) + + def test_pickle(self): + with tempfile.TemporaryDirectory() as d: + writer = TensorWriter(d, safe_serialization=False) + writer.save_tensor("timothan", torch.randn(4)) + writer.finalize() + + assert os.path.exists(os.path.join(d, "pytorch_model-00001-of-00001.bin")) + assert os.path.exists(os.path.join(d, "pytorch_model.bin.index.json"))