Skip to content

Commit

Permalink
fix: fill missing values in empirical_mean with 0, make grud more rob…
Browse files Browse the repository at this point in the history
…ust;
  • Loading branch information
WenjieDu committed Jun 26, 2024
1 parent b3c419a commit 51771c4
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(
self.empirical_mean = torch.sum(
self.missing_mask * self.X, dim=[0, 1]
) / torch.sum(self.missing_mask, dim=[0, 1])
# fill nan with 0, in case some features have no observations
self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0)

def _fetch_data_from_array(self, idx: int) -> Iterable:
"""Fetch data according to index.
Expand Down

0 comments on commit 51771c4

Please sign in to comment.