Skip to content

Commit

Permalink
fix(mega): move to click and fix rest of pyproject scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
nyxkrage committed Jan 4, 2024
1 parent a55ad9c commit 30520f5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 83 deletions.
96 changes: 18 additions & 78 deletions mergekit/scripts/megamerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import yaml
import os
import re
from typing import Optional
from typing_extensions import Annotated
import typer
import click
import logging
import os

from mergekit.merge import MergeOptions, run_merge
from mergekit.common import parse_kmb
from mergekit.config import MergeConfiguration
from mergekit.options import add_merge_options

# Regex that matches huggingface path
hf_path = r"^[a-zA-Z0-9\-]+/[a-zA-Z0-9\-\._]+(?:\+.+)$"
Expand Down Expand Up @@ -40,58 +38,20 @@ def merge(m, options, force):
)
del merges[m]

@click.command("mergekit-mega")
@click.argument("config_file")
@click.option(
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
"--force", "-f", type=bool, default=False, is_flag=True, help="overwrite existing merge results instead of skipping them"
)
@add_merge_options
def main(
config_file: Annotated[str, typer.Argument(help="YAML configuration file")],
lora_merge_cache: Annotated[
Optional[str],
typer.Option(help="Path to store merged LORA models", metavar="PATH"),
] = None,
transformers_cache: Annotated[
Optional[str],
typer.Option(
help="Override storage path for downloaded models", metavar="PATH"
),
] = None,
cuda: Annotated[
bool, typer.Option(help="Perform matrix arithmetic on GPU")
] = False,
low_cpu_memory: Annotated[
bool,
typer.Option(
help="Store results and intermediate values on GPU. Useful if VRAM > RAM"
),
] = False,
copy_tokenizer: Annotated[
bool, typer.Option(help="Copy a tokenizer to the output")
] = True,
allow_crimes: Annotated[
bool, typer.Option(help="Allow mixing architectures")
] = False,
out_shard_size: Annotated[
Optional[int],
typer.Option(
help="Number of parameters per output shard [default: 5B]",
parser=parse_kmb,
show_default=False,
metavar="NUM",
),
] = parse_kmb("5B"),
verbose: Annotated[bool, typer.Option("-v", help="Verbose logging")] = False,
trust_remote_code: Annotated[
bool, typer.Option(help="Trust remote code when merging LoRAs")
] = False,
clone_tensors: Annotated[
bool,
typer.Option(
help="Clone tensors before saving, to allow multiple occurrences of the same layer"
),
] = False,
force: Annotated[
bool, typer.Option(help="Force overwrite of existing output")
] = False,
lazy_unpickle: Annotated[
bool, typer.Option(help="Experimental lazy unpickler for lower memory usage")
] = False,
merge_options: MergeOptions,
config_file: str,
force: bool,
verbose: bool,
):
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)

Expand All @@ -117,31 +77,11 @@ def main(
if not re.match(hf_path, mdl["model"]):
merges[d["out_path"]]["deps"].append(mdl["model"])

options = MergeOptions(
lora_merge_cache=lora_merge_cache,
transformers_cache=transformers_cache,
cuda=cuda,
low_cpu_memory=low_cpu_memory,
copy_tokenizer=copy_tokenizer,
allow_crimes=allow_crimes,
out_shard_size=out_shard_size,
trust_remote_code=trust_remote_code,
clone_tensors=clone_tensors,
lazy_unpickle=lazy_unpickle,
)

print("Merging:\n" + '\n'.join(merges))
logging.info("Merging: " + ', '.join(merges))

while len(merges) != 0:
m = list(merges.keys())[0]
merge(m, options, force)



def _main():
# just a wee li'l stub for setuptools
typer.run(main)

merge(m, merge_options, force)

if __name__ == "__main__":
_main()
main()
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ repository = "https://github.com/cg123/mergekit"


[project.scripts]
mergekit-yaml = "mergekit.scripts.run_yaml:_main"
mergekit-mega = "mergekit.scripts.megamerge:_main"
mergekit-legacy = "mergekit.scripts.legacy:_main"
mergekit-layershuffle = "mergekit.scripts.layershuffle:_main"
bakllama = "mergekit.scripts.bakllama:_main"
mergekit-yaml = "mergekit.scripts.run_yaml:main"
mergekit-mega = "mergekit.scripts.megamerge:main"
mergekit-legacy = "mergekit.scripts.legacy:main"
mergekit-layershuffle = "mergekit.scripts.layershuffle:main"
bakllama = "mergekit.scripts.bakllama:main"

[tool.setuptools]
packages = ["mergekit"]
Expand Down

0 comments on commit 30520f5

Please sign in to comment.