From 776c04b54055c7802c09ff7d0b3d3bbe76f16909 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 15 Apr 2022 19:46:44 +0900 Subject: [PATCH] Squashed commit of the following: commit f499e6048912437b2759fce68b1c66fc23071a2f Author: Masahiro Masuda Date: Fri Apr 15 04:11:02 2022 +0900 Squashed commit of the following: commit dcb628d7d548f4a0e3eb8f3a2004fdda759fb0bf Author: Masahiro Masuda Date: Thu Apr 14 17:10:27 2022 +0900 Squashed commit of the following: commit dd956ec636bb374480e57b09a980432b54860a94 Author: Masahiro Masuda Date: Thu Apr 14 16:53:34 2022 +0900 add conv2d relay test commit 7291e476757f5780efec04146b57fed0cdd0314b Author: Masahiro Masuda Date: Thu Apr 14 16:46:05 2022 +0900 add dense and bmm test commit a957dde49d8c665b3d0df090a353d6b46011ae5f Author: Masahiro Masuda Date: Thu Apr 14 16:32:43 2022 +0900 conv2d topi test working commit 6d53c502f14f0eeb38e2b77716584f57ea1890d2 Author: Masahiro Masuda Date: Thu Apr 14 11:33:38 2022 +0900 add mattr kind commit 3761bd7ef8f3cc4fd98951f777f48efed8825054 Author: Masahiro Masuda Date: Thu Apr 14 11:12:14 2022 +0900 update dot prod intrin commit e781ee1e84ae31b4b84b2b419a03ce2196f89780 Author: Masahiro Masuda Date: Thu Apr 14 11:02:43 2022 +0900 black commit b2208a7b5ef3ba03b9750f579022b2a1c699348c Author: Masahiro Masuda Date: Thu Apr 14 10:58:10 2022 +0900 cleanup commit f8bc306ca5eca0839c3e305418ebe9b58c818ecd Author: Masahiro Masuda Date: Thu Apr 14 10:35:02 2022 +0900 [ROCM] Support dp4a on AMDGPU by sdot4 intrinsic commit 0225f2bfe3f413cd4764c2dba6c922af2520146b Author: Masahiro Masuda Date: Thu Apr 14 08:56:10 2022 +0900 share op strategy between cuda and rocm commit 762c7e8611c9ec3cca3321428e2362c81fe89b9b Author: Masahiro Masuda Date: Thu Apr 14 08:28:34 2022 +0900 fixed rocm batch_matmul strategy for mixed i8i8i32 commit ce53e8d141f7f901303ec6a91674337cbf2b2384 Author: Masahiro Masuda Date: Thu Apr 14 06:17:30 2022 +0900 add rocm sdot4 TIR intrin commit f4562b991f9180b61be7339b2890de1584656c10 Author: Masahiro Masuda Date: Thu Apr 14 06:03:44 2022 +0900 rocm sdot4 works commit 6cc62805f82dd884a18a1c4c0e9bae5866e00da0 Author: Masahiro Masuda Date: Thu Apr 14 05:32:07 2022 +0900 more wip commit 0602f4a3157d4cb5a3f280a3a3c514bb6535aac8 Author: Masahiro Masuda Date: Thu Apr 14 03:47:37 2022 +0900 Squashed commit of the following: commit 65b8bcf955f44540d6a52c8416e60f3047c8366c Author: Masahiro Masuda Date: Wed Apr 13 20:36:49 2022 +0900 [WIP] adding DP4A support to rocm commit 4f8f308ab6bb85ef3bdcc2b8e846c2eea15f2167 Author: Masahiro Masuda Date: Wed Apr 13 14:03:25 2022 +0900 Squashed commit of the following: commit 1711be38a17e3b6171350009f1da05824cd0b340 Author: Masahiro Masuda Date: Wed Apr 13 13:11:40 2022 +0900 fixed condition for real commit 8a48fb5262e80e318cd81d5ff51bf95fd5eb576e Author: Masahiro Masuda Date: Wed Apr 13 09:57:42 2022 +0900 Revert "Skip applying sch_rule when both ann and sch_rule are defined" This reverts commit 4915c6a5a91ff87038e71f8aff9f31db684b4a95. commit daea033d2cb06388ef27ddadb80fc5bce72181d2 Author: Masahiro Masuda Date: Mon Apr 11 09:31:05 2022 +0900 [Metaschedule] Support rocm and spirv commit eb0cae2c779808cced074d189e8f487bf46ea89f Author: Masahiro Masuda Date: Wed Apr 13 07:25:04 2022 +0900 dp4a works commit 4915c6a5a91ff87038e71f8aff9f31db684b4a95 Author: Masahiro Masuda Date: Wed Apr 13 06:13:45 2022 +0900 Skip applying sch_rule when both ann and sch_rule are defined commit 7b3d71c6b21a9c5de9ef2b89d0a7db2800a5f3a2 Author: Masahiro Masuda Date: Wed Apr 13 04:40:31 2022 +0900 fixed intrin description commit 7666cd7a5b0ce182791662673fbe45944c84d0ae Author: Masahiro Masuda Date: Tue Apr 12 19:59:47 2022 +0900 add DP4A intrin commit 7086bdb75546a2680d12dc8f80c040cea23f729a Author: Masahiro Masuda Date: Tue Apr 12 19:03:44 2022 +0900 works commit db343974bfae86e51078e40e6170022a782d8e0a Author: Masahiro Masuda Date: Tue Apr 12 12:49:52 2022 +0900 more hack to tensorize loop mapping to make resnet50 e2e work commit 2409674a7884a60beb50d7aa3345c4b907b8cd13 Author: Masahiro Masuda Date: Mon Apr 11 13:40:59 2022 +0900 wip support pad + qnn.conv2d folding commit 613cb7ec33b6df41f1ebe0f0a0ac8eca7c73cff1 Author: Masahiro Masuda Date: Sun Apr 10 12:04:08 2022 +0900 hack to tensorize loop mapping to make conv2d work commit 9e4f9df6a409396a8a4a20d967c4f51accf5d210 Author: Masahiro Masuda Date: Sun Apr 10 11:34:13 2022 +0900 wrap tensorize with try/catch commit d4b496d858da0ae43063d47cb03a28b803d0269f Author: Masahiro Masuda Date: Sun Apr 10 11:33:39 2022 +0900 revert change in task_scheduler.cc commit 476129be7b286f5d109402280aea585e89f6dc1d Author: Masahiro Masuda Date: Sat Apr 9 05:54:10 2022 +0900 try / catch in ThreadedApply commit d8226ff26f25eba17d4000f25131822874bdc2cc Author: Masahiro Masuda Date: Fri Apr 8 17:17:59 2022 +0900 filter out invalid candidate commit 2632899a2759885d338e25f2a25ba0b2c555f0c3 Author: Masahiro Masuda Date: Fri Apr 8 10:09:48 2022 +0900 try graceful exit in parallel_for_dynamic commit 9d6741c3dd29c4dde861aa1d3b2ca85f560f5ac6 Author: Masahiro Masuda Date: Fri Apr 8 09:35:51 2022 +0900 [QNN] Fix broadcast for invalid axis commit 6ccde0959343ce4246ef99505b4f54de469a1a5c Author: Masahiro Masuda Date: Thu Apr 7 20:51:15 2022 +0900 refactor rewrite_tensorize commit 2ce206699f10b03b9611c4683018f7e0c70c7eb5 Author: Masahiro Masuda Date: Thu Apr 7 20:48:17 2022 +0900 allow missing schedule_rule in post order apply commit 3a69353a29abfc454e28d4e530d22a3e2043712e Author: Masahiro Masuda Date: Thu Apr 7 19:42:48 2022 +0900 refactor rewrite_tensorize commit 43e0b2f7f98299679807aaf1ffb13cce2b5f5ce3 Author: Masahiro Masuda Date: Thu Apr 7 18:25:14 2022 +0900 rewrite_vnni -> rewrite_tensorize commit 823797e2627a9bfa812b72019468569ee79eb4c6 Author: Masahiro Masuda Date: Thu Apr 7 18:12:12 2022 +0900 VNNI -> WithIntrin commit 4284a47e5933aa89c1c3362b15ad53b14782fc81 Author: Masahiro Masuda Date: Thu Apr 7 17:45:41 2022 +0900 introduce TileForIntrin commit b87ef32e30e1e71b3f39789f7289976a8cba4ab4 Author: Masahiro Masuda Date: Thu Apr 7 17:34:04 2022 +0900 move TilingwithTensorIntrin to auto_tensorize.cc commit 2fc118b3726586ba13f7de950beaa299b83a0af3 Author: Masahiro Masuda Date: Thu Apr 7 17:28:45 2022 +0900 clean up headers commit d8b2aa325c91b524bec22dc1ec2fc52c9f060fce Author: Masahiro Masuda Date: Thu Apr 7 17:09:32 2022 +0900 clean up using namespace commit eb05d25e2b71f4a1232a8796d1413011ec7629d3 Author: Masahiro Masuda Date: Thu Apr 7 17:03:05 2022 +0900 refactored init commit 5e6b0a08d447c0470c2c8a993e4bd62673e34fe3 Author: Masahiro Masuda Date: Thu Apr 7 16:57:14 2022 +0900 compiled commit 2b8c430e2fec7ceb285eed7bc7aa73bb9a74a997 Author: Masahiro Masuda Date: Thu Apr 7 12:51:55 2022 +0900 wip MultiLevelTiling refactor commit 7c21a9fea0511c88bd82f49f799b5198252df40a Author: Masahiro Masuda Date: Thu Apr 7 11:58:33 2022 +0900 function doc string not supported by tvmscript commit 40f9742bc9c3aa11e8c2c0551d1827ad47fc0f39 Author: Masahiro Masuda Date: Thu Apr 7 11:56:45 2022 +0900 update vnni intrin name commit 4814f825a5315efd2a3da8c36d2ce6b5df5447cd Merge: e0c5eb84b 07bbb38f7 Author: Masahiro Masuda Date: Thu Apr 7 11:44:47 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 07bbb38f7fb52db4a2ecde3d5c87cf4d5cd000a1 Author: Masahiro Masuda Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 15e60b42362cc64b1428b219c8eada414d1b8372 Author: Masahiro Masuda Date: Thu Apr 7 11:16:08 2022 +0900 black commit 7a757fe53758e06418ea1367b348b47c8cd2dcf9 Author: Masahiro Masuda Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 9a3e508b6f4529158e703b4617f2ddaa351a89eb Author: Masahiro Masuda Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit d8e43ecf1c0a79a2c195ff31e1e699a447a11335 Author: Masahiro Masuda Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 625cd2774ec455307646b0c26bb3971d89613d1e Author: Masahiro Masuda Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 69e72b6b612588e670937e003435afa647030ceb Author: Masahiro Masuda Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 1351fdea6b22f231a290a6c28e06732c9cf993cf Author: Masahiro Masuda Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0ced85fd097ed48aad8714912718d8735791e1fb Author: Masahiro Masuda Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit 38a5aca87ec438446593a3af17760339211f5ad9 Author: Masahiro Masuda Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 88b763ec48c20cf68db8bc3bae3fa3ae78996ee8 Author: Masahiro Masuda Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 711a0076d9be2b9aa80ada67e1edda5ba1fdf1fd Author: Masahiro Masuda Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR commit e0c5eb84bf6a0ad2ba0cddc4bdf22a799dc4b8a0 Author: Masahiro Masuda Date: Thu Apr 7 11:42:26 2022 +0900 merge fix commit b171748139e53f0cf75ff4b6fde436f9d8a5fe91 Merge: 71fe3bdf0 82e152a3c Author: Masahiro Masuda Date: Thu Apr 7 11:33:59 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 71fe3bdf02ae10ddbe090a4fd1020f545a05bb41 Author: Masahiro Masuda Date: Thu Apr 7 06:57:38 2022 +0900 move tensor intrin under tir commit 0c51badef45af2a1025ab42fe38d1b3f07ab493e Author: Masahiro Masuda Date: Thu Apr 7 06:12:39 2022 +0900 remove log commit fed910e03eb94c169d4a160b8f3cad406d04c6aa Author: Masahiro Masuda Date: Thu Apr 7 06:11:22 2022 +0900 more revert commit 7150aff9fba167d88dbfb40d48727de8a144b9c0 Author: Masahiro Masuda Date: Thu Apr 7 06:10:44 2022 +0900 revert stmt_functor change commit 155107b98b09c5e5cc7f19afbd327b0557a02843 Author: Masahiro Masuda Date: Thu Apr 7 06:10:09 2022 +0900 refactored RewriteVNNI a bit commit ca15255e3a882b89b05bb83079640c929fb63096 Author: Masahiro Masuda Date: Thu Apr 7 05:41:13 2022 +0900 add RewriteVNNI commit dc9f71d5e3122b50fa8ae6a4462f959f13870b05 Author: Masahiro Masuda Date: Thu Apr 7 05:38:56 2022 +0900 vectorized init loop commit fcc31ee20ddfafd47f566bf98ff40a9f684d12eb Author: Masahiro Masuda Date: Thu Apr 7 04:55:36 2022 +0900 tensorize worked commit 2b534377a45b9ab84bf35c3d7c03ecae7616d17f Author: Masahiro Masuda Date: Wed Apr 6 19:11:05 2022 +0900 TilingwithTensorIntrin works commit 86baa31e773fc864f77dc113bc9a93b79f3fc652 Author: Masahiro Masuda Date: Wed Apr 6 08:58:27 2022 +0900 Ported auto-tensorization code commit 82e152a3c91144041ade783116a50565ebb48b89 Author: Masahiro Masuda Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 88d9bdd3b21302bc2dd068a990df15c375a1a8ef Author: Masahiro Masuda Date: Thu Apr 7 11:16:08 2022 +0900 black commit 31fe7eb8075445161d804d170772eac8e90d3425 Author: Masahiro Masuda Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 7876754effc40ad089349534dacd75df19d38fc4 Author: Masahiro Masuda Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit 56f2e9a85069426021e2872eb1da95bf134ac7e0 Author: Masahiro Masuda Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 995cc8d6fcec70a3fadcfb1c6fee7b9f0b5a0951 Author: Masahiro Masuda Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 86bbd4955b34257d68d957cb4a2536aea3ef9bac Author: Masahiro Masuda Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 120fd96e80307b4301ee3fc93e6793e0b40485f0 Author: Masahiro Masuda Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0f0682d00c3961afd1f492ae55f180c5b5502767 Author: Masahiro Masuda Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit f88c31ead1fa6db4bfd2c88eeaf5f665e4c6dddb Author: Masahiro Masuda Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 6cc80094adac398762924b0b31a4c741417ba9dc Author: Masahiro Masuda Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 11a29c704cdaad96aeeca39c9c753ef006d27a50 Author: Masahiro Masuda Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR commit e370ed459739f5312e45a2fb3a446b120f8ec5d1 Author: Chris Sullivan Date: Wed Apr 13 15:19:41 2022 -0700 [Hexagon] Less aggressive adb state clean up (#10909) * Only remove port forwarding applied in a session to avoid affecting global adb state. * Send SIGINT to attempt to allow remote server to cleanup and undbind port in deconstruction * Only attempt to forward ports not in use by adb or the system. commit ce8f83e3c5c5bb7a021d675283e84ac319f19162 Author: Christian Convey Date: Wed Apr 13 16:25:39 2022 -0400 [hexagon] 'add_hvx' test to explore HVX usage. (#10604) Add a unit test named 'add_hvx' to explore how various scheduling choices, tensor sizes, etc. impact efficient usage of Hexagon HVX units. commit 0602f4a3157d4cb5a3f280a3a3c514bb6535aac8 Author: Masahiro Masuda Date: Thu Apr 14 03:47:37 2022 +0900 Squashed commit of the following: commit 65b8bcf955f44540d6a52c8416e60f3047c8366c Author: Masahiro Masuda Date: Wed Apr 13 20:36:49 2022 +0900 [WIP] adding DP4A support to rocm commit 4f8f308ab6bb85ef3bdcc2b8e846c2eea15f2167 Author: Masahiro Masuda Date: Wed Apr 13 14:03:25 2022 +0900 Squashed commit of the following: commit 1711be38a17e3b6171350009f1da05824cd0b340 Author: Masahiro Masuda Date: Wed Apr 13 13:11:40 2022 +0900 fixed condition for real commit 8a48fb5262e80e318cd81d5ff51bf95fd5eb576e Author: Masahiro Masuda Date: Wed Apr 13 09:57:42 2022 +0900 Revert "Skip applying sch_rule when both ann and sch_rule are defined" This reverts commit 4915c6a5a91ff87038e71f8aff9f31db684b4a95. commit daea033d2cb06388ef27ddadb80fc5bce72181d2 Author: Masahiro Masuda Date: Mon Apr 11 09:31:05 2022 +0900 [Metaschedule] Support rocm and spirv commit eb0cae2c779808cced074d189e8f487bf46ea89f Author: Masahiro Masuda Date: Wed Apr 13 07:25:04 2022 +0900 dp4a works commit 4915c6a5a91ff87038e71f8aff9f31db684b4a95 Author: Masahiro Masuda Date: Wed Apr 13 06:13:45 2022 +0900 Skip applying sch_rule when both ann and sch_rule are defined commit 7b3d71c6b21a9c5de9ef2b89d0a7db2800a5f3a2 Author: Masahiro Masuda Date: Wed Apr 13 04:40:31 2022 +0900 fixed intrin description commit 7666cd7a5b0ce182791662673fbe45944c84d0ae Author: Masahiro Masuda Date: Tue Apr 12 19:59:47 2022 +0900 add DP4A intrin commit 7086bdb75546a2680d12dc8f80c040cea23f729a Author: Masahiro Masuda Date: Tue Apr 12 19:03:44 2022 +0900 works commit db343974bfae86e51078e40e6170022a782d8e0a Author: Masahiro Masuda Date: Tue Apr 12 12:49:52 2022 +0900 more hack to tensorize loop mapping to make resnet50 e2e work commit 2409674a7884a60beb50d7aa3345c4b907b8cd13 Author: Masahiro Masuda Date: Mon Apr 11 13:40:59 2022 +0900 wip support pad + qnn.conv2d folding commit 613cb7ec33b6df41f1ebe0f0a0ac8eca7c73cff1 Author: Masahiro Masuda Date: Sun Apr 10 12:04:08 2022 +0900 hack to tensorize loop mapping to make conv2d work commit 9e4f9df6a409396a8a4a20d967c4f51accf5d210 Author: Masahiro Masuda Date: Sun Apr 10 11:34:13 2022 +0900 wrap tensorize with try/catch commit d4b496d858da0ae43063d47cb03a28b803d0269f Author: Masahiro Masuda Date: Sun Apr 10 11:33:39 2022 +0900 revert change in task_scheduler.cc commit 476129be7b286f5d109402280aea585e89f6dc1d Author: Masahiro Masuda Date: Sat Apr 9 05:54:10 2022 +0900 try / catch in ThreadedApply commit d8226ff26f25eba17d4000f25131822874bdc2cc Author: Masahiro Masuda Date: Fri Apr 8 17:17:59 2022 +0900 filter out invalid candidate commit 2632899a2759885d338e25f2a25ba0b2c555f0c3 Author: Masahiro Masuda Date: Fri Apr 8 10:09:48 2022 +0900 try graceful exit in parallel_for_dynamic commit 9d6741c3dd29c4dde861aa1d3b2ca85f560f5ac6 Author: Masahiro Masuda Date: Fri Apr 8 09:35:51 2022 +0900 [QNN] Fix broadcast for invalid axis commit 6ccde0959343ce4246ef99505b4f54de469a1a5c Author: Masahiro Masuda Date: Thu Apr 7 20:51:15 2022 +0900 refactor rewrite_tensorize commit 2ce206699f10b03b9611c4683018f7e0c70c7eb5 Author: Masahiro Masuda Date: Thu Apr 7 20:48:17 2022 +0900 allow missing schedule_rule in post order apply commit 3a69353a29abfc454e28d4e530d22a3e2043712e Author: Masahiro Masuda Date: Thu Apr 7 19:42:48 2022 +0900 refactor rewrite_tensorize commit 43e0b2f7f98299679807aaf1ffb13cce2b5f5ce3 Author: Masahiro Masuda Date: Thu Apr 7 18:25:14 2022 +0900 rewrite_vnni -> rewrite_tensorize commit 823797e2627a9bfa812b72019468569ee79eb4c6 Author: Masahiro Masuda Date: Thu Apr 7 18:12:12 2022 +0900 VNNI -> WithIntrin commit 4284a47e5933aa89c1c3362b15ad53b14782fc81 Author: Masahiro Masuda Date: Thu Apr 7 17:45:41 2022 +0900 introduce TileForIntrin commit b87ef32e30e1e71b3f39789f7289976a8cba4ab4 Author: Masahiro Masuda Date: Thu Apr 7 17:34:04 2022 +0900 move TilingwithTensorIntrin to auto_tensorize.cc commit 2fc118b3726586ba13f7de950beaa299b83a0af3 Author: Masahiro Masuda Date: Thu Apr 7 17:28:45 2022 +0900 clean up headers commit d8b2aa325c91b524bec22dc1ec2fc52c9f060fce Author: Masahiro Masuda Date: Thu Apr 7 17:09:32 2022 +0900 clean up using namespace commit eb05d25e2b71f4a1232a8796d1413011ec7629d3 Author: Masahiro Masuda Date: Thu Apr 7 17:03:05 2022 +0900 refactored init commit 5e6b0a08d447c0470c2c8a993e4bd62673e34fe3 Author: Masahiro Masuda Date: Thu Apr 7 16:57:14 2022 +0900 compiled commit 2b8c430e2fec7ceb285eed7bc7aa73bb9a74a997 Author: Masahiro Masuda Date: Thu Apr 7 12:51:55 2022 +0900 wip MultiLevelTiling refactor commit 7c21a9fea0511c88bd82f49f799b5198252df40a Author: Masahiro Masuda Date: Thu Apr 7 11:58:33 2022 +0900 function doc string not supported by tvmscript commit 40f9742bc9c3aa11e8c2c0551d1827ad47fc0f39 Author: Masahiro Masuda Date: Thu Apr 7 11:56:45 2022 +0900 update vnni intrin name commit 4814f825a5315efd2a3da8c36d2ce6b5df5447cd Merge: e0c5eb84b 07bbb38f7 Author: Masahiro Masuda Date: Thu Apr 7 11:44:47 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 07bbb38f7fb52db4a2ecde3d5c87cf4d5cd000a1 Author: Masahiro Masuda Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 15e60b42362cc64b1428b219c8eada414d1b8372 Author: Masahiro Masuda Date: Thu Apr 7 11:16:08 2022 +0900 black commit 7a757fe53758e06418ea1367b348b47c8cd2dcf9 Author: Masahiro Masuda Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 9a3e508b6f4529158e703b4617f2ddaa351a89eb Author: Masahiro Masuda Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit d8e43ecf1c0a79a2c195ff31e1e699a447a11335 Author: Masahiro Masuda Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 625cd2774ec455307646b0c26bb3971d89613d1e Author: Masahiro Masuda Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 69e72b6b612588e670937e003435afa647030ceb Author: Masahiro Masuda Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 1351fdea6b22f231a290a6c28e06732c9cf993cf Author: Masahiro Masuda Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0ced85fd097ed48aad8714912718d8735791e1fb Author: Masahiro Masuda Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit 38a5aca87ec438446593a3af17760339211f5ad9 Author: Masahiro Masuda Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 88b763ec48c20cf68db8bc3bae3fa3ae78996ee8 Author: Masahiro Masuda Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 711a0076d9be2b9aa80ada67e1edda5ba1fdf1fd Author: Masahiro Masuda Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR commit e0c5eb84bf6a0ad2ba0cddc4bdf22a799dc4b8a0 Author: Masahiro Masuda Date: Thu Apr 7 11:42:26 2022 +0900 merge fix commit b171748139e53f0cf75ff4b6fde436f9d8a5fe91 Merge: 71fe3bdf0 82e152a3c Author: Masahiro Masuda Date: Thu Apr 7 11:33:59 2022 +0900 Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni commit 71fe3bdf02ae10ddbe090a4fd1020f545a05bb41 Author: Masahiro Masuda Date: Thu Apr 7 06:57:38 2022 +0900 move tensor intrin under tir commit 0c51badef45af2a1025ab42fe38d1b3f07ab493e Author: Masahiro Masuda Date: Thu Apr 7 06:12:39 2022 +0900 remove log commit fed910e03eb94c169d4a160b8f3cad406d04c6aa Author: Masahiro Masuda Date: Thu Apr 7 06:11:22 2022 +0900 more revert commit 7150aff9fba167d88dbfb40d48727de8a144b9c0 Author: Masahiro Masuda Date: Thu Apr 7 06:10:44 2022 +0900 revert stmt_functor change commit 155107b98b09c5e5cc7f19afbd327b0557a02843 Author: Masahiro Masuda Date: Thu Apr 7 06:10:09 2022 +0900 refactored RewriteVNNI a bit commit ca15255e3a882b89b05bb83079640c929fb63096 Author: Masahiro Masuda Date: Thu Apr 7 05:41:13 2022 +0900 add RewriteVNNI commit dc9f71d5e3122b50fa8ae6a4462f959f13870b05 Author: Masahiro Masuda Date: Thu Apr 7 05:38:56 2022 +0900 vectorized init loop commit fcc31ee20ddfafd47f566bf98ff40a9f684d12eb Author: Masahiro Masuda Date: Thu Apr 7 04:55:36 2022 +0900 tensorize worked commit 2b534377a45b9ab84bf35c3d7c03ecae7616d17f Author: Masahiro Masuda Date: Wed Apr 6 19:11:05 2022 +0900 TilingwithTensorIntrin works commit 86baa31e773fc864f77dc113bc9a93b79f3fc652 Author: Masahiro Masuda Date: Wed Apr 6 08:58:27 2022 +0900 Ported auto-tensorization code commit 82e152a3c91144041ade783116a50565ebb48b89 Author: Masahiro Masuda Date: Thu Apr 7 11:24:56 2022 +0900 more lint fix commit 88d9bdd3b21302bc2dd068a990df15c375a1a8ef Author: Masahiro Masuda Date: Thu Apr 7 11:16:08 2022 +0900 black commit 31fe7eb8075445161d804d170772eac8e90d3425 Author: Masahiro Masuda Date: Thu Apr 7 11:12:54 2022 +0900 pylint commit 7876754effc40ad089349534dacd75df19d38fc4 Author: Masahiro Masuda Date: Thu Apr 7 10:58:52 2022 +0900 simplify import commit 56f2e9a85069426021e2872eb1da95bf134ac7e0 Author: Masahiro Masuda Date: Thu Apr 7 10:52:50 2022 +0900 use vectorlow/high in arm intrin commit 995cc8d6fcec70a3fadcfb1c6fee7b9f0b5a0951 Author: Masahiro Masuda Date: Thu Apr 7 10:34:57 2022 +0900 fixed offset factor commit 86bbd4955b34257d68d957cb4a2536aea3ef9bac Author: Masahiro Masuda Date: Thu Apr 7 10:12:02 2022 +0900 Add ARM intrin commit 120fd96e80307b4301ee3fc93e6793e0b40485f0 Author: Masahiro Masuda Date: Thu Apr 7 08:27:27 2022 +0900 use buffer syntax sugar commit 0f0682d00c3961afd1f492ae55f180c5b5502767 Author: Masahiro Masuda Date: Thu Apr 7 08:17:43 2022 +0900 rename vnni.py to x86.py commit f88c31ead1fa6db4bfd2c88eeaf5f665e4c6dddb Author: Masahiro Masuda Date: Thu Apr 7 07:24:44 2022 +0900 add VNNI unittest commit 6cc80094adac398762924b0b31a4c741417ba9dc Author: Masahiro Masuda Date: Thu Apr 7 07:10:06 2022 +0900 refactored existing test using VNNI intrin commit 11a29c704cdaad96aeeca39c9c753ef006d27a50 Author: Masahiro Masuda Date: Thu Apr 7 07:04:58 2022 +0900 [TIR] Add VNNI dot product intrinsic for TIR --- include/tvm/meta_schedule/schedule_rule.h | 10 ++ include/tvm/tir/stmt.h | 5 + python/tvm/meta_schedule/postproc/__init__.py | 1 + .../postproc/rewrite_tensorize.py | 33 ++++ .../meta_schedule/schedule_rule/__init__.py | 2 +- .../schedule_rule/multi_level_tiling.py | 47 ++++++ .../postproc/rewrite_tensorize.cc | 104 ++++++++++++ .../schedule_rule/auto_tensorize.cc | 99 +++++++++++ .../schedule_rule/auto_tensorize.h | 35 ++++ .../schedule_rule/multi_level_tiling.cc | 26 +-- .../schedule_rule/multi_level_tiling.h | 30 ++++ .../multi_level_tiling_with_intrin.cc | 60 +++++++ src/target/target_kind.cc | 1 + src/tir/schedule/analysis.h | 26 +++ src/tir/schedule/analysis/analysis.cc | 155 ++++++++++++++++++ 15 files changed, 611 insertions(+), 23 deletions(-) create mode 100644 python/tvm/meta_schedule/postproc/rewrite_tensorize.py create mode 100644 src/meta_schedule/postproc/rewrite_tensorize.cc create mode 100644 src/meta_schedule/schedule_rule/auto_tensorize.cc create mode 100644 src/meta_schedule/schedule_rule/auto_tensorize.h create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 1675bcce05edb..d854efc4c75ca 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -150,6 +150,16 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write); + + TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin( + String intrin_name, // + String structure, // + Optional> tile_binds, // + Optional max_innermost_factor, // + Optional> vector_load_lens, // + Optional> reuse_read, // + Optional> reuse_write); + /*! * \brief Create a rule: add-rfactor to some blocks if needed * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 9ccab50eced26..2cab81c74733a 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl /*! \brief Mark auto-unroll setting on the block. */ constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; +/*! + * \brief Mark that the block should be further rewritten using tensorization. + */ +constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 96361e739186d..39113bb90011a 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -22,3 +22,4 @@ from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock from .verify_gpu_code import VerifyGPUCode +from .rewrite_tensorize import RewriteTensorize diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py new file mode 100644 index 0000000000000..c45c319161b4e --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that tensorize related components.""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc +import tvm.tir.tensor_intrin + + +@register_object("meta_schedule.RewriteTensorize") +class RewriteTensorize(Postproc): + """A postprocessor that tensorize related components.""" + + def __init__(self, vectorize_init_loop=False) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member + vectorize_init_loop + ) diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index f03c6de3df4bd..a958fdc39db1f 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -22,7 +22,7 @@ from .add_rfactor import AddRFactor from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction -from .multi_level_tiling import MultiLevelTiling, ReuseType +from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 2ff49168d0c66..acdb185293d6c 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -82,3 +82,50 @@ def __init__( reuse_read.as_dict() if reuse_read is not None else None, reuse_write.as_dict() if reuse_write is not None else None, ) + + +@register_object("meta_schedule.MultiLevelTilingWithIntrin") +class MultiLevelTilingWithIntrin(ScheduleRule): + """Multi-level tiling with reuse. + + Parameters + ---------- + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + intrin_name: str, + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member + intrin_name, + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc new file mode 100644 index 0000000000000..1735ceb10cf44 --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" +#include "tvm/runtime/container/base.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; + +void ApplyTensorization(const tir::Schedule& sch, const String& func_name, + const tir::PrimFuncNode* func, bool vectorize_init_loop) { + std::vector>> jobs; + + tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) -> bool { + if (const auto* block = obj.as()) { + tir::StmtSRef block_sref = sch->GetSRef(block); + if (Optional intrin_name = + tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { + std::string block_name = block_sref->StmtAs()->name_hint; + if (block_name.find("init") == std::string::npos) { + jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) { + try { + sch->Tensorize(block, intrin_name.value()); + } catch (const std::exception& e) { + LOG(WARNING) << "Tensorize failed with error " << e.what(); + } + }); + } else if (vectorize_init_loop) { + jobs.emplace_back(block_name, [sch](tir::BlockRV block) { + Array child_blocks = sch->GetChildBlocks(block); + ICHECK(child_blocks.size() == 1); + Array init_loops = sch->GetLoops(child_blocks[0]); + ICHECK(init_loops.size() == 1); + sch->Vectorize(init_loops[0]); + }); + } + } + } + return true; + }); + + for (auto kv : jobs) { + tir::BlockRV block = sch->GetBlock(kv.first, func_name); + sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize); + kv.second(block); + } +} + +class RewriteTensorizeNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const tir::Schedule& sch) final; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + bool vectorize_init_loop = false; + + static constexpr const char* _type_key = "meta_schedule.RewriteTensorize"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode); +}; + +bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) { + for (const auto& kv : sch->mod()->functions) { + GlobalVar g_var = kv.first; + BaseFunc base_func = kv.second; + if (const tir::PrimFuncNode* prim_func = base_func.as()) { + ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop); + } + } + return true; +} + +Postproc RewriteTensorize(bool vectorize_init_loop) { + ObjectPtr n = make_object(); + n->vectorize_init_loop = vectorize_init_loop; + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize").set_body_typed(RewriteTensorize); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_tensorize.cc b/src/meta_schedule/schedule_rule/auto_tensorize.cc new file mode 100644 index 0000000000000..21f6697b9fab2 --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_tensorize.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "auto_tensorize.h" + +#include "../../tir/schedule/analysis.h" + +namespace tvm { +namespace meta_schedule { + +using tir::LoopRV; + +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name) { + Optional opt_tensorize_info = GetTensorizeLoopMapping( + sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + if (!opt_tensorize_info) return NullOpt; + const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + // Construct a mapping from tir loops back to LoopRVs + Map loop2rv; + { + Array loop_rvs = sch->GetLoops(block_rv); + for (const LoopRV& loop_rv : loop_rvs) { + loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); + } + } + // Split the loops + arith::Analyzer analyzer; + std::unordered_set inner_loops; + std::vector reorder_suffix; + reorder_suffix.resize(info->loop_map.size()); + for (const auto& kv : info->loop_map) { + // Extract mapping (block_loop => desc_loop) + const tir::StmtSRef& block_loop_sref = kv.first; + const tir::ForNode* block_loop = block_loop_sref->StmtAs(); + const tir::ForNode* desc_loop = kv.second.get(); + ICHECK(block_loop != nullptr && desc_loop != nullptr); + // Extract the loop extent + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + const auto* int_block_extent = block_extent.as(); + const auto* int_desc_extent = desc_extent.as(); + ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); + // Check divisibility + int64_t total = int_block_extent->value; + int64_t inner = int_desc_extent->value; + ICHECK_EQ(total % inner, 0); + int64_t outer = int_block_extent->value / int_desc_extent->value; + // Do the split + Array split = sch->Split(loop2rv.at(block_loop_sref), {Integer(outer), Integer(inner)}); + ICHECK_EQ(split.size(), 2); + inner_loops.insert(sch->GetSRef(split[1]).operator->()); + // The inner split will be reordered to the loop domain that is tensorized + int desc_loop_index = info->desc_loop_indexer.at(GetRef(desc_loop)); + reorder_suffix[desc_loop_index] = split[1]; + } + // Reorder the loops + std::vector reorder_list; + bool meet = false; + Array all_loops = sch->GetLoops(block_rv); + for (const LoopRV& loop : all_loops) { + if (inner_loops.count(sch->GetSRef(loop).operator->())) { + meet = true; + } else if (meet) { + reorder_list.push_back(loop); + } + } + reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); + sch->Reorder(reorder_list); + ICHECK(!reorder_suffix.empty()); + return reorder_suffix[0]; +} + +tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name) { + Optional tiled_loop_rv = TilingwithTensorIntrin(sch, block, intrin_name); + ICHECK(tiled_loop_rv.defined()); + tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value()); + sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); + return outer_block; +} + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/auto_tensorize.h b/src/meta_schedule/schedule_rule/auto_tensorize.h new file mode 100644 index 0000000000000..14674c2c737bf --- /dev/null +++ b/src/meta_schedule/schedule_rule/auto_tensorize.h @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +Optional TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, + const String& intrin_name); + +tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name); + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_AUTO_TENSORIZE_H_ diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 6b18b17867dc1..0dd477330a6ee 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -25,6 +25,7 @@ #include #include "../utils.h" +#include "tvm/meta_schedule/schedule_rule.h" namespace tvm { namespace tir { @@ -260,28 +261,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write) { - ObjectPtr n = make_object(); - n->structure = structure; - n->tile_binds = tile_binds.value_or({}); - n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; - n->vector_load_lens = vector_load_lens.defined() - ? support::AsVector(vector_load_lens.value()) - : std::vector(); - n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); - n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); - for (int i = 0, len = structure.size(); i < len; ++i) { - char c = structure.data()[i]; - if (c == 'S') { - n->s_indices_.push_back(i); - } else if (c == 'R') { - n->r_indices_.push_back(i); - } else { - LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; - } - } - n->thread_warp_size_ = -1; - n->max_threads_per_block_ = -1; - return ScheduleRule(n); + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + return ScheduleRule(node); } TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index b7712b5c1989f..f260c4856e364 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -181,6 +181,36 @@ class MultiLevelTilingNode : public ScheduleRuleNode { TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); }; +template +ObjectPtr MultiLevelTilingInitCommon(String structure, Optional> tile_binds, + Optional max_innermost_factor, + Optional> vector_load_lens, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_lens = vector_load_lens.defined() + ? support::AsVector(vector_load_lens.value()) + : std::vector(); + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + n->thread_warp_size_ = -1; + n->max_threads_per_block_ = -1; + return n; +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc new file mode 100644 index 0000000000000..ba85629ab8049 --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../utils.h" +#include "auto_tensorize.h" +#include "multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { + protected: + virtual std::vector ApplySubRules(std::vector states) { + states = SubRule(std::move(states), [&](State state) { + state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name); + return std::vector(1, state); + }); + return MultiLevelTilingNode::ApplySubRules(states); + } + + public: + String intrin_name; + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); +}; + +ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( + String intrin_name, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write) { + ICHECK(tir::TensorIntrin::Get(intrin_name).defined()); + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + node->intrin_name = intrin_name; + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin") + .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 2ad75259d69b6..a0c9df074b503 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -314,6 +314,7 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option("max_threads_per_block", Integer(256)) .add_attr_option("max_shared_memory_per_block", Integer(65536)) .add_attr_option("thread_warp_size", Integer(64)) + .add_attr_option("max_shared_memory_per_block", Integer(64000)) .set_default_keys({"rocm", "gpu"}) .set_attrs_preprocessor(UpdateROCmAttrs); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index b76d41326ff1d..cee45e1a4398e 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -656,6 +656,32 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, const P const StmtSRef& dom_high_exclusive, arith::Analyzer* analyzer); +/*! \brief Necessary information used for tensorization */ +class TensorizeInfoNode : public Object { + public: + /*! \brief Maps block loops to desc loops */ + Map loop_map; + /*! \brief Maps loops in desc to its index, outer to inner */ + Map desc_loop_indexer; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("loop_map", &loop_map); + v->Visit("desc_loop_indexer", &desc_loop_indexer); + } + + static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); +}; + +class TensorizeInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); +}; + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4a7ac401dd600..74c754dde9169 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2028,5 +2028,160 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } +TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + // Try to do tiling automatically if possible + // Now the heuristic is that if block's block var binding is constant + loop var, + // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder + // i, j, k according to the loops outside desc_block + // Collect the loops outside block + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + const tir::BlockRealizeNode* desc_block = nullptr; + std::vector desc_loops; + std::unordered_set desc_loop_vars; + const auto* desc_scope_realize = desc_func->body.as(); + ICHECK(desc_scope_realize); + { + auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, + &analyzer](const ObjectRef& obj) -> bool { + // Extract the block + if (const auto* block = obj.as()) { + desc_block = block; + return false; + } + // Extract the loops + if (const auto* loop = obj.as()) { + desc_loops.push_back(loop); + desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return false; + } + } + return true; + }; + tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); + std::reverse(desc_loops.begin(), desc_loops.end()); + ICHECK(desc_block); + } + // Step 2. Check if `desc_block` matches `block` + // Ignore the scope of buffers when comparing, since we can do cache_read/write + const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + std::vector block_loops; + std::unordered_set block_loop_vars; + { + for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { + break; + } + block_loops.push_back(loop); + block_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return NullOpt; + } + } + std::reverse(block_loops.begin(), block_loops.end()); + } + // Step 4. Map from block loops to desc block loops + ObjectPtr ret = make_object(); + int n_block_vars = block->iter_values.size(); + int n_desc_vars = desc_block->iter_values.size(); + int offset = n_block_vars - n_desc_vars; + if (offset < 0) { + return NullOpt; + } + // We align the block and desc block's bindings from the right side + // block (v0=..., v1=..., v2=...) + // ^ i_block + // desc_block( v1=..., v2=...) + // ^ i_desc + + std::vector iter_types = GetBlockVarTypes(block_sref); + ICHECK(block_loops.size() == iter_types.size()); + + for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) { + // For each block var binding, we find + const PrimExpr& block_bind = block->iter_values[i_block]; + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + LOG(INFO) << "block bind: " << block_bind; + LOG(INFO) << "desc bind: " << desc_bind; + // Step 4.1. Find the corresponding loop of the i-th block var of block + const tir::ForNode* block_loop = nullptr; + for (int i = block_loops.size() - 1; i >= 0; --i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + const auto* int_block_extent = block_loops[i]->extent.as(); + const auto* int_desc_extent = desc_loops[i_desc]->extent.as(); + + if (i_desc != n_desc_vars - 1 && iter_types[i] == IterVarType::kCommReduce) continue; + + // if (int_block_extent->value == int_desc_extent->value) { + if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) { + return block_loop_vars.count(var); + })) { + block_loop = block_loops[i]; + LOG(INFO) << "Selected " << i << " th block loop " << block_loops[i]->loop_var << ", " + << block_loop->extent; + break; + } else { + LOG(INFO) << i << " th block loop not ok " + << ", " << block_loops[i]->loop_var << ", " << block_loops[i]->extent; + } + } + if (block_loop == nullptr) { + return NullOpt; + } + // Step 4.2. Find the corresponding loop of the i-th block var of desc + const tir::ForNode* desc_loop = nullptr; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) { + return desc_loop_vars.count(var); + })) { + desc_loop = desc_loops[i]; + LOG(INFO) << "Selected " << i << " th desc loop " << desc_loop->extent; + ; + break; + } + } + if (desc_loop == nullptr) { + return NullOpt; + } + // Step 4.3. Check divisibility of loop extents + PrimExpr block_extent = analyzer.Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + if (const auto* int_block_extent = block_extent.as()) { + if (const auto* int_desc_extent = desc_extent.as()) { + if (int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + } else { + return NullOpt; + } + } else { + return NullOpt; + } + // Step 4.4. Maps the result of Step 4.1 to Step 4.2 + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + auto it = ret->loop_map.find(block_loop_sref); + if (it == ret->loop_map.end()) { + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + } else if ((*it).second.get() != desc_loop) { + return NullOpt; + } + } + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + } + return TensorizeInfo(ret); +} + } // namespace tir } // namespace tvm