forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lowering for 'tosa.scatter' #40
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This patch adds support for `tosa.scatter` lowering in the `--tosa-to-scf` pass. Here's an example for this lowering: ``` func.func @tosa( %valuesIn : tensor<3x7x5xi32>, %indices : tensor<3x6xi32>, %input : tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %0 = "tosa.scatter"(%valuesIn, %indices, %input) : (tensor<3x7x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> (tensor<3x7x5xi32>) return %0 : tensor<3x7x5xi32> } ``` translates to func.func @tosa(%arg0: tensor<3x7x5xi32>, %arg1: tensor<3x6xi32>, %arg2: tensor<3x6x5xi32>) -> tensor<3x7x5xi32> { %c0 = arith.constant 0 : index %c3 = arith.constant 3 : index %c1 = arith.constant 1 : index %c6 = arith.constant 6 : index %c2 = arith.constant 2 : index %c5 = arith.constant 5 : index %c0_0 = arith.constant 0 : index %c1_1 = arith.constant 1 : index %0 = scf.for %arg3 = %c0_0 to %c3 step %c1_1 iter_args(%arg4 = %arg0) -> (tensor<3x7x5xi32>) { %1 = scf.for %arg5 = %c0_0 to %c6 step %c1_1 iter_args(%arg6 = %arg4) -> (tensor<3x7x5xi32>) { %extracted = tensor.extract %arg1[%arg3, %arg5] : tensor<3x6xi32> %2 = arith.index_cast %extracted : i32 to index %extracted_slice = tensor.extract_slice %arg2[%arg3, %arg5, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<3x6x5xi32> to tensor<?x?x?xi32> %inserted_slice = tensor.insert_slice %extracted_slice into %arg6[%arg3, %2, %c0_0] [%c1_1, %c1_1, %c5] [%c1_1, %c1_1, %c1_1] : tensor<?x?x?xi32> into tensor<3x7x5xi32> scf.yield %inserted_slice : tensor<3x7x5xi32> } scf.yield %1 : tensor<3x7x5xi32> } return %0 : tensor<3x7x5xi32> } ``` We have attempted an alternative lowering pass that uses `tensor.scatter` as an intermediate step. However, we opted to aim straight at the `scf` dialect for the following reasons: - The `tensor.scatter` op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon). - The `tosa.scatter` and `tensor.scatter` op have different indexing semantics. The `indices` argument of `tosa.scatter` must be non-trivially modified and restructured (e.g. with a `linalg.generic` op) to adapt to the needs of `tensor.scatter`. While this overhead may be simplified and fused after a subsequent `tensor.scatter` lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through the `tensor` dialect that we're missing, this additional complexity may not be justified. Reviewed By: eric-k256 Differential Revision: https://reviews.llvm.org/D151117
flemairen6
approved these changes
Jun 2, 2023
ttjost
approved these changes
Jun 2, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
mgehre-amd
pushed a commit
that referenced
this pull request
Feb 27, 2024
…lvm#80904)" This reverts commit b1ac052. This commit breaks coroutine splitting for non-swift calling convention functions. In this example: ```ll ; ModuleID = 'repro.ll' source_filename = "stdlib/test/runtime/test_llcl.mojo" target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-linux-gnu" @0 = internal constant { i32, i32 } { i32 trunc (i64 sub (i64 ptrtoint (ptr @craSH to i64), i64 ptrtoint (ptr getelementptr inbounds ({ i32, i32 }, ptr @0, i32 0, i32 1) to i64)) to i32), i32 64 } define dso_local void @af_suspend_fn(ptr %0, i64 %1, ptr %2) #0 { ret void } define dso_local void @craSH(ptr %0) #0 { %2 = call token @llvm.coro.id.async(i32 64, i32 8, i32 0, ptr @0) %3 = call ptr @llvm.coro.begin(token %2, ptr null) %4 = getelementptr inbounds { ptr, { ptr, ptr }, i64, { ptr, i1 }, i64, i64 }, ptr poison, i32 0, i32 0 %5 = call ptr @llvm.coro.async.resume() store ptr %5, ptr %4, align 8 %6 = call { ptr, ptr, ptr } (i32, ptr, ptr, ...) @llvm.coro.suspend.async.sl_p0p0p0s(i32 0, ptr %5, ptr @ctxt_proj_fn, ptr @af_suspend_fn, ptr poison, i64 -1, ptr poison) ret void } define dso_local ptr @ctxt_proj_fn(ptr %0) #0 { ret ptr %0 } ; Function Attrs: nomerge nounwind declare { ptr, ptr, ptr } @llvm.coro.suspend.async.sl_p0p0p0s(i32, ptr, ptr, ...) #1 ; Function Attrs: nounwind declare token @llvm.coro.id.async(i32, i32, i32, ptr) #2 ; Function Attrs: nounwind declare ptr @llvm.coro.begin(token, ptr writeonly) #2 ; Function Attrs: nomerge nounwind declare ptr @llvm.coro.async.resume() #1 attributes #0 = { "target-features"="+adx,+aes,+avx,+avx2,+bmi,+bmi2,+clflushopt,+clwb,+clzero,+crc32,+cx16,+cx8,+f16c,+fma,+fsgsbase,+fxsr,+invpcid,+lzcnt,+mmx,+movbe,+mwaitx,+pclmul,+pku,+popcnt,+prfchw,+rdpid,+rdpru,+rdrnd,+rdseed,+sahf,+sha,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+sse4a,+ssse3,+vaes,+vpclmulqdq,+wbnoinvd,+x87,+xsave,+xsavec,+xsaveopt,+xsaves" } attributes #1 = { nomerge nounwind } attributes #2 = { nounwind } ``` This verifier crashes after the `coro-split` pass with ``` cannot guarantee tail call due to mismatched parameter counts musttail call void @af_suspend_fn(ptr poison, i64 -1, ptr poison) LLVM ERROR: Broken function PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace. Stack dump: 0. Program arguments: opt ../../../reduced.ll -O0 #0 0x00007f1d89645c0e __interceptor_backtrace.part.0 /build/gcc-11-XeT9lY/gcc-11-11.4.0/build/x86_64-linux-gnu/libsanitizer/asan/../../../../src/libsanitizer/sanitizer_common/sanitizer_common_interceptors.inc:4193:28 #1 0x0000556d94d254f7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:22 #2 0x0000556d94d19a2f llvm::sys::RunSignalHandlers() /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Support/Signals.cpp:105:20 #3 0x0000556d94d1aa42 SignalHandler(int) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Support/Unix/Signals.inc:371:36 #4 0x00007f1d88e42520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520) #5 0x00007f1d88e969fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76 #6 0x00007f1d88e969fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10 #7 0x00007f1d88e969fc pthread_kill ./nptl/pthread_kill.c:89:10 #8 0x00007f1d88e42476 gsignal ./signal/../sysdeps/posix/raise.c:27:6 #9 0x00007f1d88e287f3 abort ./stdlib/abort.c:81:7 #10 0x0000556d8944be01 std::vector<llvm::json::Value, std::allocator<llvm::json::Value>>::size() const /usr/include/c++/11/bits/stl_vector.h:919:40 #11 0x0000556d8944be01 bool std::operator==<llvm::json::Value, std::allocator<llvm::json::Value>>(std::vector<llvm::json::Value, std::allocator<llvm::json::Value>> const&, std::vector<llvm::json::Value, std::allocator<llvm::json::Value>> const&) /usr/include/c++/11/bits/stl_vector.h:1893:23 #12 0x0000556d8944be01 llvm::json::operator==(llvm::json::Array const&, llvm::json::Array const&) /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/Support/JSON.h:572:69 #13 0x0000556d8944be01 llvm::json::operator==(llvm::json::Value const&, llvm::json::Value const&) (.cold) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Support/JSON.cpp:204:28 #14 0x0000556d949ed2bd llvm::report_fatal_error(char const*, bool) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Support/ErrorHandling.cpp:82:70 #15 0x0000556d8e37e876 llvm::SmallVectorBase<unsigned int>::size() const /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:91:32 #16 0x0000556d8e37e876 llvm::SmallVectorTemplateCommon<llvm::DiagnosticInfoOptimizationBase::Argument, void>::end() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:282:41 #17 0x0000556d8e37e876 llvm::SmallVector<llvm::DiagnosticInfoOptimizationBase::Argument, 4u>::~SmallVector() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/ADT/SmallVector.h:1215:24 #18 0x0000556d8e37e876 llvm::DiagnosticInfoOptimizationBase::~DiagnosticInfoOptimizationBase() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/DiagnosticInfo.h:413:7 #19 0x0000556d8e37e876 llvm::DiagnosticInfoIROptimization::~DiagnosticInfoIROptimization() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/DiagnosticInfo.h:622:7 #20 0x0000556d8e37e876 llvm::OptimizationRemark::~OptimizationRemark() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/DiagnosticInfo.h:689:7 #21 0x0000556d8e37e876 operator() /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp:2213:14 #22 0x0000556d8e37e876 emit<llvm::CoroSplitPass::run(llvm::LazyCallGraph::SCC&, llvm::CGSCCAnalysisManager&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&)::<lambda()> > /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/Analysis/OptimizationRemarkEmitter.h:83:12 #23 0x0000556d8e37e876 llvm::CoroSplitPass::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Transforms/Coroutines/CoroSplit.cpp:2212:13 #24 0x0000556d8c36ecb1 llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::CoroSplitPass, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/PassManagerInternal.h:91:3 #25 0x0000556d91c1a84f llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Analysis/CGSCCPassManager.cpp:90:12 #26 0x0000556d8c3690d1 llvm::detail::PassModel<llvm::LazyCallGraph::SCC, llvm::PassManager<llvm::LazyCallGraph::SCC, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&>::run(llvm::LazyCallGraph::SCC&, llvm::AnalysisManager<llvm::LazyCallGraph::SCC, llvm::LazyCallGraph&>&, llvm::LazyCallGraph&, llvm::CGSCCUpdateResult&) /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/PassManagerInternal.h:91:3 #27 0x0000556d91c2162d llvm::ModuleToPostOrderCGSCCPassAdaptor::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Analysis/CGSCCPassManager.cpp:278:18 #28 0x0000556d8c369035 llvm::detail::PassModel<llvm::Module, llvm::ModuleToPostOrderCGSCCPassAdaptor, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/PassManagerInternal.h:91:3 #29 0x0000556d9457abc5 llvm::PassManager<llvm::Module, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/PassManager.h:247:20 #30 0x0000556d8e30979e llvm::CoroConditionalWrapper::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) /home/ubuntu/modular/third-party/llvm-project/llvm/lib/Transforms/Coroutines/CoroConditionalWrapper.cpp:19:74 #31 0x0000556d8c365755 llvm::detail::PassModel<llvm::Module, llvm::CoroConditionalWrapper, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/PassManagerInternal.h:91:3 #32 0x0000556d9457abc5 llvm::PassManager<llvm::Module, llvm::AnalysisManager<llvm::Module>>::run(llvm::Module&, llvm::AnalysisManager<llvm::Module>&) /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/PassManager.h:247:20 #33 0x0000556d89818556 llvm::SmallPtrSetImplBase::isSmall() const /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/ADT/SmallPtrSet.h:196:33 #34 0x0000556d89818556 llvm::SmallPtrSetImplBase::~SmallPtrSetImplBase() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/ADT/SmallPtrSet.h:84:17 #35 0x0000556d89818556 llvm::SmallPtrSetImpl<llvm::AnalysisKey*>::~SmallPtrSetImpl() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/ADT/SmallPtrSet.h:321:7 #36 0x0000556d89818556 llvm::SmallPtrSet<llvm::AnalysisKey*, 2u>::~SmallPtrSet() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/ADT/SmallPtrSet.h:427:7 #37 0x0000556d89818556 llvm::PreservedAnalyses::~PreservedAnalyses() /home/ubuntu/modular/third-party/llvm-project/llvm/include/llvm/IR/Analysis.h:109:7 #38 0x0000556d89818556 llvm::runPassPipeline(llvm::StringRef, llvm::Module&, llvm::TargetMachine*, llvm::TargetLibraryInfoImpl*, llvm::ToolOutputFile*, llvm::ToolOutputFile*, llvm::ToolOutputFile*, llvm::StringRef, llvm::ArrayRef<llvm::PassPlugin>, llvm::ArrayRef<std::function<void (llvm::PassBuilder&)>>, llvm::opt_tool::OutputKind, llvm::opt_tool::VerifierKind, bool, bool, bool, bool, bool, bool, bool) /home/ubuntu/modular/third-party/llvm-project/llvm/tools/opt/NewPMDriver.cpp:532:10 #39 0x0000556d897e3939 optMain /home/ubuntu/modular/third-party/llvm-project/llvm/tools/opt/optdriver.cpp:737:27 #40 0x0000556d89455461 main /home/ubuntu/modular/third-party/llvm-project/llvm/tools/opt/opt.cpp:25:33 #41 0x00007f1d88e29d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16 #42 0x00007f1d88e29e40 call_init ./csu/../csu/libc-start.c:128:20 #43 0x00007f1d88e29e40 __libc_start_main ./csu/../csu/libc-start.c:379:5 #44 0x0000556d897b6335 _start (/home/ubuntu/modular/.derived/third-party/llvm-project/build-relwithdebinfo-asan/bin/opt+0x150c335) Aborted (core dumped)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a cherry-pick of https://reviews.llvm.org/D151117.
This patch adds support for
tosa.scatter
lowering in the--tosa-to-scf
pass. Here's an example for this lowering:translates to
We have attempted an alternative lowering pass that uses
tensor.scatter
as an intermediate step. However, we opted to aim straight at thescf
dialect for the following reasons:tensor.scatter
op doesn't seem to be used anywhere. There is no available lowering pass for this op (although we have one that we'll upstream soon).tosa.scatter
andtensor.scatter
op have different indexing semantics. Theindices
argument oftosa.scatter
must be non-trivially modified and restructured (e.g. with alinalg.generic
op) to adapt to the needs oftensor.scatter
. While this overhead may be simplified and fused after a subsequenttensor.scatter
lowering, it adds complex logic and an obscure intermediate state. Unless there is a good reason to go through thetensor
dialect that we're missing, this additional complexity may not be justified.Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D151117
DO NOT FILE A PULL REQUEST
This repository does not accept pull requests. Please follow http://llvm.org/docs/Contributing.html#how-to-submit-a-patch for contribution to LLVM.
DO NOT FILE A PULL REQUEST