diff --git a/rhana/labeler/masker.py b/rhana/labeler/masker.py index c2d1d45..2cf5ed4 100644 --- a/rhana/labeler/masker.py +++ b/rhana/labeler/masker.py @@ -106,6 +106,7 @@ def predict(self, rd, do_rle:bool=False, threshold:bool=0.5): inp = (inp - self.normalize.mean)/self.normalize.std scores = torch.sigmoid(self.learn.model(inp)) masks = scores > threshold + masks = masks[0] # masks = scores > threshold classes = self.learn.classes # classes variable store which label is predicted channel-wise