Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some questions about 3DInfomax #13

Open
happyzhanglol opened this issue Jan 26, 2023 · 1 comment
Open

Some questions about 3DInfomax #13

happyzhanglol opened this issue Jan 26, 2023 · 1 comment

Comments

@happyzhanglol
Copy link

Dear professor,I have some questions about the 3DInfomax.
I want to get the evaluation metrics such as Precision,so I use the Function which you provided in your metric.py such as TruePositiveRate() and TrueNegativeRate() to get this metric. But I tried all OGB datasets and found that those metrics such as Precision,Accuracy and Recall were not ideal. I hope you can reply to me as soon as possible. Thank you, professor.

Here is the HIV dataset's metric:
Precision: 0.008995866402983665
Accuracy: 0.9988852739334106
Recall: 0.002496626228094101
F1_score: 0.003908519633114338
ROC_AUC: 0.7427065372467041
PR_AUC: 0.2141391634941101
ogbg-molhiv: 0.742706502636204
BCEWithLogitsLoss: 0.17792926660992883

Here is the BBBP dataset's metric:
Precision: 0.44607841968536377
Accuracy: 0.6127931475639343
Recall: 0.005654983688145876
F1_score: 0.011168383993208408
ROC_AUC: 0.6745756268501282
PR_AUC: 0.6546612977981567
ogbg-molbbbp: 0.6745756172839505
BCEWithLogitsLoss: 1.1453146849359785

Here is my metric code:
class Precision(nn.Module):
def init(self, threshold=0.5) -> None:
super(Precision, self).init()
self.threshold = threshold

def forward(self, x1: Tensor, x2: Tensor, pos_mask: Tensor = None) -> Tensor:
    batch_size, _ = x1.size()
    if x1.shape != x2.shape and pos_mask == None: 
        x2 = x2[:batch_size]
    sim_matrix = torch.einsum('ik,jk->ij', x1, x2)

    x1_abs = x1.norm(dim=1)
    x2_abs = x2.norm(dim=1)
    sim_matrix = sim_matrix / torch.einsum('i,j->ij', x1_abs, x2_abs)

    preds: Tensor = (sim_matrix + 1) / 2 > self.threshold
    if pos_mask == None:  # if we are comparing global with global
        pos_mask = torch.eye(batch_size, device=x1.device) 
        neg_mask = 1 - pos_mask 

    num_positives = len(x1)
    num_negatives = len(x1) * (len(x2) - 1)

    false_positives = ((preds.long() - pos_mask) * pos_mask).count_nonzero()
    true_positives = num_positives - ((preds.long() - pos_mask) * pos_mask).count_nonzero()

    false_negatives = (((~preds).long() - neg_mask) * neg_mask).count_nonzero()
    true_negatives = num_negatives - (((~preds).long() - neg_mask) * neg_mask).count_nonzero()

    pre = true_positives /(true_positives + false_positives)
    return pre
@A-Gentle-Cat
Copy link

Hello, I also tried to evaluate the model on bbbp, but the result was even worse than yours. May I take a look at your tune_bbbp.yml file?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants