Skip to content

Commit

Permalink
[LLVM] Make changes needed for opaque pointers (#9138)
Browse files Browse the repository at this point in the history
* [LLVM] Make changes needed for opaque pointers

- Pass value type to all Create.*Load and Create.*GEP functions.
- Create type TypedPointer to keep both the address and the pointee's
  type when buffer pointers etc. are created.
- Eliminate calls to getPointerElementType, except one in creating
  debug info (that seems necessary for the time being).

* Fix typo in CodeGenCPU::CreateStructRefPtr

* Fix type extraction in CodeGenLLVM::AddAliasInfo

* Fix types in ramp-1 vector loads/stores

* Fix getting intrinsic name in error message

* Return valid pointer from PackClosureData when no data to pack
  • Loading branch information
Krzysztof Parzyszek authored Sep 29, 2021
1 parent 86ce111 commit df50fa3
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 145 deletions.
173 changes: 113 additions & 60 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ std::unique_ptr<llvm::Module> CodeGenCPU::Finish() {
}
return CodeGenLLVM::Finish();
}
llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index,
int kind) {

CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf,
llvm::Value* index, int kind) {
if (kind < builtin::kArrKindBound_) {
if (buf->getType() == t_void_p_) {
buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
Expand All @@ -257,57 +258,87 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::
}
switch (kind) {
case builtin::kArrAddr: {
return builder_->CreateInBoundsGEP(buf, index);
return TypedPointer(t_tvm_array_, builder_->CreateInBoundsGEP(t_tvm_array_, buf, index));
}
case builtin::kArrData: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(0)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrShape: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(4)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(4);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(4)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrStrides: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(5)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(5);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(5)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrNDim: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(2);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(2)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeCode: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeBits: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(1);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(1)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrTypeLanes: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(3)->getStructElementType(2);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(3), ConstInt32(2)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrByteOffset: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(6);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(6)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrDeviceId: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(1);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(1)});
return TypedPointer(member_type, member_addr);
}
case builtin::kArrDeviceType: {
return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)});
llvm::Type* member_type = t_tvm_array_->getStructElementType(1)->getStructElementType(0);
llvm::Value* member_addr =
builder_->CreateInBoundsGEP(t_tvm_array_, buf, {index, ConstInt32(1), ConstInt32(0)});
return TypedPointer(member_type, member_addr);
}
case builtin::kTVMValueContent: {
ICHECK_EQ(t.lanes(), 1);
ICHECK(t.is_handle() || t.bits() == 64);
if (t.is_int()) {
buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index));
} else if (t.is_float()) {
buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
return builder_->CreateInBoundsGEP(buf, index);
return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index));
} else {
ICHECK(t.is_handle());
buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
buf = builder_->CreateInBoundsGEP(buf, index);
return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index);
return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()));
}
}
default:
LOG(FATAL) << "unknown field code";
return nullptr;
return TypedPointer();
}
}

Expand Down Expand Up @@ -373,7 +404,10 @@ llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string
llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) {
ICHECK(gv != nullptr);
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, llvm::Align(gv->getAlignment()));
llvm::LoadInst* faddr =
builder_->CreateAlignedLoad(gv->getValueType(), gv, llvm::Align(gv->getAlignment()));
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv->getValueType(), gv, gv->getAlignment());
#else
llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
#endif
Expand Down Expand Up @@ -490,34 +524,36 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) {
builder_->SetInsertPoint(compute_call_end);
}

llvm::Value* CodeGenCPU::PackClosureData(const Array<Var>& vfields, uint64_t* num_bytes) {
CodeGenLLVM::TypedPointer CodeGenCPU::PackClosureData(const Array<Var>& vfields,
uint64_t* num_bytes) {
if (vfields.size() == 0) {
*num_bytes = 0U;
return llvm::Constant::getNullValue(t_void_p_);
return TypedPointer(t_void_p_, llvm::Constant::getNullValue(t_void_p_));
}
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
ICHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::StructType* ctype = llvm::StructType::create(fields);
llvm::Value* cvalue = builder_->CreateAlloca(ctype, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
builder_->CreateInBoundsGEP(ctype, cvalue, {zero, ConstInt32(i)}));
}
*num_bytes = data_layout_->getTypeAllocSize(
llvm::cast<llvm::PointerType>(cdata->getType())->getElementType());
return cdata;
*num_bytes = data_layout_->getTypeAllocSize(ctype);
return TypedPointer(ctype, cvalue);
}

void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array<Var>& vfields,
void CodeGenCPU::UnpackClosureData(TypedPointer cdata, const Array<Var>& vfields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap) {
for (size_t i = 0; i < vfields.size(); ++i) {
(*vmap)[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)}));
llvm::Type* field_type = cdata.type->getStructElementType(i);
llvm::Value* field_addr =
builder_->CreateInBoundsGEP(cdata.type, cdata.addr, {ConstInt32(0), ConstInt32(i)});
(*vmap)[vfields[i].get()] = builder_->CreateLoad(field_type, field_addr);
}
}

Expand All @@ -530,21 +566,22 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
// allocate and setup the closure, call the closure.
Array<Var> vfields = tir::UndefinedVars(body, {});
uint64_t nbytes;
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
TypedPointer cdata = PackClosureData(vfields, &nbytes);
#if TVM_LLVM_VERSION >= 90
auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch());
#else
auto launch_callee = RuntimeTVMParallelLaunch();
#endif
BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall(
launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)}));
launch_callee,
{f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(num_task)}));
// Setup the closure function.
BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* task_id = &(*it++);
llvm::Value* penv = &(*it++);
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
Expand All @@ -553,8 +590,9 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) {
par_env.task_id = Var("task_id", DataType::Int(32));
par_env.num_task = Var("num_task", DataType::Int(32));
new_vmap[par_env.task_id.get()] = task_id;
new_vmap[par_env.num_task.get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)}));
new_vmap[par_env.num_task.get()] = builder_->CreateLoad(
t_int32_,
builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}));
par_env.penv = penv;
auto new_analyzer = std::make_unique<arith::Analyzer>();
std::swap(function_, f);
Expand Down Expand Up @@ -600,14 +638,14 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod
// allocate and setup the closure, call the closure.
uint64_t nbytes;
Array<Var> vfields = tir::UndefinedVars(body, {});
llvm::Value* cdata = PackClosureData(vfields, &nbytes);
TypedPointer cdata = PackClosureData(vfields, &nbytes);
BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall(
finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)}));
finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)}));
// Setup the closure function.
BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
cdata = builder_->CreatePointerCast(&(*it++), cdata->getType());
cdata.addr = builder_->CreatePointerCast(&(*it++), cdata.addr->getType());
// setup new variable map, swap it with current var context.
std::unordered_map<const VarNode*, llvm::Value*> new_vmap;
UnpackClosureData(cdata, vfields, &new_vmap);
Expand Down Expand Up @@ -655,7 +693,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_);
#if TVM_LLVM_VERSION >= 110
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align));
llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, llvm::Align(align));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* handle = builder_->CreateAlignedLoad(hptr->getValueType(), hptr, align);
#else
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
#endif
Expand All @@ -667,8 +707,11 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* out =
WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); });
#if TVM_LLVM_VERSION >= 110
llvm::LoadInst* ctx =
builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment()));
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
llvm::Align(gv_mod_ctx_->getAlignment()));
#elif TVM_LLVM_VERSION >= 80
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_->getValueType(), gv_mod_ctx_,
gv_mod_ctx_->getAlignment());
#else
llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment());
#endif
Expand All @@ -682,7 +725,10 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out});
init_block = CheckCallSuccess(retcode);
#if TVM_LLVM_VERSION >= 110
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align));
llvm::Value* loaded_handle =
builder_->CreateAlignedLoad(t_tvm_func_handle_, out, llvm::Align(align));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(t_tvm_func_handle_, out, align);
#else
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
#endif
Expand All @@ -709,37 +755,44 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
llvm::Value* stack_value = MakeValue(args[1]);
llvm::Value* stack_tcode = MakeValue(args[2]);
llvm::Value* arg_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin));
llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(begin));
TypedPointer arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end));
llvm::Value* ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
TypedPointer ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
#else
auto call_callee = RuntimeTVMFuncCall();
#endif
llvm::Value* call = builder_->CreateCall(
call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, ret_tcode});
call_callee,
{handle, arg_value, arg_tcode.addr, ConstInt32(nargs), ret_value, ret_tcode.addr});
llvm::BasicBlock* end_block = CheckCallSuccess(call);

