We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Describe the bug
def repeat(self, sizes): """Repeats each tensor along the batch dim. Args: sizes (int): repeat times. Example: >>> a = Interaction({'k': torch.zeros(4)}) >>> a.repeat(3) The batch_size of interaction: 12 k, torch.Size([12]), cpu >>> a = Interaction({'k': torch.zeros(4, 7)}) >>> a.repeat(3) The batch_size of interaction: 12 k, torch.Size([12, 7]), cpu Returns: a copyed Interaction object with repeated Tensors. """ ret = {} for k in self.interaction: if len(self.interaction[k].shape) == 1: ret[k] = self.interaction[k].repeat(sizes) else: ret[k] = self.interaction[k].repeat([sizes, 1]) new_pos_len_list = self.pos_len_list * sizes if self.pos_len_list else None new_user_len_list = self.user_len_list * sizes if self.user_len_list else None return Interaction(ret, new_pos_len_list, new_user_len_list)
对于token_seq self.interaction[k] 会是三维(batch size, hist length, seq size),而不是两维 if len(self.interaction[k].shape) == 1: ret[k] = self.interaction[k].repeat(sizes) else: ret[k] = self.interaction[k].repeat([sizes, 1])
Expected behavior
if len(self.interaction[k].shape) == 1: ret[k] = self.interaction[k].repeat(sizes) else: ret[k] = self.interaction[k].repeat([sizes ] + [1] *( len(self.interaction[k].shape) -1 ) )
The text was updated successfully, but these errors were encountered:
FIX: fix issue RUCAIBox#930.
eac3f02
感谢你发现了这个bug,我们在#933 中修复了它。
Sorry, something went wrong.
Merge pull request #933 from chenyushuo/data
487c8ab
FIX: fix issue #930.
chenyushuo
No branches or pull requests
Describe the bug
Expected behavior
The text was updated successfully, but these errors were encountered: