Skip to content

Commit

Permalink
Add Embedding quantization (#4159)
Browse files Browse the repository at this point in the history
* add quant emb

* support quant embeddings

* remove useless log
  • Loading branch information
LiuChiachi authored Dec 29, 2022
1 parent 6f5c287 commit 5d542a2
Showing 1 changed file with 63 additions and 11 deletions.
74 changes: 63 additions & 11 deletions paddlenlp/trainer/trainer_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,50 @@ def compress(self, custom_evaluate=None):
_dynabert(self, self.model, args.output_dir)
if "ptq" in args.strategy:
self.args.input_filename_prefix = "pruned_model"
output_dir_list = []
for width_mult in args.width_mult_list:
output_dir_width = os.path.join(args.output_dir, "width_mult_" + str(round(width_mult, 2)))
self.quant(output_dir_width, "ptq")
elif args.strategy == "ptq":
# Input model is an inference model
output_dir_list += self.quant(output_dir_width, "ptq")
if "embeddings" in args.strategy:
for output_dir in output_dir_list:
self.quant(os.path.join(output_dir, args.output_filename_prefix), "embeddings")
elif "ptq" in args.strategy:
# When input model is an inference model
if args.input_infer_model_path is not None:
model_dir = os.path.dirname(args.input_infer_model_path)
self.args.input_filename_prefix = os.path.basename(args.input_infer_model_path)
self.quant(model_dir, args.strategy)
output_dir_list = self.quant(model_dir, "ptq")
# Input model is load from Trainer API in dygraph.
else:
# When input model is a dygraph.
# exports model and then do 'ptq'
# Prefix of `export_model` is 'model'
self.args.input_filename_prefix = "model"
input_spec = generate_input_spec(self.model, self.train_dataset)
input_dir = args.output_dir
export_model(model=self.model, input_spec=input_spec, path=input_dir)
self.quant(input_dir, args.strategy)
elif args.strategy == "qat":
output_dir_list = self.quant(input_dir, "ptq")
if "embeddings" in args.strategy:
for output_dir in output_dir_list:
self.quant(os.path.join(output_dir, args.output_filename_prefix), "embeddings")
elif "qat" in args.strategy:
global_try_import_slim()
self.quant(args.output_dir, args.strategy)
self.quant(args.output_dir, "qat")
if "embeddings" in args.strategy:
self.quant(os.path.join(args.output_dir, args.output_filename_prefix), "embeddings")


def quant(self, model_dir, strategy):
"""
Supports Post-Training Quantization now.
Supports Post-Training Quantization, Quantization Aware Training and
Embedding Quantization.
"""
if strategy == "ptq":
_post_training_quantization_grid_search(self, model_dir)
return _post_training_quantization_grid_search(self, model_dir)
elif strategy == "qat":
_quant_aware_training_dynamic(self, model_dir)
elif strategy == "embeddings":
_quant_embeddings(self, model_dir)


def generate_input_spec(model, dataset):
Expand Down Expand Up @@ -138,7 +152,7 @@ def _dynabert(self, model, output_dir):
ofa_model = _dynabert_training(
self, ofa_model, model, teacher_model, train_dataloader, eval_dataloader, args.num_train_epochs
)

self.reset_optimizer_and_scheduler()
# Each width_mult best model would be exported.
_dynabert_export(self, ofa_model)

Expand Down Expand Up @@ -540,6 +554,7 @@ def _post_training_quantization_grid_search(self, model_dir):
exe = paddle.static.Executor(place)

args.output_filename_prefix = "int8"
output_dir_list = []

def _post_training_quantization(algo, batch_size, batch_nums):
try:
Expand Down Expand Up @@ -587,11 +602,13 @@ def _batch_generator_func():
optimize_model=False,
)
post_training_quantization.quantize()
save_model_path = os.path.join(model_dir, algo + "_".join([str(batch_size), str(batch_nums)]))
post_training_quantization.save_quantized_model(
save_model_path=os.path.join(model_dir, algo + "_".join([str(batch_size), str(batch_nums)])),
save_model_path=save_model_path,
model_filename=args.output_filename_prefix + ".pdmodel",
params_filename=args.output_filename_prefix + ".pdiparams",
)
output_dir_list.append(save_model_path)

logger.info("Post training quantization starts.")
for algo in args.algo_list:
Expand All @@ -601,6 +618,7 @@ def _batch_generator_func():

paddle.disable_static()
logger.info("Post training quantization ends and quantized models are saved.")
return output_dir_list


def _quant_aware_training_dynamic(self, input_dir):
Expand Down Expand Up @@ -725,6 +743,35 @@ def _quant_aware_training_dynamic(self, input_dir):
logger.info("Quant aware training ends and quantized models are saved.")


def _quant_embeddings(self, input_prefix):
import paddleslim.quant as quant

self.args.output_filename_prefix = "quant_emb"

paddle.enable_static()
place = paddle.set_device(self.args.device)
exe = paddle.static.Executor(place)
main_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(input_prefix, exe)

config = {"quantize_op_types": ["lookup_table_v2"], "lookup_table_v2": {"quantize_type": "log"}}

quant_emb_program = quant.quant_embedding(main_program, place, config)

input_dir = os.path.dirname(input_prefix)

paddle.fluid.io.save_inference_model(
input_dir,
feed_target_names,
fetch_targets,
exe,
quant_emb_program,
model_filename=self.args.output_filename_prefix + ".pdmodel",
params_filename=self.args.output_filename_prefix + ".pdiparams",
export_for_deployment=True,
program_only=False,
)


def auto_model_dynabert_forward(
self,
input_ids,
Expand Down Expand Up @@ -865,5 +912,10 @@ def soft_cross_entropy(inp, target):
return -1.0 * paddle.mean(paddle.sum(inp_likelihood * target_prob, axis=-1))


def reset_optimizer_and_scheduler(self):
self.optimizer, self.lr_scheduler = None, None


Trainer.compress = compress
Trainer.quant = quant
Trainer.reset_optimizer_and_scheduler = reset_optimizer_and_scheduler

0 comments on commit 5d542a2

Please sign in to comment.