Skip to content

Commit

Permalink
Fix removed numpy types
Browse files Browse the repository at this point in the history
  • Loading branch information
bwroblew committed Jan 23, 2023
1 parent 0fc7262 commit 15d4c2f
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed the use of types removed in `numpy 1.24.0` ([#6495](https://github.com/pyg-team/pytorch_geometric/pull/6495))
- Fixed keyword parameters in `examples/mnist_voxel_grid.py` ([#6478](https://github.com/pyg-team/pytorch_geometric/pull/6478))
- Unified `LightningNodeData` and `LightningLinkData` code paths ([#6473](https://github.com/pyg-team/pytorch_geometric/pull/6473))
- Allow indices with any integer type in `RGCNConv` ([#6463](https://github.com/pyg-team/pytorch_geometric/pull/6463))
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def index_select(self, idx: IndexType) -> 'Dataset':
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
return self.index_select(idx.flatten().tolist())

elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
elif isinstance(idx, np.ndarray) and idx.dtype == bool:
idx = idx.flatten().nonzero()[0]
return self.index_select(idx.flatten().tolist())

Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/datasets/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections import Counter
from typing import Callable, List, Optional

import numpy as np
import torch

from torch_geometric.data import (
Expand Down Expand Up @@ -169,7 +168,7 @@ def process(self):
labels_df = pd.read_csv(task_file, sep='\t')
labels_set = set(labels_df[label_header].values.tolist())
labels_dict = {lab: i for i, lab in enumerate(list(labels_set))}
nodes_dict = {np.unicode(key): val for key, val in nodes_dict.items()}
nodes_dict = {str(key): val for key, val in nodes_dict.items()}

train_labels_df = pd.read_csv(train_file, sep='\t')
train_indices, train_labels = [], []
Expand Down

0 comments on commit 15d4c2f

Please sign in to comment.