Skip to content

Commit

Permalink
fix attention mask
Browse files Browse the repository at this point in the history
  • Loading branch information
gongel committed Sep 9, 2022
1 parent 06038b6 commit f560a6b
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions model_zoo/gpt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,17 +442,14 @@ def _construct_sample(self, tokens):
labels = tokens[1:]
tokens = tokens[:-1]
seq_length = len(tokens)
# Attention mask for the attention calulate
attention_mask = np.tri(seq_length, seq_length).reshape(
(1, seq_length, seq_length))
# No padding, so attention_mask is None
attention_mask = None

# The pad and eos tokens do not contribute the loss
loss_mask = np.ones(seq_length, dtype="float32")
loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0
position_ids = np.arange(0, seq_length, dtype="int64")

attention_mask = (attention_mask - 1.0) * 1e9
attention_mask = attention_mask.astype("float32")
labels = np.array(labels, dtype="int64")
return [tokens, loss_mask, attention_mask, position_ids, labels]

Expand Down

0 comments on commit f560a6b

Please sign in to comment.