Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mergekit-mega: compound merging using multiple yaml documents in a single merge config #72

Merged
merged 12 commits into from
Jan 5, 2024
Merged
37 changes: 37 additions & 0 deletions examples/mega.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
slices:
- sources:
- model: psmathur/orca_mini_v3_13b
layer_range: [0, 40]
- model: garage-bAInd/Platypus2-13B
layer_range: [0, 40]
merge_method: slerp
base_model: psmathur/orca_mini_v3_13b
parameters:
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5 # fallback for rest of tensors
dtype: float16
name: gradient-slerp
---
models:
- model: gradient-slerp
parameters:
density: [1, 0.7, 0.1] # density gradient
weight: 1.0
- model: WizardLM/WizardMath-13B-V1.0
parameters:
density: 0.33
weight:
- filter: mlp
value: 0.5
- value: 0
merge_method: ties
base_model: TheBloke/Llama-2-13B-fp16
parameters:
normalize: true
int8_mask: true
dtype: float16
name: gradient-slerp-ties
143 changes: 143 additions & 0 deletions mergekit/scripts/megamerge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env python3

import os
import sys
import logging
from pathlib import Path

import click
import yaml

from mergekit.merge import MergeOptions, run_merge
from mergekit.config import MergeConfiguration
from mergekit.options import add_merge_options

merges = {}


def has_circular_dependency(nodes):
def dfs(node, visited, stack):
visited[node] = True
stack[node] = True

for dependency in nodes[node]["deps"]:
if not visited[dependency]:
if dfs(dependency, visited, stack):
return True
elif stack[dependency]:
return True

stack[node] = False
return False

visited = {key: False for key in nodes}
stack = {key: False for key in nodes}

for node in nodes:
if not visited[node]:
if dfs(node, visited, stack):
return node

return None


def merge(m, merge_options, force, out_path):
# check if output_path exists
if os.path.exists(out_path / m):
if not force:
logging.info("Skipping %s as it already exists", m)
del merges[m]
return
logging.info("Overwriting %s as --force was specified", m)

if len(merges[m]["deps"]) != 0:
for dep in merges[m]["deps"]:
if dep in merges:
merge(dep, merge_options, force, out_path)

logging.info("Merging model %s", m)
merge_config: MergeConfiguration = MergeConfiguration.model_validate(merges[m])
run_merge(
merge_config,
str(out_path / merges[m]["name"]),
options=merge_options,
)
del merges[m]


@click.command("mergekit-mega")
@click.argument("config_file")
@click.argument("out_path")
@click.option(
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
"--force",
"-f",
type=bool,
default=False,
is_flag=True,
help="overwrite existing merge results instead of skipping them",
)
@add_merge_options
def main(
merge_options: MergeOptions,
config_file: str,
out_path: str,
force: bool,
verbose: bool,
):
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)

out_path = Path(out_path)
with open(config_file, "r", encoding="utf-8") as f:
data = yaml.load_all(f, Loader=yaml.FullLoader)

for d in data:
if "/" in d["name"]:
logging.error("name must not contain a slash")
sys.exit(1)

merges[d["name"]] = d
merges[d["name"]]["deps"] = []
if "slices" in d:
for slc in d["slices"]:
for src in slc["sources"]:
if "model" in src and src["model"] is not None:
model_lora = src["model"].split("+")
# name must not have a slash to avoid path traversal
# therefore, we can use it to check if its a merge from the config
if "/" not in model_lora[0]:
nyxkrage marked this conversation as resolved.
Show resolved Hide resolved
# avoid duplicate deps
if model_lora[0] not in merges[d["name"]]["deps"]:
merges[d["name"]]["deps"].append(model_lora[0])
src["model"] = str(out_path / model_lora[0])
if len(model_lora) == 2:
src["model"] += "+" + model_lora[1]
if "models" in d:
for mdl in d["models"]:
if "model" in mdl and mdl["model"] is not None:
model_lora = mdl["model"].split("+")
# name must not have a slash to avoid path traversal
# therefore, we can use it to check if its a merge from the config
if "/" not in model_lora[0]:
# avoid duplicate deps
if model_lora[0] not in merges[d["name"]]["deps"]:
merges[d["name"]]["deps"].append(model_lora[0])
mdl["model"] = str(out_path / model_lora[0])
if len(model_lora) == 2:
mdl["model"] += "+" + model_lora[1]

logging.info("Merging: %s", ", ".join(merges))

if (dep := has_circular_dependency(merges)) is not None:
logging.error("Circular dependency detected: %s", dep)
sys.exit(1)

while len(merges) != 0:
m = list(merges.keys())[0]
merge(m, merge_options, force)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ repository = "https://github.com/cg123/mergekit"

[project.scripts]
mergekit-yaml = "mergekit.scripts.run_yaml:main"
mergekit-mega = "mergekit.scripts.megamerge:main"
mergekit-legacy = "mergekit.scripts.legacy:main"
mergekit-layershuffle = "mergekit.scripts.layershuffle:main"
bakllama = "mergekit.scripts.bakllama:main"
Expand Down
Loading