Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[🐛BUG] DIN Model, 负采样时,interaction.repeat 处理 token_seq 类型变量的时候维度不匹配错误 #930

Closed
transposition opened this issue Aug 16, 2021 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@transposition
Copy link

Describe the bug

 def repeat(self, sizes):
        """Repeats each tensor along the batch dim.
        Args:
            sizes (int): repeat times.
        Example:
            >>> a = Interaction({'k': torch.zeros(4)})
            >>> a.repeat(3)
            The batch_size of interaction: 12
                k, torch.Size([12]), cpu
            >>> a = Interaction({'k': torch.zeros(4, 7)})
            >>> a.repeat(3)
            The batch_size of interaction: 12
                k, torch.Size([12, 7]), cpu
        Returns:
            a copyed Interaction object with repeated Tensors.
        """
        ret = {}
        for k in self.interaction:
            if len(self.interaction[k].shape) == 1:
                ret[k] = self.interaction[k].repeat(sizes)
            else:
                ret[k] = self.interaction[k].repeat([sizes, 1])
        new_pos_len_list = self.pos_len_list * sizes if self.pos_len_list else None
        new_user_len_list = self.user_len_list * sizes if self.user_len_list else None
        return Interaction(ret, new_pos_len_list, new_user_len_list)
对于token_seq  self.interaction[k] 会是三维batch size, hist length, seq  size),而不是两维
 if len(self.interaction[k].shape) == 1:
                ret[k] = self.interaction[k].repeat(sizes)
            else:
                ret[k] = self.interaction[k].repeat([sizes, 1])  

Expected behavior

 if len(self.interaction[k].shape) == 1:
                ret[k] = self.interaction[k].repeat(sizes)
            else:
                ret[k] = self.interaction[k].repeat([sizes ] + [1] *( len(self.interaction[k].shape) -1 )  )
@transposition transposition added the bug Something isn't working label Aug 16, 2021
chenyushuo added a commit to chenyushuo/RecBole that referenced this issue Aug 17, 2021
@chenyushuo
Copy link
Collaborator

chenyushuo commented Aug 17, 2021

感谢你发现了这个bug,我们在#933 中修复了它。

@chenyushuo chenyushuo self-assigned this Aug 17, 2021
2017pxy added a commit that referenced this issue Aug 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants