Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op hardshrink #7887

Merged
merged 55 commits into from
Apr 15, 2022
Merged

Add op hardshrink #7887

merged 55 commits into from
Apr 15, 2022

Conversation

marigoold
Copy link
Contributor

此PR完成了:

  • 增加了Hardshrink激活函数,及其对应的文档、单测代码、global测试代码

image

@@ -451,6 +451,60 @@ def extra_repr(self):
return inplace_str


class Hardshrink(Module):
r"""
The interface is consistent with PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几行都去掉,可以在torch文档基础增改一下,不要照搬

[](user_op::KernelComputeContext* ctx) { \
return HardShrinkFunctor<dtype>(ctx->Attr<double>("lambd")); \
}, \
"out", "in"); \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lazy这里支持inplace还需要参考这部分 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/activation_kernels.h#L309-L313。

另外一个PR也记得改一下

def __init__(self, lambd: float = 0.5, inplace: bool = False):
self.inplace = inplace
self.lambd = lambd
super().__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先super调用父类init,再初始化inplace 和 lambd

@github-actions
Copy link
Contributor

github-actions bot commented Apr 9, 2022

CI failed when running job: cuda-misc. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label Apr 9, 2022
@github-actions
Copy link
Contributor

github-actions bot commented Apr 9, 2022

Speed stats:

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7887/

@marigoold marigoold enabled auto-merge (squash) April 13, 2022 07:00
@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7887/

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7887/

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7887/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

✔️ OneFlow resnet50 time: 128.3ms (= 12834.6ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 139.4ms (= 13937.9ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.09 (= 139.4ms / 128.3ms)

OneFlow resnet50 time: 77.9ms (= 7793.6ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 83.6ms (= 8359.0ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.07 (= 83.6ms / 77.9ms)

OneFlow resnet50 time: 52.4ms (= 10478.2ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 55.0ms (= 11005.6ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.05 (= 55.0ms / 52.4ms)

OneFlow resnet50 time: 43.1ms (= 8623.6ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 50.0ms (= 9995.1ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.16 (= 50.0ms / 43.1ms)

OneFlow resnet50 time: 37.6ms (= 7514.2ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 39.2ms (= 7844.4ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.04 (= 39.2ms / 37.6ms)

OneFlow swin dataloader time: 0.251s (= 50.166s / 200, num_workers=1)
PyTorch swin dataloader time: 0.252s (= 50.321s / 200, num_workers=1)
✔️ Relative speed: 1.003 (= 0.252s / 0.251s)

OneFlow swin dataloader time: 0.073s (= 14.572s / 200, num_workers=4)
PyTorch swin dataloader time: 0.069s (= 13.822s / 200, num_workers=4)
✔️ Relative speed: 0.949 (= 0.069s / 0.073s)

OneFlow swin dataloader time: 0.037s (= 7.415s / 200, num_workers=8)
PyTorch swin dataloader time: 0.037s (= 7.450s / 200, num_workers=8)
✔️ Relative speed: 1.005 (= 0.037s / 0.037s)

✔️ OneFlow resnet50 time: 135.8ms (= 13578.8ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 162.8ms (= 16284.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.20 (= 162.8ms / 135.8ms)

OneFlow resnet50 time: 87.7ms (= 8770.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 98.9ms (= 9893.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.13 (= 98.9ms / 87.7ms)

OneFlow resnet50 time: 58.7ms (= 11740.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 86.2ms (= 17239.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.47 (= 86.2ms / 58.7ms)

OneFlow resnet50 time: 53.1ms (= 10624.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.1ms (= 13427.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.26 (= 67.1ms / 53.1ms)

OneFlow resnet50 time: 48.7ms (= 9742.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 62.5ms (= 12502.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.28 (= 62.5ms / 48.7ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7887/

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 

✔️ OneFlow resnet50 time: 128.3ms (= 12833.6ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 140.3ms (= 14032.5ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.09 (= 140.3ms / 128.3ms)

OneFlow resnet50 time: 78.3ms (= 7826.1ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 84.8ms (= 8484.5ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.08 (= 84.8ms / 78.3ms)

OneFlow resnet50 time: 52.9ms (= 10584.7ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 57.3ms (= 11456.2ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.08 (= 57.3ms / 52.9ms)

OneFlow resnet50 time: 45.3ms (= 9062.7ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 47.8ms (= 9556.3ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.05 (= 47.8ms / 45.3ms)

OneFlow resnet50 time: 37.2ms (= 7444.2ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 38.1ms (= 7629.1ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.02 (= 38.1ms / 37.2ms)

OneFlow swin dataloader time: 0.251s (= 50.238s / 200, num_workers=1)
PyTorch swin dataloader time: 0.256s (= 51.226s / 200, num_workers=1)
✔️ Relative speed: 1.020 (= 0.256s / 0.251s)

OneFlow swin dataloader time: 0.065s (= 13.060s / 200, num_workers=4)
PyTorch swin dataloader time: 0.070s (= 14.093s / 200, num_workers=4)
✔️ Relative speed: 1.079 (= 0.070s / 0.065s)

OneFlow swin dataloader time: 0.037s (= 7.484s / 200, num_workers=8)
PyTorch swin dataloader time: 0.036s (= 7.299s / 200, num_workers=8)
✔️ Relative speed: 0.975 (= 0.036s / 0.037s)

✔️ OneFlow resnet50 time: 135.0ms (= 13502.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 155.7ms (= 15565.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 155.7ms / 135.0ms)

OneFlow resnet50 time: 88.6ms (= 8857.9ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 98.6ms (= 9863.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.11 (= 98.6ms / 88.6ms)

OneFlow resnet50 time: 59.8ms (= 11951.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 76.2ms (= 15236.7ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.27 (= 76.2ms / 59.8ms)

OneFlow resnet50 time: 51.9ms (= 10375.9ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.5ms (= 13308.1ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.28 (= 66.5ms / 51.9ms)

OneFlow resnet50 time: 48.5ms (= 9698.3ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 71.7ms (= 14340.7ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.48 (= 71.7ms / 48.5ms)

@marigoold marigoold merged commit e82f520 into master Apr 15, 2022
@marigoold marigoold deleted the add_op_hardshrink branch April 15, 2022 00:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants