Skip to content

Commit

Permalink
mypy fix
Browse files Browse the repository at this point in the history
  • Loading branch information
anpolol committed Dec 4, 2023
1 parent 7a9c89c commit fcf5d7a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions stable_gnn/fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,15 @@ def _cuae(self, y_true: List, y_pred: List, sensitive_features: List) -> Dict[st
ans = {"df": df, "diff": total_diff, "ratio": total_ratio, "variation": variation}
return ans

def _zeros_ones_to_classes(self, x: List, length: int = 3):
def _zeros_ones_to_classes(self, x: np.array, length: int = 3) -> np.array:
n = int(len(x) / length)
p = []
for i in range(n):
z = x[i * length : i * length + length]
p.append(z.argmax())
return np.array(p, dtype=int)

def _answer_creator(self, x: List, y: List, grouper: List):
def _answer_creator(self, x: List, y: List, grouper: List) -> np.array:
x = np.array(x) # array of 1
y = np.array(y) # array of 0
grouper = np.array(grouper)
Expand Down
4 changes: 2 additions & 2 deletions stable_gnn/model_gc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import List, Optional, Tuple

import bamt.networks as Nets
import bamt.Networks as Nets
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from bamt.preprocessors import Preprocessor
from bamt.Preprocessors import Preprocessor
from pgmpy.estimators import K2Score
from sklearn import preprocessing
from torch import device
Expand Down

0 comments on commit fcf5d7a

Please sign in to comment.