Skip to content

Commit

Permalink
[Hexagon] Codegen for 2d Load/Store (apache#10586)
Browse files Browse the repository at this point in the history
* Added unit tests for codegen of 2d physical buffers in Hexagon.

* Update IndexMap when buffers are updated.

* Extended CodeGenLLVM::BufferAccessHelper to support N-d

This way, a subclass can override GetBufferPtr, without needing to
reimplement all of the other indexing logic for
BufferLoad/BufferStore.

* Updated CodeGenHexagon to treat 2-d physical buffers as T**

* Moved indices size check earlier.

Previous location in `CodeGenLLVM::BufferAccessHelper` occurred after
possible integer wrapping in `indices.size()-1` loop bounds.

* Updated to use `llvm::ArrayRef` instead of `std::vector`.

* Resolve lint error.

* CI fix, contextlib.nullcontext not available on python3.6
  • Loading branch information
Lunderberg authored and pfk-beta committed Apr 11, 2022
1 parent 3e173f5 commit 125fc65
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,12 +815,12 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(begin));
TypedPointer arg_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
TypedPointer ret_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
Expand Down
31 changes: 29 additions & 2 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class CodeGenHexagon final : public CodeGenLLVM {
llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};

private:
TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype) final;
TypedPointer CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, int kind);

// Check if the call to packed function is successful
Expand Down Expand Up @@ -320,12 +322,12 @@ CodeGenHexagon::PackedCall CodeGenHexagon::MakeCallPackedLowered(const Array<Pri
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(begin));
TypedPointer arg_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(begin), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(begin)}, DataType::Int(32));
llvm::Value* ret_value = builder_->CreateInBoundsGEP(
t_tvm_value_, builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()),
ConstInt32(end));
TypedPointer ret_tcode =
CreateBufferPtr(stack_tcode, DataType::Int(32), ConstInt32(end), DataType::Int(32));
CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, DataType::Int(32));

#if TVM_LLVM_VERSION >= 90
auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall());
Expand Down Expand Up @@ -570,6 +572,31 @@ llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) {
return CodeGenLLVM::CreateIntrinsic(op);
}

CodeGenLLVM::TypedPointer CodeGenHexagon::CreateBufferPtr(llvm::Value* buffer_ptr,
DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices,
DataType value_dtype) {
// Flat indices get delegated to the LLVM codegen.
if (indices.size() == 1) {
return CodeGenLLVM::CreateBufferPtr(buffer_ptr, buffer_element_dtype, indices, value_dtype);
}

ICHECK_EQ(indices.size(), 2) << "CodegenHexagon supports 1-d and 2-d physical buffers, received "
<< indices.size() << "-d buffer indices";

// Use the first index to identify the pointer.
DataType dtype_void_ptr = DataType::Handle();
CodeGenLLVM::TypedPointer buffer_chunk_ptr_ptr =
CodeGenLLVM::CreateBufferPtr(buffer_ptr, dtype_void_ptr, {indices[0]}, dtype_void_ptr);
llvm::Value* buffer_chunk_ptr =
builder_->CreateLoad(buffer_chunk_ptr_ptr.type, buffer_chunk_ptr_ptr.addr);

// Then delegate the CodeGenLLVM to find the value from the second
// index.
return CodeGenLLVM::CreateBufferPtr(buffer_chunk_ptr, buffer_element_dtype, {indices[1]},
value_dtype);
}

CodeGenLLVM::TypedPointer CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf,
llvm::Value* index, int kind) {
static const std::map<int, int> field_index = {
Expand Down
86 changes: 57 additions & 29 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,11 @@ llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) {

CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr,
DataType buffer_element_dtype,
llvm::Value* index, DataType value_dtype) {
llvm::ArrayRef<llvm::Value*> indices,
DataType value_dtype) {
ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers.";
llvm::Value* index = indices[0];

llvm::PointerType* buffer_ptr_type = llvm::dyn_cast<llvm::PointerType>(buffer_ptr->getType());
ICHECK(buffer_ptr_type != nullptr);
auto address_space = buffer_ptr_type->getAddressSpace();
Expand Down Expand Up @@ -1010,7 +1014,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
index = r->base;
}
TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
MakeValue(index), load->dtype);
{MakeValue(index)}, load->dtype);
unsigned addrspace =
llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace();
return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace));
Expand Down Expand Up @@ -1274,39 +1278,56 @@ bool CodeGenLLVM::HasAlignmentPadding(DataType dtype) {
}

void CodeGenLLVM::BufferAccessHelper(
Buffer buffer, PrimExpr index, DataType value_dtype,
Buffer buffer, Array<PrimExpr> indices, DataType value_dtype,
std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment,
bool is_volatile)>
make_instruction) {
DataType buffer_element_dtype = buffer->dtype;

ICHECK_EQ(value_dtype.lanes(), index.dtype().lanes() * buffer_element_dtype.lanes());
ICHECK_GE(indices.size(), 1)
<< "Buffer " << buffer->name << " is accessed with no indices. "
<< "0-d scalar buffers are expected to be flattened to 1-d buffers prior to codegen.";

// Only the last index is allowed to be multi-lane. All earlier
// indices must be scalar. This only matters for subclasses of
// CodeGenLLVM, because the default implementation of GetBufferPtr
// requires 1-d indices.
std::vector<llvm::Value*> earlier_index_values;
for (size_t i = 0; i < indices.size() - 1; i++) {
ICHECK_EQ(indices[i].dtype().lanes(), 1)
<< "Buffer " << buffer->name << " is accessed with a multi-lane index at position " << i
<< ". Multi-lane indices are only supported as the last index.";
earlier_index_values.push_back(MakeValue(indices[i]));
}

PrimExpr last_index = indices[indices.size() - 1];
ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * buffer_element_dtype.lanes());

