Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] Codegen for 2d Load/Store #10586

Merged
merged 8 commits into from
Mar 15, 2022
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,12 +815,12 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(begin));
TypedPointer arg_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
TypedPointer ret_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
Expand Down
31 changes: 29 additions & 2 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class CodeGenHexagon final : public CodeGenLLVM {
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};

private:
TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype) final;
TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind);

// Check if the call to packed function is successful
Expand Down Expand Up @@ -320,12 +322,12 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array<Pri
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(begin));
TypedPointer arg_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
TypedPointer ret_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
Expand Down Expand Up @@ -570,6 +572,31 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) {
return CodeGenLLVM::CreateIntrinsic(op);
}

CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_ptr,
DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices,
DataType value_dtype) {
// Flat indices get delegated to the LLVM codegen.
if (indices.size() == 1) {
return CodeGenLLVM::CreateBufferPtr(buffer_ptr, buffer_element_dtype, indices, value_dtype);
}

ICHECK_EQ(indices.size(), 2) << "CodegenHexagon supports 1-d and 2-d physical buffers, received "
<< indices.size() << "-d buffer indices";

// Use the first index to identify the pointer.
DataType dtype_void_ptr = DataType::Handle();
CodeGenLLVM::TypedPointer buffer_chunk_ptr_ptr =
CodeGenLLVM::CreateBufferPtr(buffer_ptr, dtype_void_ptr, {indices[0]}, dtype_void_ptr);
llvm::Value* buffer_chunk_ptr =
builder_->CreateLoad(buffer_chunk_ptr_ptr.type, buffer_chunk_ptr_ptr.addr);

// Then delegate the CodeGenLLVM to find the value from the second
// index.
return CodeGenLLVM::CreateBufferPtr(buffer_chunk_ptr, buffer_element_dtype, {indices[1]},
value_dtype);
}

CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf,
llvm::Value* index, int kind) {
static const std::map<int, int> field_index = {
Expand Down
86 changes: 57 additions & 29 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,11 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {

CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr,
DataType buffer_element_dtype,
llvm::Value* index, DataType value_dtype) {
llvm::ArrayRef<llvm::Value*> indices,
DataType value_dtype) {
ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers.";
llvm::Value* index = indices[0];

llvm::PointerType* buffer_ptr_type = llvm::dyn_cast<llvm::PointerType>(buffer_ptr->getType());
ICHECK(buffer_ptr_type != nullptr);
auto address_space = buffer_ptr_type->getAddressSpace();
Expand Down Expand Up @@ -1010,7 +1014,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
index = r->base;
}
TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
MakeValue(index), load->dtype);
{MakeValue(index)}, load->dtype);
unsigned addrspace =
llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace();
return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace));
Expand Down Expand Up @@ -1274,39 +1278,56 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) {
}

void CodeGenLLVM::BufferAccessHelper(
Buffer buffer, PrimExpr index, DataType value_dtype,
Buffer buffer, Array<PrimExpr> indices, DataType value_dtype,
std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment,
bool is_volatile)>
make_instruction) {
DataType buffer_element_dtype = buffer->dtype;

ICHECK_EQ(value_dtype.lanes(), index.dtype().lanes() * buffer_element_dtype.lanes());
ICHECK_GE(indices.size(), 1)
<< "Buffer " << buffer->name << " is accessed with no indices. "
<< "0-d scalar buffers are expected to be flattened to 1-d buffers prior to codegen.";

// Only the last index is allowed to be multi-lane. All earlier
// indices must be scalar. This only matters for subclasses of
// CodeGenLLVM, because the default implementation of GetBufferPtr
// requires 1-d indices.
std::vector<llvm::Value*> earlier_index_values;
for (size_t i = 0; i < indices.size() - 1; i++) {
ICHECK_EQ(indices[i].dtype().lanes(), 1)
<< "Buffer " << buffer->name << " is accessed with a multi-lane index at position " << i
<< ". Multi-lane indices are only supported as the last index.";
earlier_index_values.push_back(MakeValue(indices[i]));
}

PrimExpr last_index = indices[indices.size() - 1];
ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes());

bool is_volatile = volatile_buf_.count(buffer->data.get());

