Skip to content

Commit

Permalink
fix(mega): use config file name for final merge (#78)
Browse files Browse the repository at this point in the history
Based on feedback from Henky, updates mergekit-mega to require one of
the merges to omit the name key, that will then use the name of the
config file for the name.
Also fixes a bug I introduced when allowing the base model to be an
intermediate model where it wouldn't update the merge models to the
correct path
And removes some unnecessary print statements that I forgot to remove
  • Loading branch information
nyxkrage authored Jan 6, 2024
1 parent ebcaa04 commit 1011ef3
Showing 1 changed file with 31 additions and 10 deletions.
41 changes: 31 additions & 10 deletions mergekit/scripts/megamerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def dfs(node, visited, stack):
return None


def merge(m, merge_options, force, out_path):
def merge(m: str, merge_options: MergeOptions, force: bool, out_path: Path):
"""
Merges a model and its dependencies
Expand Down Expand Up @@ -86,23 +86,23 @@ def merge(m, merge_options, force, out_path):
del merges[m]


def add_model_deps(model, name, out_path):
def add_model_deps(model: str, name: str, out_path: Path):
"""
Adds a model to `name`s dependencies if it is not already there and is a merge
"""
model_lora = model.split("+")
# name must not have a slash to avoid path traversal
# therefore, we can use it to check if its a merge from the config
print(model_lora)
if "/" not in model_lora[0]:
print(model_lora)
# avoid duplicate deps
if model_lora[0] not in merges[name]["deps"]:
merges[name]["deps"].append(model_lora[0])
model = str(out_path / model_lora[0])
if len(model_lora) == 2:
model += "+" + model_lora[1]

return model


@click.command("mergekit-mega")
@click.argument("config_file")
Expand All @@ -116,7 +116,15 @@ def add_model_deps(model, name, out_path):
type=bool,
default=False,
is_flag=True,
help="overwrite existing merge results instead of skipping them",
help="Overwrite existing merge results instead of skipping them",
)
@click.option(
"--require-nameless",
"-R",
type=bool,
default=False,
is_flag=True,
help="Enforces exactly one unnamed merge in the YAML, which will inherit the input file's name.",
)
@add_merge_options
def main(
Expand All @@ -125,6 +133,7 @@ def main(
out_path: str,
force: bool,
verbose: bool,
require_nameless: bool,
):
"""
Main entrypoint for mergekit-mega command see module docstring for more info
Expand All @@ -133,27 +142,39 @@ def main(
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)

out_path = Path(out_path)
final_found = False

with open(config_file, "r", encoding="utf-8") as f:
data = yaml.load_all(f, Loader=yaml.FullLoader)

for d in data:
if "name" not in d:
if final_found:
logging.error("Only one merge must not have a name")
sys.exit(1)
# this sets the name of the final merge to the config file name without the extension
d["name"] = os.path.basename(config_file).rsplit(".", maxsplit=1)[0]
final_found = True

if "/" in d["name"]:
logging.error("name must not contain a slash")
sys.exit(1)

merges[d["name"]] = d
merges[d["name"]]["deps"] = []
if "base_model" in d:
add_model_deps(d["base_model"], d["name"], out_path)
if "/" not in d["base_model"]:
d["base_model"] = str(out_path / d["base_model"])
d["base_model"] = add_model_deps(d["base_model"], d["name"], out_path)
if "slices" in d:
for slc in d["slices"]:
for src in slc["sources"]:
add_model_deps(src["model"], d["name"], out_path)
src["model"] = add_model_deps(src["model"], d["name"], out_path)
if "models" in d:
for mdl in d["models"]:
add_model_deps(mdl["model"], d["name"], out_path)
mdl["model"] = add_model_deps(mdl["model"], d["name"], out_path)

if require_nameless and not final_found:
logging.error("No final merge found")
sys.exit(1)

logging.info("Merging: %s", ", ".join(merges))

Expand Down

0 comments on commit 1011ef3

Please sign in to comment.