From 3229cb329254764499dd672bb28fd9685ecd6a2e Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Tue, 12 Oct 2021 21:08:54 -0500 Subject: [PATCH] [LLVM] Treat scalars as single-lane vectors in CreateVecConcat (#9264) LLVM differentiates between `<1 x ty>` and `ty`, while TVM does not. Make sure that a bunch of TVM scalars can be concatenated into a vector when generating LLVM IR. --- src/target/llvm/codegen_llvm.cc | 14 ++++++++++++++ .../python/unittest/test_target_codegen_llvm.py | 16 ++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 12fbf2c3e42c..c94c5a685d1b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -626,6 +626,20 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { } llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { + // To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane + // LLVM vector types. + for (size_t i = 0, e = vecs.size(); i != e; ++i) { + llvm::Value* v = vecs[i]; + if (!v->getType()->isVectorTy()) { +#if TVM_LLVM_VERSION >= 110 + llvm::Type* vec_ty = llvm::FixedVectorType::get(v->getType(), 1); +#else + llvm::Type* vec_ty = llvm::VectorType::get(v->getType(), 1); +#endif + vecs[i] = builder_->CreateInsertElement(llvm::UndefValue::get(vec_ty), v, ConstInt32(0)); + } + } + // concat vector, tree shape reduction int total_lanes = 0; diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 8c8d601672ac..5a1b33ae10b1 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -885,5 +885,21 @@ def check_llvm(use_file): check_llvm(use_file=False) +@tvm.testing.requires_llvm +def test_llvm_scalar_concat(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.decl_buffer((1,), "int32x2") + s = tvm.tir.Shuffle([x, y], [0, 1]) + f = tvm.tir.PrimFunc([x, y, z], z.vstore(0, s)) + + mod = tvm.ir.IRModule.from_expr(f.with_attr("global_symbol", "codegen_scalar_concat")) + + # This will crash in LLVM codegen if CodeGenLLVM::CreateVecConcat doesn't convert + # scalars to single-lane LLVM vectors. + with tvm.transform.PassContext(config={"tir.disable_assert": True}): + m = tvm.build(mod, [x, y, z], target="llvm") + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))