Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile fails with a network with a SplineConv layer #8318

Open
cfd1 opened this issue Nov 3, 2023 · 3 comments
Open

Compile fails with a network with a SplineConv layer #8318

cfd1 opened this issue Nov 3, 2023 · 3 comments
Assignees

Comments

@cfd1
Copy link

cfd1 commented Nov 3, 2023

🐛 Describe the bug

Torch Geometric fails to compile a network using SplineConv. Please can this be fixed. I'm happy to help, but I will need quite a bit of guidance.

I have an example to show this based off the mnist_volex_grid.py. Replace line 92 onwards with this code to reproduce the error.

print("Training - Baseline")
start_time = time.time()
for epoch in range(1, 3):
    tmp_time = time.time()
    train(epoch)
    test_acc = test()
    print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f} in {time.time() - tmp_time:.1f}s')

print(f"End time: {time.time() - start_time:.1f}s")

print("Setting up compiled")
model = Net().to(device)
model = torch_geometric.compile(model, dynamic=False, fullgraph=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print("Training - Compiled")
start_time = time.time()
for epoch in range(1, 3):
    train(epoch)
    test_acc = test()
    print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}')

print(f"End time: {time.time() - start_time:.1f}s")

The error message

Training - Baseline
Epoch: 01, Test: 0.9393 in 29.7s
Epoch: 02, Test: 0.9445 in 20.9s
End time: 50.6s
Setting up compiled
Training - Compiled
Traceback (most recent call last):
  File "test.py", line 111, in <module>
    train(epoch)
  File "test.py", line 78, in train
    F.nll_loss(model(data), data.y).backward()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 493, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 132, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 370, in _convert_frame_assert
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 554, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 465, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 432, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2071, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1191, in LOAD_ATTR
    result = BuiltinVariable(getattr).call_function(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 608, in call_function
    result = handler(tx, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 1102, in call_getattr
    obj.var_getattr(tx, name).clone(source=source).add_options(options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/user_defined.py", line 430, in var_getattr
    ).call_function(tx, [], {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2176, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2283, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1568, in CONTAINS_OP
    self.push(right.call_method(self, "__contains__", [left], {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/user_defined.py", line 301, in call_method
    ).call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 307, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 261, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 598, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2176, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2229, in inline_call_
    InliningInstructionTranslator.check_inlineable(func)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2210, in check_inlineable
    unimplemented(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/exc.py", line 143, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: inline in skipfiles: Mapping.__contains__  | __contains__ /usr/lib/python3.10/_collections_abc.py

from user code:
   File "test.py", line 35, in forward
    data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))
  File "/usr/local/lib/python3.10/dist-packages/torch_geometric/data/data.py", line 847, in x
    return self['x'] if 'x' in self._store else None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Environment

  • PyG version: 2.4.0
  • PyTorch version: 2.1.0a0+32f93b1
  • OS: NVIDIA Pytorch container 23.09
  • Python version: 3.10.12
  • CUDA/cuDNN version: 12.2/8.9.5
  • How you installed PyTorch and PyG (conda, pip, source): NVIDA Pytorch container and pip install torch_geometric
  • Any other relevant information (e.g., version of torch-scatter):
    • torch-cluster: 1.6.2
    • torch-scatter: 2.1.2
    • torch-sparse: 0.6.18
    • torch-spline-conv: 1.2.2
@cfd1 cfd1 added the bug label Nov 3, 2023
@rusty1s
Copy link
Member

rusty1s commented Nov 3, 2023

This is currently expected as torch.compile cannot yet handle custom CPP/CUDA code :(

@akihironitta
Copy link
Member

For reference, pytorch/pytorch#111223 might be relevant.

@akihironitta
Copy link
Member

#8890 for tracking

@akihironitta akihironitta self-assigned this Dec 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants