From df50fa3dcf0ba7993944389c2b6a5724b0f77730 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Wed, 29 Sep 2021 11:59:33 -0500 Subject: [PATCH] [LLVM] Make changes needed for opaque pointers (#9138) * [LLVM] Make changes needed for opaque pointers - Pass value type to all Create.*Load and Create.*GEP functions. - Create type TypedPointer to keep both the address and the pointee's type when buffer pointers etc. are created. - Eliminate calls to getPointerElementType, except one in creating debug info (that seems necessary for the time being). * Fix typo in CodeGenCPU::CreateStructRefPtr * Fix type extraction in CodeGenLLVM::AddAliasInfo * Fix types in ramp-1 vector loads/stores * Fix getting intrinsic name in error message * Return valid pointer from PackClosureData when no data to pack --- src/target/llvm/codegen_cpu.cc | 173 +++++++++++++++++++---------- src/target/llvm/codegen_cpu.h | 6 +- src/target/llvm/codegen_hexagon.cc | 92 +++++++++------ src/target/llvm/codegen_llvm.cc | 122 ++++++++++++-------- src/target/llvm/codegen_llvm.h | 11 +- 5 files changed, 259 insertions(+), 145 deletions(-) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index c98c23ae8c61..466f85393b1b 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -246,8 +246,9 @@ std::unique_ptr CodeGenCPU::Finish() { } return CodeGenLLVM::Finish(); } -llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, - int kind) { + +CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, + llvm::Value* index, int kind) { if (kind < builtin::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); @@ -257,57 +258,87 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm:: } switch (kind) { case builtin::kArrAddr: { - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index)); } case builtin::kArrData: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrShape: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(4); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrStrides: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(5); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrNDim: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(2); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeCode: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeBits: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrTypeLanes: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrByteOffset: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(6); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrDeviceId: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)}); + return TypedPointer(member_type, member_addr); } case builtin::kArrDeviceType: { - return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); + llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0); + llvm::Value* member_addr = + builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)}); + return TypedPointer(member_type, member_addr); } case builtin::kTVMValueContent: { ICHECK_EQ(t.lanes(), 1); ICHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); } else if (t.is_float()) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(buf, index); - return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); + return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } default: LOG(FATAL) << "unknown field code"; - return nullptr; + return TypedPointer(); } } @@ -373,7 +404,10 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { ICHECK(gv != nullptr); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment())); + llvm::LoadInst* faddr = + builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment()); #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif @@ -490,10 +524,11 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { builder_->SetInsertPoint(compute_call_end); } -llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* num_bytes) { +CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array& vfields, + uint64_t* num_bytes) { if (vfields.size() == 0) { *num_bytes = 0U; - return llvm::Constant::getNullValue(t_void_p_); + return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_)); } std::vector fields; for (Var v : vfields) { @@ -501,23 +536,24 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* nu ICHECK(it != var_map_.end()); fields.push_back(it->second->getType()); } - llvm::StructType* tcdata = llvm::StructType::create(fields); - llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); + llvm::StructType* ctype = llvm::StructType::create(fields); + llvm::Value* cvalue = builder_->CreateAlloca(ctype, ConstInt32(1)); llvm::Value* zero = ConstInt32(0); for (size_t i = 0; i < vfields.size(); ++i) { builder_->CreateStore(var_map_.at(vfields[i].get()), - builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); + builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)})); } - *num_bytes = data_layout_->getTypeAllocSize( - llvm::cast(cdata->getType())->getElementType()); - return cdata; + *num_bytes = data_layout_->getTypeAllocSize(ctype); + return TypedPointer(ctype, cvalue); } -void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array& vfields, +void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { - (*vmap)[vfields[i].get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)})); + llvm::Type* field_type = cdata.type->getStructElementType(i); + llvm::Value* field_addr = + builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)}); + (*vmap)[vfields[i].get()] = builder_->CreateLoad(field_type, field_addr); } } @@ -530,21 +566,22 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { // allocate and setup the closure, call the closure. Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; - llvm::Value* cdata = PackClosureData(vfields, &nbytes); + TypedPointer cdata = PackClosureData(vfields, &nbytes); #if TVM_LLVM_VERSION >= 90 auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); #else auto launch_callee = RuntimeTVMParallelLaunch(); #endif BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( - launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); + launch_callee, + {f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); llvm::Value* penv = &(*it++); - cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); + cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); // setup new variable map, swap it with current var context. std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); @@ -553,8 +590,9 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { par_env.task_id = Var("task_id", DataType::Int(32)); par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; - new_vmap[par_env.num_task.get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)})); + new_vmap[par_env.num_task.get()] = builder_->CreateLoad( + t_int32_, + builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)})); par_env.penv = penv; auto new_analyzer = std::make_unique(); std::swap(function_, f); @@ -600,14 +638,14 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod // allocate and setup the closure, call the closure. uint64_t nbytes; Array vfields = tir::UndefinedVars(body, {}); - llvm::Value* cdata = PackClosureData(vfields, &nbytes); + TypedPointer cdata = PackClosureData(vfields, &nbytes); BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( - finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); + finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); - cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); + cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType()); // setup new variable map, swap it with current var context. std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); @@ -655,7 +693,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); + llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif @@ -667,8 +707,11 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = - builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif @@ -682,7 +725,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 - llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); + llvm::Value* loaded_handle = + builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align); #else llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); #endif @@ -709,11 +755,13 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& llvm::Value* stack_value = MakeValue(args[1]); llvm::Value* stack_tcode = MakeValue(args[2]); llvm::Value* arg_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(begin)); + TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(end)); + TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); @@ -721,15 +769,18 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& auto call_callee = RuntimeTVMFuncCall(); #endif llvm::Value* call = builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode}); + call_callee, + {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); #else llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif @@ -737,9 +788,11 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array& // Load the return type code. #if TVM_LLVM_VERSION >= 110 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); #else - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); #endif pc.end_block = end_block; @@ -882,24 +935,24 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; - llvm::Value* ref = - this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); + TypedPointer ref = + CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { - return builder_->CreatePointerCast(ref, t_void_p_); + return builder_->CreatePointerCast(ref.addr, t_void_p_); } else { - return builder_->CreateLoad(ref); + return builder_->CreateLoad(ref.type, ref.addr); } } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); - llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + MakeValue(op->args[1]), kind); ICHECK(kind != builtin::kArrAddr); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref.type); } - builder_->CreateStore(value, ref); + builder_->CreateStore(value, ref.addr); return ConstInt32(0); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2U); diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 30e61ea63f12..402189eb374d 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -105,9 +105,9 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - llvm::Value* PackClosureData(const Array& fields, uint64_t* num_bytes); - llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(llvm::Value* cdata, const Array& fields, + TypedPointer PackClosureData(const Array& fields, uint64_t* num_bytes); + TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); + void UnpackClosureData(TypedPointer cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. struct PackedCall { diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index d8a64102f9cd..bffb620d49f9 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -75,7 +75,7 @@ class CodeGenHexagon final : public CodeGenLLVM { llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr}; private: - llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); + TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind); // Check if the call to packed function is successful // if not directly finalize function and pass on return code. @@ -255,7 +255,10 @@ llvm::GlobalVariable* CodeGenHexagon::InitContextPtr(llvm::Type* p_type, std::st llvm::Value* CodeGenHexagon::GetContextPtr(llvm::GlobalVariable* gv) { ICHECK(gv != nullptr); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment())); + llvm::LoadInst* faddr = + builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment()); #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif @@ -313,11 +316,13 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const ArrayCreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); - llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(begin)); + TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); llvm::Value* ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); - llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); + t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), + ConstInt32(end)); + TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); @@ -325,15 +330,18 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const ArrayCreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode}); + call_callee, + {handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr}); llvm::BasicBlock* end_block = CheckCallSuccess(call); // Load the return value and cast it to the designated type (r_type). DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = - builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type); + llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo()); #if TVM_LLVM_VERSION >= 110 - llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8); #else llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8); #endif @@ -341,9 +349,11 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array= 110 - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8)); +#elif TVM_LLVM_VERSION >= 80 + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8); #else - pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8); + pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8); #endif pc.end_block = end_block; @@ -380,7 +390,9 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); + llvm::Value* handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, hptr, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, hptr, align); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif @@ -392,8 +404,11 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { llvm::Value* out = WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = - builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + llvm::Align(gv_mod_ctx_->getAlignment())); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_, + gv_mod_ctx_->getAlignment()); #else llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif @@ -407,7 +422,10 @@ llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) { llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 - llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); + llvm::Value* loaded_handle = + builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align)); +#elif TVM_LLVM_VERSION >= 80 + llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align); #else llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); #endif @@ -514,23 +532,23 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::tvm_struct_get())) { ICHECK_EQ(op->args.size(), 3); int kind = op->args[2].as()->value; - llvm::Value* ref = + TypedPointer ref = CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { - return builder_->CreatePointerCast(ref, t_void_p_); + return builder_->CreatePointerCast(ref.addr, t_void_p_); } - return builder_->CreateLoad(ref); + return builder_->CreateLoad(ref.type, ref.addr); } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4); int kind = op->args[2].as()->value; ICHECK(kind != builtin::kArrAddr); - llvm::Value* ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), MakeValue(op->args[1]), kind); llvm::Value* value = MakeValue(op->args[3]); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref.type); } - builder_->CreateStore(value, ref); + builder_->CreateStore(value, ref.addr); return ConstInt32(0); } else if (op->op.same_as(builtin::tvm_stack_alloca())) { ICHECK_EQ(op->args.size(), 2); @@ -549,8 +567,8 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) { return CodeGenLLVM::CreateIntrinsic(op); } -llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, - int kind) { +CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, + llvm::Value* index, int kind) { static const std::map field_index = { {builtin::kArrData, 0}, {builtin::kArrDeviceType, 1}, {builtin::kArrDeviceId, 1}, {builtin::kArrNDim, 2}, {builtin::kArrTypeCode, 3}, {builtin::kArrTypeBits, 3}, @@ -581,12 +599,13 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll uint64_t byte_offset; kArrByteOffset } DLTensor; */ - llvm::Value* base_gep = builder_->CreateInBoundsGEP(buf, index, "base_gep"); + llvm::Value* base_gep = builder_->CreateInBoundsGEP(t_tvm_array_, buf, index, "base_gep"); if (kind == builtin::kArrAddr) { - return base_gep; + return TypedPointer(t_void_p_, base_gep); } llvm::Value* field_gep = builder_->CreateInBoundsGEP( - base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, "field_gep"); + t_tvm_array_, base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, "field_gep"); + llvm::Type* field_type = t_tvm_array_->getStructElementType(field_index.at(kind)); switch (kind) { // These fields have no sub-fields. case builtin::kArrData: @@ -594,10 +613,13 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll case builtin::kArrShape: case builtin::kArrStrides: case builtin::kArrByteOffset: - return field_gep; + return TypedPointer(field_type, field_gep); } - return builder_->CreateInBoundsGEP( - field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, "subfield_gep"); + llvm::Value* subfield_gep = builder_->CreateInBoundsGEP( + field_type, field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, + "subfield_gep"); + llvm::Type* subfield_type = field_type->getStructElementType(subfield_index.at(kind)); + return TypedPointer(subfield_type, subfield_gep); } if (kind == builtin::kTVMValueContent) { @@ -615,20 +637,20 @@ llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, ll ICHECK(t.is_handle() || t.bits() == 64); if (t.is_int()) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); } else if (t.is_float()) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); - return builder_->CreateInBoundsGEP(buf, index); + return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); } else { ICHECK(t.is_handle()); buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); - buf = builder_->CreateInBoundsGEP(buf, index); - return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_void_p_, buf, index); + return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); } } assert(!"Unknown kind"); - return nullptr; + return TypedPointer(); } namespace { diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6aabdc1bd804..12fbf2c3e42c 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -473,9 +473,16 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta); // Extract the underlying type of the allocated buffer. - llvm::Type* buf_type = GetVarValue(buffer)->getType()->getScalarType(); - if (buf_type->isPointerTy()) { - buf_type = buf_type->getPointerElementType(); + DataType dtype = buffer->dtype; + if (buffer->type_annotation.defined()) { + Type element_type = Downcast(buffer->type_annotation)->element_type; + if (auto* ptype = element_type.as()) { + dtype = ptype->dtype; + } + } + llvm::Type* buf_type = DTypeToLLVMType(dtype); + if (!buf_type) { + buf_type = t_void_p_; } std::string tmp; @@ -737,14 +744,17 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { return ptr; } -llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { +CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, + llvm::Value* index) { llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); ICHECK(btype != nullptr); - llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace()); - if (btype != ptype) { - buffer = builder_->CreatePointerCast(buffer, ptype); + llvm::Type* llvm_type = DTypeToLLVMType(t); + llvm::PointerType* ttype = llvm_type->getPointerTo(btype->getAddressSpace()); + if (btype != ttype) { + buffer = builder_->CreatePointerCast(buffer, ttype); } - return builder_->CreateInBoundsGEP(buffer, index); + llvm::Value* ptr = builder_->CreateInBoundsGEP(llvm_type, buffer, index); + return TypedPointer(llvm_type, ptr); } llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { @@ -861,10 +871,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " -#if TVM_LLVM_VERSION <= 130 - << llvm::Intrinsic::getName(id, {}); +#if TVM_LLVM_VERSION >= 130 + << llvm::Intrinsic::getBaseName(id).str(); #else - << llvm::Intrinsic::getName(id, return_type, {}); + << llvm::Intrinsic::getName(id, {}); #endif return builder_->CreateCall(f, arg_value); } else if (op->op.same_as(builtin::bitwise_and())) { @@ -888,18 +898,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::address_of())) { const LoadNode* l = op->args[0].as(); ICHECK(op->args.size() == 1 && l); - const RampNode* r = l->index.as(); - llvm::Value* ptr; - unsigned addrspace; - if (!r) { - ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); - addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); - } else { + TypedPointer buffer_ptr; + if (const RampNode* r = l->index.as()) { PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); + buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); + } else { + buffer_ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); } - return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); + unsigned addrspace = + llvm::dyn_cast(buffer_ptr.addr->getType())->getAddressSpace(); + return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace)); } else if (op->op.same_as(builtin::reinterpret()) && is_zero(op->args[0])) { return llvm::Constant::getNullValue(t_void_p_); } else if (op->op.same_as(builtin::isnullptr())) { @@ -1154,29 +1162,40 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { if (t.lanes() == 1) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - llvm::Value* ptr = CreateBufferPtr(t, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* load = + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(load, op->buffer_var.get(), op->index); return load; } else { // vector load - unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + // The index argument is element-based, to create buffer pointer for t's element type. + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + unsigned addrspace = + llvm::dyn_cast(buffer->getType())->getAddressSpace(); + buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.addr = + builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); +#elif TVM_LLVM_VERSION >= 80 llvm::LoadInst* load = - builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, alignment, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(load, op->buffer_var.get(), op->index); return load; @@ -1187,11 +1206,15 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { int basic_align = t.bits() / 8; llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t)); auto f = [&](int i, llvm::Value* index) { - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, + llvm::Align(basic_align), is_volatile); +#elif TVM_LLVM_VERSION >= 80 + llvm::LoadInst* load = + builder_->CreateAlignedLoad(buffer_ptr.type, buffer_ptr.addr, basic_align, is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(buffer_ptr.addr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); AddAliasInfo(load, op->buffer_var.get(), PrimExpr()); @@ -1271,30 +1294,36 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { if (t.lanes() == 1) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); - llvm::Value* ptr = CreateBufferPtr(t, buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = - builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); + builder_->CreateAlignedStore(value, buffer_ptr.addr, llvm::Align(alignment), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), op->index); return; } else { // vector store - unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); ICHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + // The index argument is element-based, to create buffer pointer for t's element type. + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + unsigned addrspace = + llvm::dyn_cast(buffer->getType())->getAddressSpace(); + buffer_ptr.type = DTypeToLLVMType(t); + buffer_ptr.addr = + builder_->CreatePointerCast(buffer_ptr.addr, buffer_ptr.type->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = - builder_->CreateAlignedStore(value, ptr, llvm::Align(alignment), is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore(value, buffer_ptr.addr, + llvm::Align(alignment), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(value, buffer_ptr.addr, alignment, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), op->index); return; @@ -1305,13 +1334,14 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { // scalarized store. int basic_align = t.bits() / 8; auto f = [&](int i, llvm::Value* index) { - llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); + TypedPointer buffer_ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile); + llvm::StoreInst* store = + builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), buffer_ptr.addr, + llvm::Align(basic_align), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), - ptr, basic_align, is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore( + builder_->CreateExtractElement(value, i), buffer_ptr.addr, basic_align, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), PrimExpr()); }; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index a4f007aeebed..177b53056354 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -181,6 +181,15 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: + /*! + * \brief Address and type pair to assist in handling opaque pointers. + */ + struct TypedPointer { + TypedPointer() = default; + TypedPointer(llvm::Type* t, llvm::Value* a) : type(t), addr(a) {} + llvm::Type* type = nullptr; /*!< Type of the value pointed to. */ + llvm::Value* addr = nullptr; /*!< Address of the value. */ + }; /*! \brief The storage information */ struct StorageInfo { /*! \brief The alignment of allocation */ @@ -301,7 +310,7 @@ class CodeGenLLVM : public ExprFunctor, llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); - llvm::Value* CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); + TypedPointer CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index); // Vector concatenation. llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent); llvm::Value* CreateVecFlip(llvm::Value* vec);