Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compression API supports strategy QAT #3271

Merged
merged 7 commits into from
Oct 17, 2022

Conversation

LiuChiachi
Copy link
Contributor

@LiuChiachi LiuChiachi commented Sep 14, 2022

PR types

New features

PR changes

APIs

Description

Compression API supports strategy QAT
DONE:

  • 支持QAT
  • 优化了compress()接口,不需要输入custom_dynabert_calc_loss,loss function直接提供在Trainer 初始化时;
  • 更新了文档

Usage:

cd PaddleNLP/model_zoo/ernie-3.0

python compress_seq_cls.py \
    --dataset   "clue cluewsc2020"   \
    --model_name_or_path ernie-3.0-nano-zh \
    --per_device_train_batch_size 32 \
    --output_dir ./test \
    --per_device_eval_batch_size 32 \
    --num_train_epochs 5 \
    --width_mult_list 2/3 \
    --strategy 'qat' \
    --batch_size_list 4 \
    --algo_list 'abs_max' \



python compress_token_cls.py      --dataset   "msra_ner"   \
    --model_name_or_path best_models/MSRA_NER/   \
    --output_dir ./  --remove_unused_columns False   \
    --max_seq_length 32    \
    --per_device_train_batch_size 32   \
    --per_device_eval_batch_size 32  \
    --learning_rate 0.00005    \
    --remove_unused_columns False  \
    --num_train_epochs 1 \
    --batch_size_list 4 \
    --algo_list 'abs_max' \
    --strategy 'qat'


python compress_qa.py \
    --dataset "clue cmrc2018" \
    --width_mult_list 2/3 \
    --model_name_or_path best_models/CMRC2018  \
    --output_dir ./ \
    --max_seq_length 32 \
    --learning_rate 0.00003 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 24 \
    --per_device_eval_batch_size 24 \
    --max_answer_length 50 \
    --strategy 'qat' \
    --batch_size_list 4 \
    --algo_list 'abs_max' \

@LiuChiachi LiuChiachi force-pushed the support-qat branch 3 times, most recently from d8c9cbb to 6af10e8 Compare September 29, 2022 12:49
@LiuChiachi LiuChiachi marked this pull request as ready for review September 29, 2022 12:49
@LiuChiachi LiuChiachi requested a review from wawltor October 14, 2022 04:10
batch.pop("length")
if "seq_len" in batch:
batch.pop("seq_len")
elif "start_positions" in batch and "end_positions" in batch:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种分支代码太多,是不是可以通过配置List来解决

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢提醒,已经修改

dtype="int64") # input_ids
]

input_spec = generate_input_spec(self.model, self.train_dataset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了一下函数中的写法,不考 start_positions 和 end_positions 这个应该是针对UIE写法,不过 start_positions 和 end_positions 这两个字段也不会出现在UIE的模型中

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_positions和end_positions这个在qa任务和UIE里都有,在原文中抽取的任务会有的。这个函数是想通过forward的参数、dataloader的数据来判断input_spec的个数,需要排除掉labels/start_positions和end_positions

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LiuChiachi LiuChiachi merged commit 217a25c into PaddlePaddle:develop Oct 17, 2022
joey12300 added a commit to joey12300/PaddleNLP that referenced this pull request Oct 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants