Skip to content

Commit

Permalink
fix: turn missing_mask into torch.float;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Sep 21, 2023
1 parent acfe255 commit 18945e8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pypots/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _fetch_data_from_array(self, idx: int) -> Iterable:
"""

X = self.X[idx].to(torch.float32)
missing_mask = ~torch.isnan(X)
missing_mask = (~torch.isnan(X)).to(torch.float32)
X = torch.nan_to_num(X)
sample = [
torch.tensor(idx),
Expand Down Expand Up @@ -280,7 +280,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
self.file_handle = self._open_file_handle()

X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32)
missing_mask = ~torch.isnan(X)
missing_mask = (~torch.isnan(X)).to(torch.float32)
X = torch.nan_to_num(X)
sample = [
torch.tensor(idx),
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/gpvae/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
self.file_handle = self._open_file_handle()

X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32)
missing_mask = ~torch.isnan(X)
missing_mask = (~torch.isnan(X)).to(torch.float32)
X = torch.nan_to_num(X)

sample = [
Expand Down
2 changes: 1 addition & 1 deletion tests/imputation/gpvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_2_parameters(self):
and self.gp_vae.best_model_dict is not None
)

@pytest.mark.xdist_group(name="imputation-GPVAE")
@pytest.mark.xdist_group(name="imputation-gpvae")
def test_3_saving_path(self):
# whether the root saving dir exists, which should be created by save_log_into_tb_file
assert os.path.exists(
Expand Down

0 comments on commit 18945e8

Please sign in to comment.