diff --git a/mergekit/scripts/megamerge.py b/mergekit/scripts/megamerge.py index 02ea0f2e..59fc7a42 100644 --- a/mergekit/scripts/megamerge.py +++ b/mergekit/scripts/megamerge.py @@ -53,7 +53,7 @@ def dfs(node, visited, stack): return None -def merge(m, merge_options, force, out_path): +def merge(m: str, merge_options: MergeOptions, force: bool, out_path: Path): """ Merges a model and its dependencies @@ -86,16 +86,14 @@ def merge(m, merge_options, force, out_path): del merges[m] -def add_model_deps(model, name, out_path): +def add_model_deps(model: str, name: str, out_path: Path): """ Adds a model to `name`s dependencies if it is not already there and is a merge """ model_lora = model.split("+") # name must not have a slash to avoid path traversal # therefore, we can use it to check if its a merge from the config - print(model_lora) if "/" not in model_lora[0]: - print(model_lora) # avoid duplicate deps if model_lora[0] not in merges[name]["deps"]: merges[name]["deps"].append(model_lora[0]) @@ -103,6 +101,8 @@ def add_model_deps(model, name, out_path): if len(model_lora) == 2: model += "+" + model_lora[1] + return model + @click.command("mergekit-mega") @click.argument("config_file") @@ -116,7 +116,15 @@ def add_model_deps(model, name, out_path): type=bool, default=False, is_flag=True, - help="overwrite existing merge results instead of skipping them", + help="Overwrite existing merge results instead of skipping them", +) +@click.option( + "--require-nameless", + "-R", + type=bool, + default=False, + is_flag=True, + help="Enforces exactly one unnamed merge in the YAML, which will inherit the input file's name.", ) @add_merge_options def main( @@ -125,6 +133,7 @@ def main( out_path: str, force: bool, verbose: bool, + require_nameless: bool, ): """ Main entrypoint for mergekit-mega command see module docstring for more info @@ -133,10 +142,20 @@ def main( logging.basicConfig(level=logging.INFO if verbose else logging.WARNING) out_path = Path(out_path) + final_found = False + with open(config_file, "r", encoding="utf-8") as f: data = yaml.load_all(f, Loader=yaml.FullLoader) for d in data: + if "name" not in d: + if final_found: + logging.error("Only one merge must not have a name") + sys.exit(1) + # this sets the name of the final merge to the config file name without the extension + d["name"] = os.path.basename(config_file).rsplit(".", maxsplit=1)[0] + final_found = True + if "/" in d["name"]: logging.error("name must not contain a slash") sys.exit(1) @@ -144,16 +163,18 @@ def main( merges[d["name"]] = d merges[d["name"]]["deps"] = [] if "base_model" in d: - add_model_deps(d["base_model"], d["name"], out_path) - if "/" not in d["base_model"]: - d["base_model"] = str(out_path / d["base_model"]) + d["base_model"] = add_model_deps(d["base_model"], d["name"], out_path) if "slices" in d: for slc in d["slices"]: for src in slc["sources"]: - add_model_deps(src["model"], d["name"], out_path) + src["model"] = add_model_deps(src["model"], d["name"], out_path) if "models" in d: for mdl in d["models"]: - add_model_deps(mdl["model"], d["name"], out_path) + mdl["model"] = add_model_deps(mdl["model"], d["name"], out_path) + + if require_nameless and not final_found: + logging.error("No final merge found") + sys.exit(1) logging.info("Merging: %s", ", ".join(merges))