// If the buffer index is a contiguous ramp node, we only need to
// access the first element, then cast to the value type.
if (const RampNode* ramp_index = index.as<RampNode>()) {
if (const RampNode* ramp_index = last_index.as<RampNode>()) {
if (ramp_index && is_one(ramp_index->stride)) {
index = ramp_index->base;
last_index = ramp_index->base;
}
}

// All TVM arrays are densely packed. If the vectorized LLVM type
// contains padding for alignment, we need to index based on the
// size of the scalar type to avoid introducing that padding.
if (index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) {
index = buffer_element_dtype.lanes() * index;
if (last_index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) {
last_index = buffer_element_dtype.lanes() * last_index;
buffer_element_dtype = buffer_element_dtype.element_of();
}

int alignment;
if (index.dtype().lanes() == 1) {
if (last_index.dtype().lanes() == 1) {
// If we are accessing with a single index, then the vectorized
// element being accessed may require more alignment than the
// underlying data type.
int native_bits;
GetAlignment(value_dtype, buffer->data.get(), index, &alignment, &native_bits);
GetAlignment(value_dtype, buffer->data.get(), last_index, &alignment, &native_bits);
} else {
// Otherwise, alignment is based on the return value's scalar
// type.
Expand All @@ -1315,35 +1336,35 @@ void CodeGenLLVM::BufferAccessHelper(
}

llvm::Value* cached_vector_index = nullptr;
for (int i = 0; i < index.dtype().lanes(); ++i) {
llvm::Value* index_value;
for (int i = 0; i < last_index.dtype().lanes(); ++i) {
llvm::Value* last_index_value;
int subelement_i = i;
if (const RampNode* ramp = index.as<RampNode>()) {
if (const RampNode* ramp = last_index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
index_value = MakeValue(offset);
} else if (index.dtype().lanes() > 1) {
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
if (i == 0) {
cached_vector_index = MakeValue(index);
cached_vector_index = MakeValue(last_index);
}
index_value = builder_->CreateExtractElement(cached_vector_index, i);
last_index_value = builder_->CreateExtractElement(cached_vector_index, i);
} else {
index_value = MakeValue(index);
last_index_value = MakeValue(last_index);
subelement_i = -1;
}

std::vector<llvm::Value*> all_index_values = earlier_index_values;
all_index_values.push_back(last_index_value);

TypedPointer buffer_ptr =
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, index_value,
value_dtype.with_lanes(value_dtype.lanes() / index.dtype().lanes()));
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile);
AddAliasInfo(instruction, buffer->data.get(), index);
AddAliasInfo(instruction, buffer->data.get(), last_index);
}
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";

DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];

std::vector<llvm::Value*> loads;

Expand All @@ -1363,7 +1384,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
return load;
};

BufferAccessHelper(op->buffer, index, value_dtype, make_load);
// Pass all indices into BufferAccessHelper. In CodeGenLLVM,
// non-flat indices will result in an error in CreateBufferPtr, but
// a subclass may override CreateBufferPtr.
BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load);

if (loads.size() == 1) {
return loads[0];
Expand Down Expand Up @@ -1441,11 +1465,8 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
}

void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";

DataType value_dtype = op->value.dtype();
Var buffer_var = op->buffer->data;
PrimExpr buffer_index = op->indices[0];

llvm::Value* value = MakeValue(op->value);

Expand All @@ -1463,7 +1484,10 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
#endif
};

BufferAccessHelper(op->buffer, buffer_index, value_dtype, make_store);
// Pass all indices into BufferAccessHelper. In CodeGenLLVM,
// non-flat indices will result in an error in CreateBufferPtr, but
// a subclass may override CreateBufferPtr.
BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store);
}

void CodeGenLLVM::VisitStmt_(const ForNode* op) {
Expand Down Expand Up @@ -1528,6 +1552,10 @@ void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) {
}

void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
ICHECK_EQ(op->extents.size(), 1)
<< "LLVM codegen only supports flat 1-d buffer allocation, but allocation of "
<< op->buffer_var->name_hint << " is " << op->extents << "-d";

ICHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;

Expand Down
8 changes: 4 additions & 4 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*
* \param buffer The buffer being accessed
*
* \param index The index at which the buffer is being accessed.
* \param indices The indices at which the buffer is being accessed.
*
* \param value_dtype The datatype to be read from (BufferLoad) or
* written to (BufferStore) the buffer.
Expand All @@ -286,7 +286,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
* - Should return the generated expression.
*/
void BufferAccessHelper(
Buffer buffer, PrimExpr index, DataType value_dtype,
Buffer buffer, Array<PrimExpr> indices, DataType value_dtype,
std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment,
bool is_volatile)>
make_instruction);
Expand Down Expand Up @@ -372,8 +372,8 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
llvm::Value* index, DataType value_dtype);
virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype);
// Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec);
Expand Down
Loading