Skip to content

Commit

Permalink
[LLVM] Treat scalars as single-lane vectors in CreateVecConcat (#9264)
Browse files Browse the repository at this point in the history
LLVM differentiates between `<1 x ty>` and `ty`, while TVM does not.
Make sure that a bunch of TVM scalars can be concatenated into a
vector when generating LLVM IR.
  • Loading branch information
Krzysztof Parzyszek authored Oct 13, 2021
1 parent b5d863c commit 3229cb3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,20 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
}

llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
// To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane
// LLVM vector types.
for (size_t i = 0, e = vecs.size(); i != e; ++i) {
llvm::Value* v = vecs[i];
if (!v->getType()->isVectorTy()) {
#if TVM_LLVM_VERSION >= 110
llvm::Type* vec_ty = llvm::FixedVectorType::get(v->getType(), 1);
#else
llvm::Type* vec_ty = llvm::VectorType::get(v->getType(), 1);
#endif
vecs[i] = builder_->CreateInsertElement(llvm::UndefValue::get(vec_ty), v, ConstInt32(0));
}
}

// concat vector, tree shape reduction
int total_lanes = 0;

Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,5 +885,21 @@ def check_llvm(use_file):
check_llvm(use_file=False)


@tvm.testing.requires_llvm
def test_llvm_scalar_concat():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.decl_buffer((1,), "int32x2")
s = tvm.tir.Shuffle([x, y], [0, 1])
f = tvm.tir.PrimFunc([x, y, z], z.vstore(0, s))

mod = tvm.ir.IRModule.from_expr(f.with_attr("global_symbol", "codegen_scalar_concat"))

# This will crash in LLVM codegen if CodeGenLLVM::CreateVecConcat doesn't convert
# scalars to single-lane LLVM vectors.
with tvm.transform.PassContext(config={"tir.disable_assert": True}):
m = tvm.build(mod, [x, y, z], target="llvm")


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 3229cb3

Please sign in to comment.