// Load the return value and cast it to the designated type (r_type).
DataType r_api_type = tir::APIType(r_type);
llvm::Value* load_ptr =
builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo());
llvm::Type* llvm_r_api_type = DTypeToLLVMType(r_api_type);
llvm::Value* load_ptr = builder_->CreatePointerCast(ret_value, llvm_r_api_type->getPointerTo());
#if TVM_LLVM_VERSION >= 110
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8));
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
llvm::Value* rvalue = builder_->CreateAlignedLoad(llvm_r_api_type, load_ptr, 8);
#else
llvm::Value* rvalue = builder_->CreateAlignedLoad(load_ptr, 8);
#endif
pc.ret_value = CreateCast(r_api_type, r_type, rvalue);

// Load the return type code.
#if TVM_LLVM_VERSION >= 110
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8));
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, llvm::Align(8));
#elif TVM_LLVM_VERSION >= 80
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.type, ret_tcode.addr, 8);
#else
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode, 8);
pc.ret_tcode = builder_->CreateAlignedLoad(ret_tcode.addr, 8);
#endif

pc.end_block = end_block;
Expand Down Expand Up @@ -882,24 +935,24 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
} else if (op->op.same_as(builtin::tvm_struct_get())) {
ICHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* ref =
this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
TypedPointer ref =
CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind);
if (kind == builtin::kArrAddr) {
return builder_->CreatePointerCast(ref, t_void_p_);
return builder_->CreatePointerCast(ref.addr, t_void_p_);
} else {
return builder_->CreateLoad(ref);
return builder_->CreateLoad(ref.type, ref.addr);
}
} else if (op->op.same_as(builtin::tvm_struct_set())) {
ICHECK_EQ(op->args.size(), 4U);
int kind = op->args[2].as<IntImmNode>()->value;
llvm::Value* value = MakeValue(op->args[3]);
llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
TypedPointer ref = CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]),
MakeValue(op->args[1]), kind);
ICHECK(kind != builtin::kArrAddr);
if (value->getType()->isPointerTy()) {
value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType());
value = builder_->CreatePointerCast(value, ref.type);
}
builder_->CreateStore(value, ref);
builder_->CreateStore(value, ref.addr);
return ConstInt32(0);
} else if (op->op.same_as(builtin::tvm_stack_alloca())) {
ICHECK_EQ(op->args.size(), 2U);
Expand Down
6 changes: 3 additions & 3 deletions src/target/llvm/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::Value* RuntimeTVMParallelBarrier();
llvm::Value* CreateStaticHandle();
llvm::Value* GetPackedFuncHandle(const std::string& str);
llvm::Value* PackClosureData(const Array<Var>& fields, uint64_t* num_bytes);
llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(llvm::Value* cdata, const Array<Var>& fields,
TypedPointer PackClosureData(const Array<Var>& fields, uint64_t* num_bytes);
TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind);
void UnpackClosureData(TypedPointer cdata, const Array<Var>& fields,
std::unordered_map<const VarNode*, llvm::Value*>* vmap);
// Make packed call.
struct PackedCall {
Expand Down
Loading

0 comments on commit df50fa3

Please sign in to comment.