Skip to content

Commit

Permalink
Fixed PolBlogs dataset expected data format (#6715)
Browse files Browse the repository at this point in the history
Fixed issue #6714

Co-authored-by: Bernardo Marenco <bmarenco@fing.edu.uy>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Feb 15, 2023
1 parent 29e5903 commit 6e259a6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.3.0] - 2023-MM-DD

### Added

- Fixed expected data format in `PolBlogs` dataset ([#6714](https://github.com/pyg-team/pytorch_geometric/issues/6714))
- Added `SimpleConv` to perform non-trainable propagation ([#6718](https://github.com/pyg-team/pytorch_geometric/pull/6718))
- Added a `RemoveDuplicatedEdges` transform ([#6709](https://github.com/pyg-team/pytorch_geometric/pull/6709))
- Added TorchScript support to the `LINKX` model ([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712))
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/datasets/polblogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, root: str, transform: Optional[Callable] = None,

@property
def raw_file_names(self) -> List[str]:
return ['adjacency.csv', 'labels.csv']
return ['adjacency.tsv', 'labels.tsv']

@property
def processed_file_names(self) -> str:
Expand All @@ -73,10 +73,11 @@ def download(self):
def process(self):
import pandas as pd

edge_index = pd.read_csv(self.raw_paths[0], header=None)
edge_index = pd.read_csv(self.raw_paths[0], header=None, sep='\t',
usecols=[0, 1])
edge_index = torch.from_numpy(edge_index.values).t().contiguous()

y = pd.read_csv(self.raw_paths[1], header=None)
y = pd.read_csv(self.raw_paths[1], header=None, sep='\t')
y = torch.from_numpy(y.values).view(-1)

data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))
Expand Down

0 comments on commit 6e259a6

Please sign in to comment.