-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compression API supports strategy QAT #3271
Conversation
d8c9cbb
to
6af10e8
Compare
6af10e8
to
96b116e
Compare
…nto support-qat
8d89e0d
to
18fac4a
Compare
batch.pop("length") | ||
if "seq_len" in batch: | ||
batch.pop("seq_len") | ||
elif "start_positions" in batch and "end_positions" in batch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种分支代码太多,是不是可以通过配置List来解决
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢提醒,已经修改
dtype="int64") # input_ids | ||
] | ||
|
||
input_spec = generate_input_spec(self.model, self.train_dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了一下函数中的写法,不考 start_positions 和 end_positions 这个应该是针对UIE写法,不过 start_positions 和 end_positions 这两个字段也不会出现在UIE的模型中
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start_positions和end_positions这个在qa任务和UIE里都有,在原文中抽取的任务会有的。这个函数是想通过forward的参数、dataloader的数据来判断input_spec的个数,需要排除掉labels/start_positions和end_positions
…nto support-qat
2beaa20
to
fb219bc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This reverts commit 217a25c.
PR types
New features
PR changes
APIs
Description
Compression API supports strategy QAT
DONE:
custom_dynabert_calc_loss
,loss function直接提供在Trainer 初始化时;Usage: