Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DGLError Expected data to have %d rows, got %d. occurs at large batch size #4512

Open
sidazhou opened this issue Sep 5, 2022 · 10 comments
Open
Assignees
Labels
bug:confirmed Something isn't working

Comments

@sidazhou
Copy link

sidazhou commented Sep 5, 2022

🐛 Bug

DGLError('Expected data to have %d rows, got %d.') occurs at large batch__size, and doesnt occur at smaller batch_size. The larger the batch_size the larger the difference in rows. Feels like a rounding error somewhere.

To Reproduce

BATCH_SIZE = 1000 #  <---- works fine
# BATCH_SIZE = 5000 # <---- DGL errors

sampler = dgl.dataloading.NeighborSampler([4, 4])
_, _, mfgs = sampler.sample_blocks(train_pos_g, seed_ids[:BATCH_SIZE])

print(mfgs[0].srcdata['feat'].shape)
# torch.Size([10239, 128]) <---- works fine
# torch.Size([48913, 128]) <---- DGL errors

model(mfgs, mfgs[0].srcdata['feat']) # <---- errors

Expected behavior

Shouldn't DGLError

Environment

  • DGL 0.8.2
  • pytorch 1.11.0+cpu
  • python 3.10.5

Additional context

model:

# model is the default 2 layer graphSage in the tutorials
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
class Model(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
        self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
        self.h_feats = h_feats

    def forward(self, mfgs, x):
        h_dst = x[:mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
        h_dst = h[:mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        return h

Error stack:

    ---------------------------------------------------------------------------
    DGLError                                  Traceback (most recent call last)
    Input In [58], in <cell line: 13>()
          9 print(mfgs[0].srcdata['feat'].shape)
         10 # torch.Size([10239, 128]) works fine
         11 # torch.Size([48913, 128]) DGL errors
    ---> 13 model(mfgs, mfgs[0].srcdata['feat'])

    File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
       1106 # If we don't have any hooks, we want to skip the rest of the logic in
       1107 # this function, and just call forward.
       1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1109         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1110     return forward_call(*input, **kwargs)
       1111 # Do not call functions when jit is used
       1112 full_backward_hooks, non_full_backward_hooks = [], []

    Input In [1], in Model.forward(self, mfgs, x)
        101 h = F.relu(h)
        102 h_dst = h[:mfgs[1].num_dst_nodes()]
    --> 103 h = self.conv2(mfgs[1], (h, h_dst))
        104 return h

    File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
       1106 # If we don't have any hooks, we want to skip the rest of the logic in
       1107 # this function, and just call forward.
       1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1109         or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1110     return forward_call(*input, **kwargs)
       1111 # Do not call functions when jit is used
       1112 full_backward_hooks, non_full_backward_hooks = [], []

    File /opt/conda/lib/python3.10/site-packages/dgl/nn/pytorch/conv/sageconv.py:235, in SAGEConv.forward(self, graph, feat, edge_weight)
        233 if self._aggre_type == 'mean':
        234     graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src
    --> 235     graph.update_all(msg_fn, fn.mean('m', 'neigh'))
        236     h_neigh = graph.dstdata['neigh']
        237     if not lin_before_mp:

    File /opt/conda/lib/python3.10/site-packages/dgl/heterograph.py:4900, in DGLHeteroGraph.update_all(self, message_func, reduce_func, apply_node_func, etype)
       4898         key = list(ndata.keys())[0]
       4899         ndata[key] = F.replace_inf_with_zero(ndata[key])
    -> 4900     self._set_n_repr(dtid, ALL, ndata)
       4901 else:   # heterogeneous graph with number of relation types > 1
       4902     if not core.is_builtin(message_func) or not core.is_builtin(reduce_func):

    File /opt/conda/lib/python3.10/site-packages/dgl/heterograph.py:4136, in DGLHeteroGraph._set_n_repr(self, ntid, u, data)
       4132         raise DGLError('Pinned graph requires the node data to be pinned as well. '
       4133                        'Please pin the node data before assignment.')
       4135 if is_all(u):
    -> 4136     self._node_frames[ntid].update(data)
       4137 else:
       4138     self._node_frames[ntid].update_row(u, data)

    File /opt/conda/lib/python3.10/_collections_abc.py:994, in MutableMapping.update(self, other, **kwds)
        992 if isinstance(other, Mapping):
        993     for key in other:
    --> 994         self[key] = other[key]
        995 elif hasattr(other, "keys"):
        996     for key in other.keys():

    File /opt/conda/lib/python3.10/site-packages/dgl/frame.py:584, in Frame.__setitem__(self, name, data)
        574 def __setitem__(self, name, data):
        575     """Update the whole column.
        576 
        577     Parameters
       (...)
        582         The column data.
        583     """
    --> 584     self.update_column(name, data)

    File /opt/conda/lib/python3.10/site-packages/dgl/frame.py:661, in Frame.update_column(self, name, data)
        659 col = Column.create(data)
        660 if len(col) != self.num_rows:
    --> 661     raise DGLError('Expected data to have %d rows, got %d.' %
        662                    (self.num_rows, len(col)))
        663 self._columns[name] = col

    DGLError: Expected data to have 5000 rows, got 4998.

@sidazhou sidazhou changed the title DGLError Expected data to have %d rows, got %d. occurs at large batch__size DGLError Expected data to have %d rows, got %d. occurs at large batch size Sep 5, 2022
@sidazhou
Copy link
Author

sidazhou commented Sep 6, 2022

image

Narrowed it down to output_nodes.shape mismatch with mfgs[1].dstnodes().shape, 100 vs 99. Why is this? Surely it's a bug?

@sidazhou
Copy link
Author

sidazhou commented Sep 7, 2022

So the issue seems to occur when seed_nodes contain duplicated id. Is this a bug or a feature?

@rudongyu
Copy link
Collaborator

rudongyu commented Sep 8, 2022

It seems that the to_block during sampling will remove duplicated nodes, thus it causes inconsistency between the number of destination nodes and the size of destination node features. @BarclayII I guess we should check possible duplications in seed_nodes before sampling. What do you think?

@sidazhou
Copy link
Author

sidazhou commented Sep 8, 2022

Surely it's a bug, right? Because dataloader is yielding mfgs that cannot be used as input for model()

@rudongyu rudongyu moved this to 🏠 Backlog in DGL Project Tracker Sep 8, 2022
@rudongyu rudongyu added the bug:confirmed Something isn't working label Sep 8, 2022
@RManLuo
Copy link

RManLuo commented Sep 16, 2022

Hi, I am also facing this problem. The seed_nodes I input contains some duplicated ids. But I need to get these duplicate embeddings. Is there any solution for now?

@RManLuo
Copy link

RManLuo commented Sep 19, 2022

I try to use dgl.dataloading.MultiLayerFullNeighborSampler to sample the blocks for a set of seed_nodes which contains the duplicated items. If I sample them in CPU, the returned mfg would contain inconsistent results. However, if I sample them in GPU, the duplicated seed nodes would not be removed. I think sample results on different devices should be the same.

import torch
import dgl

src = torch.LongTensor(
    [0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10,
     1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11])
dst = torch.LongTensor(
    [1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11,
     0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10])
g = dgl.graph((src, dst))

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

# Sample in CPU
idx = torch.LongTensor([8,8])
src_nodes, dst_nodes, mfgs = sampler.sample_blocks(g, idx)
print(dst_nodes) # tensor([8, 8])
print(mfgs[-1].num_dst_nodes()) # 1
print(mfgs[-1].dstdata) # {'_ID': tensor([8, 8])}
# Inconsistant

# Sample in GPU
device = torch.device('cuda:0')
src_nodes, dst_nodes, mfgs = sampler.sample_blocks(g.to(device), idx.to(device))
print(dst_nodes) # tensor([8, 8], device='cuda:0')
print(mfgs[-1].num_dst_nodes()) # 2
print(mfgs[-1].dstdata) # {'_ID': tensor([8, 8], device='cuda:0')}
# Consistant

@FAF-D2
Copy link

FAF-D2 commented Oct 20, 2022

BTW, it seems that sampling in GPU still can' t solve the problem of duplicated nodes in heterograph

import torch
import dgl
import dgl.nn.pytorch as dglnn

src = torch.LongTensor(
    [0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10,
     1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11])
dst = torch.LongTensor(
    [1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11,
     0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10])
graph_data = {
        ('user', 'plays', 'game') : (src, dst),
        ('user', 'follows', 'user'): (torch.LongTensor([0, 1, 2, 3]), torch.LongTensor([5, 6, 7, 8]))
    }
g = dgl.heterograph(graph_data)
g.nodes['user'].data['h'] = torch.ones(g.num_nodes('user'), 16)
g.nodes['game'].data['h'] = torch.ones(g.num_nodes('game'), 16)

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

uid = torch.LongTensor([0, 0, 2, 2, 4, 4])
device = torch.device('cuda:0')
src_nodes, dst_nodes, mfgs = sampler.sample_blocks(g.to(device), {'user': uid.to(device)})
print(dst_nodes)  # {'user': tensor([0, 0, 2, 2, 4, 4], device='cuda:0')}
print(mfgs[-1].num_dst_nodes()) # 6

conv1 = dglnn.HeteroGraphConv({
                    'plays': dglnn.SAGEConv(16, 32, 'gcn'),
                    'follows': dglnn.SAGEConv(16, 32, 'gcn')
                }, 'sum').to(device)
conv2 = dglnn.HeteroGraphConv({
                    'plays': dglnn.SAGEConv(32, 32, 'gcn'),
                    'follows': dglnn.SAGEConv(32, 32, 'gcn')
                }, 'sum').to(device)

out = mfgs[0].srcdata['h']

print(mfgs[0].num_dst_nodes()) # 3
print(len(out['game'])) # 0
print(len(out['user'])) # 3

out = conv1(mfgs[0], out)

print(mfgs[1].num_dst_nodes()) # 6
print(len(out['game'])) # 0
print(len(out['user'])) # 3

out = conv2(mfgs[1], out) # Error

@czkkkkkk
Copy link
Collaborator

czkkkkkk commented Oct 20, 2022

Sorry we currently don't support duplicate values in the seed nodes for sampler. We’ve added it to our backlog to get prioritized over other feature requests in our roadmap.

@yueliu1999
Copy link

I also met this error on the paper100M dataset. Has this bug been fixed yet? Are there any other potential solutions?

@czkkkkkk
Copy link
Collaborator

Haven't been solved yet. We suggest users to explicitly deduplicate seed nodes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug:confirmed Something isn't working
Projects
Status: 🏠 Backlog
Development

No branches or pull requests

6 participants