Skip to content

Commit

Permalink
convert-hf-to-gguf.py: add --get-outfile command and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mofosyne committed May 23, 2024
1 parent 9b82476 commit 53096ef
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 54 deletions.
71 changes: 59 additions & 12 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Model:
tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None
fname_out: Path
fname_default: Path
gguf_writer: gguf.GGUFWriter

# subclasses should define this!
Expand Down Expand Up @@ -91,10 +92,27 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
else:
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
ftype_up: str = self.ftype.name.partition("_")[2].upper()
ftype_lw: str = ftype_up.lower()

# Generate default filename based on model specification and available metadata
version_string = None # TODO: Add metadata support
expert_count = self.hparams["num_local_experts"] if "num_local_experts" in self.hparams else None
encodingScheme = {
gguf.LlamaFileType.ALL_F32 : "F32",
gguf.LlamaFileType.MOSTLY_F16 : "F16",
gguf.LlamaFileType.MOSTLY_BF16 : "BF16",
gguf.LlamaFileType.MOSTLY_Q8_0 : "Q8_0",
}[self.ftype]
self.fname_default = f"{gguf.naming_convention(dir_model.name, version_string, expert_count, self.parameter_count(), encodingScheme)}"

# Filename Output
if fname_out is not None:
# custom defined filename and path was provided
self.fname_out = fname_out
else:
# output in the same directory as the model by default
self.fname_out = dir_model.parent / self.fname_default

# allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
self.gguf_writer = gguf.GGUFWriter(self.fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)

@classmethod
Expand Down Expand Up @@ -240,6 +258,25 @@ def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: i

return False

def parameter_count(self):
total_model_parameters = 0
for name, data_torch in self.get_tensors():
# Got A Tensor

# We don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
continue

# Calculate Tensor Volume
sum_weights_in_tensor = 1
for dim in data_torch.shape:
sum_weights_in_tensor *= dim

# Add Tensor Volume To Running Count
total_model_parameters += sum_weights_in_tensor

return total_model_parameters

def write_tensors(self):
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")

Expand Down Expand Up @@ -2551,14 +2588,24 @@ def parse_args() -> argparse.Namespace:
"--verbose", action="store_true",
help="increase output verbosity",
)
parser.add_argument(
"--get-outfile", action="store_true",
help="get calculated default outfile name"
)

return parser.parse_args()


def main() -> None:
args = parse_args()

logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
elif args.get_outfile:
# Avoid printing anything besides the dump output
logging.basicConfig(level=logging.WARNING)
else:
logging.basicConfig(level=logging.INFO)

dir_model = args.model

Expand Down Expand Up @@ -2587,19 +2634,19 @@ def main() -> None:
"auto": gguf.LlamaFileType.GUESSED,
}

if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / 'ggml-model-{ftype}.gguf'

logger.info(f"Loading model: {dir_model.name}")

hparams = Model.load_hparams(dir_model)

with torch.inference_mode():
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)
encodingScheme = ftype_map[args.outtype]
model_architecture = hparams["architectures"][0]
model_class = Model.from_model_architecture(model_architecture)
model_instance = model_class(dir_model, encodingScheme, args.outfile, args.bigendian, args.use_temp_file, args.no_lazy)

if args.get_outfile:
print(f"{model_instance.fname_default}") # noqa: NP100
return

logger.info("Set model parameters")
model_instance.set_gguf_parameters()
Expand Down
61 changes: 19 additions & 42 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,35 +1320,17 @@ def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileT

def model_parameter_count(model: LazyModel) -> int:
total_model_parameters = 0
for i, (name, lazy_tensor) in enumerate(model.items()):
for name, lazy_tensor in model.items():
# Got A Tensor
sum_weights_in_tensor = 1
# Tensor Volume
for dim in lazy_tensor.shape:
sum_weights_in_tensor *= dim
# Add Tensor Volume To Running Count
total_model_parameters += sum_weights_in_tensor
return total_model_parameters


def model_parameter_count_rounded_notation(model_params_count: int) -> str:
if model_params_count > 1e12 :
# Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12
scale_suffix = "T"
elif model_params_count > 1e9 :
# Billions Of Parameters
scaled_model_params = model_params_count * 1e-9
scale_suffix = "B"
elif model_params_count > 1e6 :
# Millions Of Parameters
scaled_model_params = model_params_count * 1e-6
scale_suffix = "M"
else:
# Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3
scale_suffix = "K"

