-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for deform_conv2d operation from PyTorch #1889
Comments
Thank you for filing this feature request! Could you provide a minimum code snippet that contains deform_conv2d to reproduce the issue? Thanks! Meanwhile, I would also recommend to add the support of this op on your end by using the composite operators: https://coremltools.readme.io/docs/composite-operators Thanks! |
Thank you for your reply! As requested, here's a minimum code snippet that contains deform_conv2d: import torch
from torchvision.ops import deform_conv2d
import coremltools as ct
class DeformConv2DModel(torch.nn.Module):
def __init__(self):
super(DeformConv2DModel, self).__init__()
self.kh, self.kw = 3, 3
self.weight = torch.nn.Parameter(torch.rand(5, 3, self.kh, self.kw))
def forward(self, x, offset, mask):
out = deform_conv2d(x, offset, self.weight, mask=mask)
return out
# Define the model
model = DeformConv2DModel()
# Create a random input tensor
input_tensor = torch.rand(4, 3, 10, 10)
offset = torch.rand(4, 2 * model.kh * model.kw, input_tensor.shape[2] - 2, input_tensor.shape[3] - 2)
mask = torch.rand(4, model.kh * model.kw, input_tensor.shape[2] - 2, input_tensor.shape[3] - 2)
# Trace the model
traced_model = torch.jit.trace(model, (input_tensor, offset, mask))
# Convert to Core ML
coreml_model = ct.convert(
traced_model,
inputs=[ct.TensorType(name="input", shape=input_tensor.shape),
ct.TensorType(name="offset", shape=offset.shape),
ct.TensorType(name="mask", shape=mask.shape)],
source='pytorch',
) |
Hello, I was wondering if there's any update regarding the support of the deform_conv2d operation? Thank you! |
same demand for deform_conv2d here! any update? |
+1 |
The PyTorch documentation for this op doesn't contains a lot of details. The PyTorch forward implementation for |
Thank you @TobyRoseman for looking into this. Based on my understanding of the topic, the PyTorch implementation references the following two papers: I've summarized the formulas below. Deformable ConvNets:
Deformable ConvNets v2:
|
Thanks @Volutionn for the concise information. So the |
That's what I understand about the |
Hello everyone, I tried using deform_conv2d in coremltools v7.1, and got this error: RuntimeError: PyTorch convert function for op 'torchvision::deform_conv2d' not implemented. Are there any workarounds? I don't think it's a simple implementation that can be done using a custom layer in MIL Builder. However, in case anyone has some ideas I'm happy to work on it to try and implement. Just need a starting point. The reason for this is Deform Conv 2d is usually a drop - in replacement that improves loss by 20-30% on the CNN. It's pretty amazing, and would be of great advantage to have supported in deployed CoreML |
Just a bump. Is this issue dead? Please advise. |
@bitanath I imagine it's just a complex operator to add. I tried to implement it using composite operators, but without any success. Hopefully, this hasn't been abandoned on @TobyRoseman's side; I've been hoping for it for almost a year. Agree, it would be amazing to have it. Let's be patient, it's normal that it takes time. |
@Volutionn I implemented the https://github.com/dneprDroid/DeformConv2d-Metal It's GPU-accelerated and supports both iOS and macOS. |
Wow, you made my day! That's amazing, thanks a lot for sharing it @dneprDroid 🙏🏻 |
Thanks a lot @dneprDroid ! This is pretty neat! I was also trying to implement this in metal from the published CUDA implementation, but it seemed too hard. Would you also be releasing the code for the shaders, purely for learning purposes? Or did I miss them somewhere. Regardless, thanks a lot for this library! It's awesome!! Much appreciated ❤️ |
Both |
Name of layer type:
deform_conv2d
Is this a PyTorch or a TensorFlow layer type:
PyTorch
Your version of coremltools:
7.0b1
Your version of PyTorch/TensorFlow:
PyTorch 2.0.1
Impact of supporting this layer type. Why is adding support for this layer type important? Is it necessary to support a popular model or use case?
Deformable Convolution, as implemented in the torchvision.ops.deform_conv2d operator in PyTorch, is a key technique that allows Convolutional Neural Networks to adapt to complex spatial transformations in input data. It enhances the model's performance in tasks that require understanding spatial hierarchies and relationships, such as object detection, image segmentation, and image restoration.
The lack of support for this operation presents a challenge for the conversion of my model.
I was wondering if there are any plans to implement support for the deform_conv2d operation in a future release of CoreML?
If support for deform_conv2d is not planned, could you provide any advice or workarounds for dealing with this issue? Any guidance would be greatly appreciated.
Thank you for your time and for the excellent work you do on the CoreML project!
The text was updated successfully, but these errors were encountered: