-
Notifications
You must be signed in to change notification settings - Fork 174
Add option to use scheduled sampling in CopyNet #309
Add option to use scheduled sampling in CopyNet #309
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @JohnGiorgi! I think this is a good addition. I just have a couple of suggestions:
- I think we should have a test
- I think the default for
scheduled_sampling_ratio
should beNone
. And then when it isNone
we shouldn't calltorch.rand()
. That way there is no performance penalty for this feature.
@epwalsh Awesome, thanks for the feedback.
I could also update |
# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio | ||
# during training. | ||
# shape: (batch_size,) | ||
input_choices = last_predictions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I realized this implementation doesn't work, because last_predictions
is never updated. I would have had to take the index of the token with the highest probability for this timestep under the model. Something like:
last_predictions = torch.max(torch.cat((generation_scores, copy_scores), -1), -1)
@epwalsh does this make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yeup, good catch. To avoid duplicate computation you could use all_scores
from the _get_ll_contrib()
method. And note that you will need to take into account this mask. So I suggest returning all_scores
and mask
from _get_ll_contrib
so you can use them here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. Could I just return log_probs
from _get_ll_contrib()
? Its computed like: log_probs = util.masked_log_softmax(all_scores, mask)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, just pushed that change.
# Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio | ||
# during training. | ||
# shape: (batch_size,) | ||
input_choices = last_predictions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yeup, good catch. To avoid duplicate computation you could use all_scores
from the _get_ll_contrib()
method. And note that you will need to take into account this mask. So I suggest returning all_scores
and mask
from _get_ll_contrib
so you can use them here.
def test_model_can_train_with_scheduled_sampling_ratio(self): | ||
train_model_from_file( | ||
self.param_file, | ||
self.TEST_DIR, | ||
overrides="{'model.scheduled_sampling_ratio':0.5}", | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@epwalsh Added the same test for scheduled sampling to simple_seq2seq
.
if ( | ||
self.training | ||
and self._scheduled_sampling_ratio > 0.0 | ||
and torch.rand(1).item() < self._scheduled_sampling_ratio | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@epwalsh Added a similar condition to simple_seq2seq
to avoid the call to torch.rand
when _scheduled_sampling_ratio
is 0.0
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM! Can you just update the CHANGELOG? Then I think this is good to go.
Cool! Changelog updated 👍 |
This PR adds the ability to use scheduled sampling in
CopyNetSeq2Seq
by supplying an argument forscheduled_sampling_ratio
that's greater than zero. It is essentially a copy/paste fromSimpleSeq2Seq
.This helps reduce the differences in the
SimpleSeq2Seq
andCopyNetSeq2Seq
model arguments. It is also backwards compatible with a default value of 0 (no scheduled sampling i.e. teacher forcing).