-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEA: Add SGL in General models #1004
Conversation
- ``layers (int)`` : The number of layers in SGL. Defaults to ``3``. | ||
- ``lamb (float)`` : The temperature in softmax. Defaults to ``0.5``. | ||
- ``embedding_dim (int)`` : the embedding size of users and items. Defaults to ``64``. | ||
- ``ratio (float)`` The dropout ratio. Defaults to ``0.1``. | ||
- ``reg (float)`` : The L2 regularization weight. Defaults to ``1e-05``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename parameters to be consistent with other general models.
self.node_num = self.user_num + self.item_num | ||
self.user_embedding = torch.nn.Embedding(self.n_users, self.embed_dim, device=self.device) | ||
self.item_embedding = torch.nn.Embedding(self.n_items, self.embed_dim, device=self.device) | ||
self.dataset = dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defined but not used.
self.apply(xavier_uniform_initialization) | ||
self.update = True | ||
|
||
def comp(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let function name make sense, such as graph_augment()
?
#calc user side | ||
u_emd1 = F.normalize(user_sub1[user_list], dim=1) | ||
u_emd2 = F.normalize(user_sub2[user_list], dim=1) | ||
#all_emd2 = F.normalize(self.user_sub2,dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove useless code.
#calc item side | ||
i_emd1 = F.normalize(item_sub1[pos_item_list], dim=1) | ||
i_emd2 = F.normalize(item_sub2[pos_item_list], dim=1) | ||
#all_item = F.normalize(self.item_sub2,dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove useless code.
v1 = torch.sum(u_emd1 * u_emd2, dim=1) | ||
v2 = u_emd1.matmul(u_emd2.T) | ||
v1 = torch.exp(v1 / self.ssl_lamb) | ||
v2 = torch.sum(torch.exp(v2 / self.ssl_lamb), dim=1) | ||
ssl_user = -torch.sum(torch.log(v1 / v2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the original paper, SGl treats the views of any different nodes as the negative pairs, rather than other nodes in the batch.
v3 = torch.sum(i_emd1 * i_emd2, dim=1) | ||
v4 = i_emd1.matmul(i_emd2.T) | ||
v3 = torch.exp(v3 / self.ssl_lamb) | ||
v4 = torch.sum(torch.exp(v4 / self.ssl_lamb), dim=1) | ||
ssl_item = -torch.sum(torch.log(v3 / v4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
self.reg_loss = EmbLoss() | ||
self.train_graph = self.csr2tensor(self.create_adjust_matrix(is_sub=False)) | ||
self.apply(xavier_uniform_initialization) | ||
self.update = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let para name make sense.
No description provided.