-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SwiGLU optimized fw/bw #490
Conversation
[ghstack-poisoned]
ghstack-source-id: eb1801be830e5b7af5f4913eaf3a0e76c1465a69 Pull Request resolved: #490
[ghstack-poisoned]
ghstack-source-id: b30671b2fc7903973bf6e5ab83542532b91d3d74 Pull Request resolved: #490
[ghstack-poisoned]
ghstack-source-id: c7a4eda1bead77d8a8ba18deaf4067cf23402205 Pull Request resolved: #490
[ghstack-poisoned]
ghstack-source-id: 055d35dff615ebcc5c8380d07a0b580a67260c52 Pull Request resolved: #490
[ghstack-poisoned]
ghstack-source-id: a87a46d345dcb98dc0c53c56575fcda38cd5bccd Pull Request resolved: #490
[ghstack-poisoned]
ghstack-source-id: e8f89ae5e89d7fc7907a6cd32d2fd85e04b08eda Pull Request resolved: #490
[ghstack-poisoned]
ghstack-source-id: b864e2340fdea7f0c6819f9349dbb3b41766c9f1 Pull Request resolved: #490
[ghstack-poisoned]
ghstack-source-id: 1ff447ae98cc07c4e3de9653884175cc4c59b5ec Pull Request resolved: #490
**USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUFusedOp) ``` [ghstack-poisoned]
ghstack-source-id: 7b874c69561bf1756e95ccfad9407e4ea9d18e85 Pull Request resolved: #490
**USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUFusedOp) ``` [ghstack-poisoned]
ghstack-source-id: abc12d1ec3cabbec2ebd5c2fffb72167609f3d85 Pull Request resolved: #490
**USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUFusedOp) ``` [ghstack-poisoned]
ghstack-source-id: 520ade162a45516f01a84e551958e4c54a0fe4e3 Pull Request resolved: #490
**USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUFusedOp) ``` [ghstack-poisoned]
**USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` [ghstack-poisoned]
**USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` [ghstack-poisoned]
**USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
Codecov ReportBase: 90.60% // Head: 88.38% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## gh/danthe3rd/52/base #490 +/- ##
========================================================
- Coverage 90.60% 88.38% -2.23%
========================================================
Files 79 80 +1
Lines 4652 4785 +133
========================================================
+ Hits 4215 4229 +14
- Misses 437 556 +119
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3 op=xops.SwiGLUPackedFusedOp) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
**NOTE** We can improve a bit more once this is fixed - NVIDIA/cutlass#674 **USAGE** ```python import xformers.ops as xops # NOTE: Important to use `unbind` from xformers for the bw pass! w1, w2 = xops.unbind( w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), dim=0, ) b1, b2 = xops.unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) y = xops.functional_swiglu(x, w1, b1, w2, b2, w3, b3) ``` **PERFORMANCE (A100 only)** *FW* ``` [-------------------------------------------------------- swiglu_fw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 1377.7 | 1581.4 | 1339.1 f16.ac B=9456, I=1536, H=4096 | 1449.3 | 1735.3 | 1462.9 f16 B=4440, I=1536, H=4096 | 600.4 | 735.6 | 593.9 f16.ac B=4440, I=1536, H=4096 | 709.0 | 843.7 | 717.6 f16 B=4728, I=1536, H=4096 | 638.9 | 776.2 | 635.3 f16.ac B=4728, I=1536, H=4096 | 748.9 | 892.2 | 756.7 f16 B=4728, I=1536, H=1024 | 162.3 | 201.5 | 163.1 f16.ac B=4728, I=1536, H=1024 | 235.2 | 277.4 | 245.5 Times are in microseconds (us). ``` *BW* ``` [-------------------------------------------------------- swiglu_bw ---------------------------------------------------------] | SwiGLUPackedFusedOp[fused.p.cpp] | eager | SwiGLUFusedOp[fused] 1 threads: ------------------------------------------------------------------------------------------------------------------- f16 B=9456, I=1536, H=4096 | 2333.1 | 2696.7 | 2336.1 f16.ac B=9456, I=1536, H=4096 | 2620.8 | 2990.9 | 2840.0 f16 B=4440, I=1536, H=4096 | 1243.2 | 1413.8 | 1240.3 f16.ac B=4440, I=1536, H=4096 | 1448.6 | 1629.0 | 1637.3 f16 B=4728, I=1536, H=4096 | 1298.4 | 1481.5 | 1301.1 f16.ac B=4728, I=1536, H=4096 | 1511.8 | 1705.3 | 1705.4 f16 B=4728, I=1536, H=1024 | 463.3 | 493.9 | 463.0 f16.ac B=4728, I=1536, H=1024 | 582.4 | 614.9 | 672.7 Times are in microseconds (us). ``` [ghstack-poisoned]
ghstack-source-id: 7998ff3210011362be7c379666655e9bc5078dde Pull Request resolved: #490
Stack from ghstack (oldest at bottom):
NOTE
We can improve a bit more once this is fixed - NVIDIA/cutlass#674
USAGE
PERFORMANCE (A100 only)
FW
BW