diff --git a/pypots/classification/grud/data.py b/pypots/classification/grud/data.py index fc23132e..3287a6f6 100644 --- a/pypots/classification/grud/data.py +++ b/pypots/classification/grud/data.py @@ -63,6 +63,8 @@ def __init__( self.empirical_mean = torch.sum( self.missing_mask * self.X, dim=[0, 1] ) / torch.sum(self.missing_mask, dim=[0, 1]) + # fill nan with 0, in case some features have no observations + self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0) def _fetch_data_from_array(self, idx: int) -> Iterable: """Fetch data according to index.