return f"{round(scaled_model_params)}{scale_suffix}"


def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel:
return {name: tensor.astype(output_type.type_for_tensor(name, tensor))
for (name, tensor) in model.items()}
Expand Down Expand Up @@ -1529,29 +1511,24 @@ def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) ->


def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
quantization = {
GGMLFileType.AllF32: "F32",
GGMLFileType.MostlyF16: "F16",
GGMLFileType.MostlyQ8_0: "Q8_0",
}[file_type]

parameters = model_parameter_count_rounded_notation(model_params_count)

expert_count = ""
if params.n_experts is not None:
expert_count = f"{params.n_experts}x"

version = ""
if metadata is not None and metadata.version is not None:
version = f"-{metadata.version}"

name = "ggml-model"
name = None
if metadata is not None and metadata.name is not None:
name = metadata.name
elif params.path_model is not None:
name = params.path_model.name

return f"{name}{version}-{expert_count}{parameters}-{quantization}"
version = metadata.version if metadata is not None and metadata.version is not None else None

expert_count = params.n_experts if params.n_experts is not None else None

encodingScheme = {
GGMLFileType.AllF32: "F32",
GGMLFileType.MostlyF16: "F16",
GGMLFileType.MostlyQ8_0: "Q8_0",
}[file_type]

return gguf.naming_convention(name, version, expert_count, model_params_count, encodingScheme)


def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
Expand Down Expand Up @@ -1612,9 +1589,9 @@ def main(args_in: list[str] | None = None) -> None:
if args.get_outfile:
model_plus = load_some_model(args.model)
params = Params.load(model_plus)
model = convert_model_names(model_plus.model, params, args.skip_unknown)
model = convert_model_names(model_plus.model, params, args.skip_unknown)
model_params_count = model_parameter_count(model_plus.model)
ftype = pick_output_type(model, args.outtype)
ftype = pick_output_type(model, args.outtype)
print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
return

Expand All @@ -1632,7 +1609,7 @@ def main(args_in: list[str] | None = None) -> None:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)

model_params_count = model_parameter_count(model_plus.model)
logger.info(f"model parameters count : {model_params_count} ({model_parameter_count_rounded_notation(model_params_count)})")
logger.info(f"model parameters count : {model_params_count} ({gguf.model_parameter_count_rounded_notation(model_params_count)})")

if args.dump:
do_dump_model(model_plus)
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .quants import *
from .tensor_mapping import *
from .vocab import *
from .utility import *
35 changes: 35 additions & 0 deletions gguf-py/gguf/utility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations


def model_parameter_count_rounded_notation(model_params_count: int) -> str:
if model_params_count > 1e15 :
# Quadrillion Of Parameters
scaled_model_params = model_params_count * 1e-15
scale_suffix = "Q"
elif model_params_count > 1e12 :
# Trillions Of Parameters
scaled_model_params = model_params_count * 1e-12
scale_suffix = "T"
elif model_params_count > 1e9 :
# Billions Of Parameters
scaled_model_params = model_params_count * 1e-9
scale_suffix = "B"
elif model_params_count > 1e6 :
# Millions Of Parameters
scaled_model_params = model_params_count * 1e-6
scale_suffix = "M"
else:
# Thousands Of Parameters
scaled_model_params = model_params_count * 1e-3
scale_suffix = "K"
return f"{round(scaled_model_params)}{scale_suffix}"


def naming_convention(model_name: str, version_string:str, expert_count_int:int, model_params_count: int, encodingScheme: str) -> str:
# Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gguf-naming-convention
name = model_name.strip().replace(' ', '-') if model_name is not None else "ggml-model"
version = f"-{version_string}" if version_string is not None else ""
expert_count_chunk = f"{expert_count_int}x" if expert_count_int is not None else ""
parameters = model_parameter_count_rounded_notation(model_params_count)
encodingScheme = encodingScheme.upper()
return f"{name}{version}-{expert_count_chunk}{parameters}-{encodingScheme}"

0 comments on commit 53096ef

Please sign in to comment.