import os import argparse import numpy as np from gliner import GLiNER import torch from onnxruntime.quantization import quantize_dynamic, QuantType if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, default= "logs/model_12000") parser.add_argument('--save_path', type=str, default = 'model/') parser.add_argument('--quantize', type=bool, default = True) args = parser.parse_args() if not os.path.exists(args.save_path): os.makedirs(args.save_path) onnx_save_path = os.path.join(args.save_path, "model.onnx") print("Loading a model...") gliner_model = GLiNER.from_pretrained(args.model_path, load_tokenizer=True) text = "ONNX is an open-source format designed to enable the interoperability of AI models across various frameworks and tools." labels = ['format', 'model', 'tool', 'cat'] inputs, _ = gliner_model.prepare_model_inputs([text], labels) if gliner_model.config.span_mode == 'token_level': all_inputs = (inputs['input_ids'], inputs['attention_mask'], inputs['words_mask'], inputs['text_lengths']) input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths'] dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, "words_mask": {0: "batch_size", 1: "sequence_length"}, "text_lengths": {0: "batch_size", 1: "value"}, "logits": {0: "position", 1: "batch_size", 2: "sequence_length", 3: "num_classes"}, } else: all_inputs = (inputs['input_ids'], inputs['attention_mask'], inputs['words_mask'], inputs['text_lengths'], inputs['span_idx'], inputs['span_mask']) input_names = ['input_ids', 'attention_mask', 'words_mask', 'text_lengths', 'span_idx', 'span_mask'] dynamic_axes={ "input_ids": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "sequence_length"}, "words_mask": {0: "batch_size", 1: "sequence_length"}, "text_lengths": {0: "batch_size", 1: "value"}, "span_idx": {0: "batch_size", 1: "num_spans", 2: "idx"}, "span_mask": {0: "batch_size", 1: "num_spans"}, "logits": {0: "batch_size", 1: "sequence_length", 2: "num_spans", 3: "num_classes"}, } print('Converting the model...') torch.onnx.export( gliner_model.model, all_inputs, f=onnx_save_path, input_names=input_names, output_names=["logits"], dynamic_axes=dynamic_axes, opset_version=14, ) if args.quantize: quantized_save_path = os.path.join(args.save_path, "model_quantized.onnx") # Quantize the ONNX model print("Quantizing the model...") quantize_dynamic( onnx_save_path, # Input model quantized_save_path, # Output model weight_type=QuantType.QUInt8 # Quantize weights to 8-bit integers ) print("Done!")