bool is_volatile = volatile_buf_.count(buffer->data.get());

// If the buffer index is a contiguous ramp node, we only need to
// access the first element, then cast to the value type.
if (const RampNode* ramp_index = index.as<RampNode>()) {
if (const RampNode* ramp_index = last_index.as<RampNode>()) {
if (ramp_index && is_one(ramp_index->stride)) {
index = ramp_index->base;
last_index = ramp_index->base;
}
}

// All TVM arrays are densely packed. If the vectorized LLVM type
// contains padding for alignment, we need to index based on the
// size of the scalar type to avoid introducing that padding.
if (index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) {
index = buffer_element_dtype.lanes() * index;
if (last_index.dtype().lanes() == 1 && HasAlignmentPadding(buffer_element_dtype)) {
last_index = buffer_element_dtype.lanes() * last_index;
buffer_element_dtype = buffer_element_dtype.element_of();
}

int alignment;
if (index.dtype().lanes() == 1) {
if (last_index.dtype().lanes() == 1) {
// If we are accessing with a single index, then the vectorized
// element being accessed may require more alignment than the
// underlying data type.
int native_bits;
GetAlignment(value_dtype, buffer->data.get(), index, &alignment, &native_bits);
GetAlignment(value_dtype, buffer->data.get(), last_index, &alignment, &native_bits);
} else {
// Otherwise, alignment is based on the return value's scalar
// type.
Expand All @@ -1315,35 +1336,35 @@ void CodeGenLLVM::BufferAccessHelper(
}

llvm::Value* cached_vector_index = nullptr;
for (int i = 0; i < index.dtype().lanes(); ++i) {
llvm::Value* index_value;
for (int i = 0; i < last_index.dtype().lanes(); ++i) {
llvm::Value* last_index_value;
int subelement_i = i;
if (const RampNode* ramp = index.as<RampNode>()) {
if (const RampNode* ramp = last_index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
index_value = MakeValue(offset);
} else if (index.dtype().lanes() > 1) {
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
if (i == 0) {
cached_vector_index = MakeValue(index);
cached_vector_index = MakeValue(last_index);
}
index_value = builder_->CreateExtractElement(cached_vector_index, i);
last_index_value = builder_->CreateExtractElement(cached_vector_index, i);
} else {
index_value = MakeValue(index);
last_index_value = MakeValue(last_index);
subelement_i = -1;
}

std::vector<llvm::Value*> all_index_values = earlier_index_values;
all_index_values.push_back(last_index_value);

TypedPointer buffer_ptr =
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, index_value,
value_dtype.with_lanes(value_dtype.lanes() / index.dtype().lanes()));
CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, all_index_values,
value_dtype.with_lanes(value_dtype.lanes() / last_index.dtype().lanes()));
auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, is_volatile);
AddAliasInfo(instruction, buffer->data.get(), index);
AddAliasInfo(instruction, buffer->data.get(), last_index);
}
}

llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";

DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];

std::vector<llvm::Value*> loads;

Expand All @@ -1363,7 +1384,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
return load;
};

BufferAccessHelper(op->buffer, index, value_dtype, make_load);
// Pass all indices into BufferAccessHelper. In CodeGenLLVM,
// non-flat indices will result in an error in CreateBufferPtr, but
// a subclass may override CreateBufferPtr.
BufferAccessHelper(op->buffer, op->indices, value_dtype, make_load);

if (loads.size() == 1) {
return loads[0];
Expand Down Expand Up @@ -1441,11 +1465,8 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
}

void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
ICHECK_EQ(op->indices.size(), 1) << "CodeGenLLVM expects flattened 1-d buffers.";

DataType value_dtype = op->value.dtype();
Var buffer_var = op->buffer->data;
PrimExpr buffer_index = op->indices[0];

llvm::Value* value = MakeValue(op->value);

Expand All @@ -1463,7 +1484,10 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
#endif
};

BufferAccessHelper(op->buffer, buffer_index, value_dtype, make_store);
// Pass all indices into BufferAccessHelper. In CodeGenLLVM,
// non-flat indices will result in an error in CreateBufferPtr, but
// a subclass may override CreateBufferPtr.
BufferAccessHelper(op->buffer, op->indices, value_dtype, make_store);
}

void CodeGenLLVM::VisitStmt_(const ForNode* op) {
Expand Down Expand Up @@ -1528,6 +1552,10 @@ void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) {
}

void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
ICHECK_EQ(op->extents.size(), 1)
<< "LLVM codegen only supports flat 1-d buffer allocation, but allocation of "
<< op->buffer_var->name_hint << " is " << op->extents << "-d";

ICHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;

Expand Down
8 changes: 4 additions & 4 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
*
* \param buffer The buffer being accessed
*
* \param index The index at which the buffer is being accessed.
* \param indices The indices at which the buffer is being accessed.
*
* \param value_dtype The datatype to be read from (BufferLoad) or
* written to (BufferStore) the buffer.
Expand All @@ -286,7 +286,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
* - Should return the generated expression.
*/
void BufferAccessHelper(
Buffer buffer, PrimExpr index, DataType value_dtype,
Buffer buffer, Array<PrimExpr> indices, DataType value_dtype,
std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i, int alignment,
bool is_volatile)>
make_instruction);
Expand Down Expand Up @@ -372,8 +372,8 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
llvm::Value* index, DataType value_dtype);
virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices, DataType value_dtype);
// Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec);
Expand Down
Loading

0 comments on commit 125fc65

Please sign in to comment.