Skip to content

Commit

Permalink
Apply only krnl related part.
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>
  • Loading branch information
imaihal committed Sep 13, 2024
1 parent bce4d6f commit 837217f
Show file tree
Hide file tree
Showing 12 changed files with 312 additions and 105 deletions.
4 changes: 3 additions & 1 deletion src/Conversion/KrnlToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

add_onnx_mlir_library(OMKrnlToLLVM
ConvertKrnlToLLVM.cpp
ConstantOpInterface.cpp
KrnlFindIndex.cpp
KrnlCall.cpp
KrnlEntryPoint.cpp
KrnlGlobal.cpp
# KrnlGlobal.cpp
KrnlInstrument.cpp
KrnlMemcpy.cpp
KrnlNone.cpp
Expand All @@ -21,6 +22,7 @@ add_onnx_mlir_library(OMKrnlToLLVM

LINK_LIBS PUBLIC
OMAccelerator
OMConstantOpInterface
OMSupport
MLIRAffineToStandard
MLIRArithTransforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,35 @@ namespace krnl {
/// This variable is initizalied inside ConvertKrnlToLLVMPass.
extern std::string EXTERNAL_CONSTANT_PREFIX;

class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
class ConstantOpInterfaceLowering
: public OpInterfaceConversionPattern<ConstantOpInterface> {
public:
explicit KrnlGlobalOpLowering(
explicit ConstantOpInterfaceLowering(
LLVMTypeConverter &typeConverter, MLIRContext *context)
: ConvertToLLVMPattern(
KrnlGlobalOp::getOperationName(), context, typeConverter) {}
: OpInterfaceConversionPattern(typeConverter, context) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
LogicalResult matchAndRewrite(ConstantOpInterface op,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto krnlGlobalOp = llvm::dyn_cast<KrnlGlobalOp>(op);
Location loc = krnlGlobalOp.getLoc();
MLIRContext *context = krnlGlobalOp.getContext();
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
const LLVMTypeConverter *llvmTypeConverter =
static_cast<const LLVMTypeConverter *>(getTypeConverter());

// Basic type.
Type llvmI8Ty = IntegerType::get(context, 8);
Type llvmI8PtrTy = getPointerType(context, llvmI8Ty);

// The element type of the array.
const Type type = op->getResult(0).getType();
const Type type = op.getResult().getType();
const MemRefType memRefTy = mlir::cast<mlir::MemRefType>(type);
const Type constantElementType =
typeConverter->convertType(memRefTy.getElementType());
llvmTypeConverter->convertType(memRefTy.getElementType());
Type globalType = constantElementType;

// The llvm type of the global (example: [2 x [8 x float]]).
const auto shape = mlir::dyn_cast<ArrayAttr>(krnlGlobalOp.getShape());
const auto shape = mlir::dyn_cast<ArrayAttr>(op.getShape());
if (shape.empty())
globalType = LLVM::LLVMArrayType::get(mlir::cast<Type>(globalType), 1);
else {
Expand All @@ -74,32 +76,32 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
LLVM::GlobalOp global;
// Pointer to the raw data of the global.
Value dataPtr;
// Update value attribute if needed.
op.updateBuffer();

if (krnlGlobalOp.getValue().has_value()) {
auto value = krnlGlobalOp.getValue().value();
if (op.getValue().has_value()) {
auto value = op.getValue().value();
TypeSwitch<Attribute>(value)
.Case<DenseResourceElementsAttr>([&](DenseResourceElementsAttr attr) {
global =
lowerDenseResourceConstant(krnlGlobalOp, globalType, rewriter);
global = lowerDenseResourceConstant(op, globalType, rewriter);
})
.Case<DenseElementsAttr>([&](DenseElementsAttr attr) {
global = lowerDenseConstant(krnlGlobalOp, globalType, rewriter);
global = lowerDenseConstant(op, globalType, rewriter);
})
.Default([&](Attribute attr) {
llvm_unreachable("Unsupported attribute type");
});
dataPtr = create.llvm.addressOf(global);
} else {
// Data are stored on files.
global = lowerGlobalOpWithExternalFiles(krnlGlobalOp, rewriter);
global = lowerGlobalOpWithExternalFiles(op, rewriter);
dataPtr = create.llvm.load(llvmI8PtrTy, create.llvm.addressOf(global));
}

// Set the global alignment based on the alignment attribute if it exists,
// otherwise use the module datalayout info.
krnl::setAlignment(global, krnlGlobalOp.getAlignmentAttr(),
krnlGlobalOp->getParentOfType<ModuleOp>(), rewriter,
*getTypeConverter());
krnl::setAlignment(global, op.getAlignmentAttr(),
op->getParentOfType<ModuleOp>(), rewriter, *llvmTypeConverter);

// Prepare data to be inserted into a MemRefDescriptor (a struct).
MemRefDescriptor memRefDescr =
Expand All @@ -115,55 +117,56 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
return mlir::cast<IntegerAttr>(a.getValue()[i]).getInt();
}

LLVM::GlobalOp lowerDenseResourceConstant(KrnlGlobalOp &krnlGlobalOp,
Type globalType, ConversionPatternRewriter &rewriter) const {
assert(krnlGlobalOp.getValue().has_value() &&
"Expecting KrnlGlobalOp with a valid value");
assert(
mlir::isa<DenseResourceElementsAttr>(krnlGlobalOp.getValue().value()) &&
"Expecting a global with an dense resource elements attribute");

MLIRContext *context = krnlGlobalOp.getContext();
Location loc = krnlGlobalOp.getLoc();
ModuleOp module = krnlGlobalOp->getParentOfType<ModuleOp>();
LLVM::GlobalOp lowerDenseResourceConstant(
ConstantOpInterface &constOpInterface, Type globalType,
ConversionPatternRewriter &rewriter) const {
assert(constOpInterface.getValue().has_value() &&
"Expecting ConstantOpInterface with a valid value");
assert(mlir::isa<DenseResourceElementsAttr>(
constOpInterface.getValue().value()) &&
"Expecting a global with an dense resource elements attribute");

MLIRContext *context = constOpInterface.getContext();
Location loc = constOpInterface.getLoc();
ModuleOp module = constOpInterface->getParentOfType<ModuleOp>();
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());

auto blob =
mlir::cast<DenseResourceElementsAttr>(krnlGlobalOp.getValue().value())
.getRawHandle()
.getBlob();
auto blob = mlir::cast<DenseResourceElementsAttr>(
constOpInterface.getValue().value())
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> rawData = blob->getData();

// Check data size.
uint64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp);
uint64_t sizeInBytes = computeSizeInBytes(constOpInterface);
assert(((uint64_t)rawData.size() == sizeInBytes) && "Data size mismatch.");

StringRef data(rawData.data(), rawData.size());
StringAttr llvmStringAttr = StringAttr::get(context, data);
auto llvmArrayI8Ty =
LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes);
LLVM::GlobalOp global = create.llvm.globalOp(llvmArrayI8Ty,
/*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(),
llvmStringAttr);
/*isConstant=*/true, LLVM::Linkage::Internal,
constOpInterface.getName(), llvmStringAttr);

LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";);
return global;
}

LLVM::GlobalOp lowerDenseConstant(KrnlGlobalOp &krnlGlobalOp, Type globalType,
ConversionPatternRewriter &rewriter) const {
assert(krnlGlobalOp.getValue().has_value() &&
"Expecting KrnlGlobalOp with a valid value");
assert(mlir::isa<DenseElementsAttr>(krnlGlobalOp.getValue().value()) &&
LLVM::GlobalOp lowerDenseConstant(ConstantOpInterface &constOpInterface,
Type globalType, ConversionPatternRewriter &rewriter) const {
assert(constOpInterface.getValue().has_value() &&
"Expecting ConstantOpInterface with a valid value");
assert(mlir::isa<DenseElementsAttr>(constOpInterface.getValue().value()) &&
"Expecting a global with an dense elements attribute");

Location loc = krnlGlobalOp.getLoc();
ModuleOp module = krnlGlobalOp->getParentOfType<ModuleOp>();
MLIRContext *context = krnlGlobalOp.getContext();
Location loc = constOpInterface.getLoc();
ModuleOp module = constOpInterface->getParentOfType<ModuleOp>();
MLIRContext *context = constOpInterface.getContext();
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

Type llvmI8Ty = IntegerType::get(context, 8);
Expand All @@ -172,9 +175,9 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
rewriter.setInsertionPointToStart(module.getBody());

DenseElementsAttr denseAttr =
mlir::cast<DenseElementsAttr>(krnlGlobalOp.getValue().value());
mlir::cast<DenseElementsAttr>(constOpInterface.getValue().value());

uint64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp);
uint64_t sizeInBytes = computeSizeInBytes(constOpInterface);
LLVM::GlobalOp global;
if (!(mlir::isa<StringType>(denseAttr.getElementType())) &&
!(denseAttr.getElementType().isInteger(1)) && (!denseAttr.isSplat()) &&
Expand All @@ -188,37 +191,39 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
StringRef data(rawData.data(), rawData.size());
StringAttr llvmStringAttr = StringAttr::get(context, data);
global = create.llvm.globalOp(llvmArrayI8Ty,
/*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(),
llvmStringAttr);
/*isConstant=*/true, LLVM::Linkage::Internal,
constOpInterface.getName(), llvmStringAttr);
} else {
if (mlir::isa<StringType>(denseAttr.getElementType()))
global = lowerStringLiteral(krnlGlobalOp, globalType, rewriter);
global = lowerStringLiteral(constOpInterface, globalType, rewriter);
else
global = create.llvm.globalOp(globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
krnlGlobalOp.getName(), krnlGlobalOp.getValue().value());
constOpInterface.getName(), constOpInterface.getValue().value());
}

LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";);
return global;
}

LLVM::GlobalOp lowerGlobalOpWithExternalFiles(
KrnlGlobalOp &krnlGlobalOp, ConversionPatternRewriter &rewriter) const {
Location loc = krnlGlobalOp.getLoc();
MLIRContext *context = krnlGlobalOp.getContext();
ModuleOp module = krnlGlobalOp.getOperation()->getParentOfType<ModuleOp>();
ConstantOpInterface &constOpInterface,
ConversionPatternRewriter &rewriter) const {
Location loc = constOpInterface.getLoc();
MLIRContext *context = constOpInterface.getContext();
ModuleOp module =
constOpInterface.getOperation()->getParentOfType<ModuleOp>();
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);

Type llvmI8Ty = IntegerType::get(context, 8);
Type llvmI8PtrTy = getPointerType(context, llvmI8Ty);
Type llvmI64Ty = IntegerType::get(context, 64);

auto offset = krnlGlobalOp.getOffset();
assert(offset.has_value() && "Missing offset value in KrnlGlobalOp");
auto offset = constOpInterface.getOffset();
assert(offset.has_value() && "Missing offset value in ConstantOpInterface");

// Data is store in `constants.bin` at offset.
std::string constantName = krnlGlobalOp.getName().str();
std::string constantName = constOpInterface.getName().str();

// Emit globals at the begining of the module.
OpBuilder::InsertionGuard insertGuard(rewriter);
Expand Down Expand Up @@ -246,14 +251,14 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
return global;
}

uint64_t computeSizeInBytes(KrnlGlobalOp &krnlGlobalOp) const {
uint64_t computeSizeInBytes(ConstantOpInterface &constOpInterface) const {
// Compute total number of elements.
const auto shape = mlir::dyn_cast<ArrayAttr>(krnlGlobalOp.getShape());
const auto shape = mlir::dyn_cast<ArrayAttr>(constOpInterface.getShape());
uint64_t numElements = 1;
for (unsigned int i = 0; i < shape.size(); ++i)
numElements *= ArrayAttrIntVal(shape, i);

const auto type = krnlGlobalOp.getResult().getType();
const auto type = constOpInterface.getResult().getType();
const auto memRefTy = mlir::cast<mlir::MemRefType>(type);

// Special handling for bool.
Expand All @@ -267,8 +272,9 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
MemRefDescriptor createMemRefDescriptor(Value address, MemRefType memRefType,
Location loc, OpBuilder &builder) const {
Type elementType = memRefType.getElementType();
const LLVMTypeConverter &typeConverter = *getTypeConverter();
Type llvmElemType = typeConverter.convertType(elementType);
const LLVMTypeConverter *llvmTypeConverter =
static_cast<const LLVMTypeConverter *>(getTypeConverter());
Type llvmElemType = llvmTypeConverter->convertType(elementType);
MLIRContext *context = builder.getContext();
MultiDialectBuilder<LLVMBuilder> create(builder, loc);

Expand All @@ -278,21 +284,21 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
Value bitCastOp = create.llvm.bitcast(ptrType, address);
// Create llvm MemRef from original MemRef and fill the data pointers.
return MemRefDescriptor::fromStaticShape(
builder, loc, typeConverter, memRefType, bitCastOp);
builder, loc, *llvmTypeConverter, memRefType, bitCastOp);
}

// Generate a global string for each krnlGlobalOp string value, and store
// Generate a global string for each constOpInterface string value, and store
// the address of the global strings into an array. Return the array address.
LLVM::GlobalOp lowerStringLiteral(
KrnlGlobalOp &krnlGlobalOp, Type globalType, OpBuilder &builder) const {
assert(mlir::isa<DenseElementsAttr>(krnlGlobalOp.getValue().value()) &&
LLVM::GlobalOp lowerStringLiteral(ConstantOpInterface &constOpInterface,
Type globalType, OpBuilder &builder) const {
assert(mlir::isa<DenseElementsAttr>(constOpInterface.getValue().value()) &&
"Expecting a dense value");

Location loc = krnlGlobalOp.getLoc();
Location loc = constOpInterface.getLoc();
MultiDialectBuilder<LLVMBuilder> create(builder, loc);

DenseElementsAttr denseAttr =
mlir::cast<DenseElementsAttr>(krnlGlobalOp.getValue().value());
mlir::cast<DenseElementsAttr>(constOpInterface.getValue().value());

Type i8PtrType = getI8PointerType(builder.getContext());

Expand Down Expand Up @@ -322,14 +328,14 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(i8Type, totalSize);
LLVM::GlobalOp globalStr = create.llvm.globalOp(llvmArrayI8Ty,
/*isConstant=*/true, LLVM::Linkage::Internal,
"om.strArray." + krnlGlobalOp.getName().str(), llvmStringAttr);
"om.strArray." + constOpInterface.getName().str(), llvmStringAttr);

// Generate an LLVM GlobalOps with an initializer region containing one
// block.
auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, offsets.size());
auto global = create.llvm.globalOp(arrayType,
/*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.getName(),
Attribute());
/*isConstant=*/true, LLVM::Linkage::Internal,
constOpInterface.getName(), Attribute());
Region &region = global.getInitializerRegion();
Block *block = builder.createBlock(&region);

Expand All @@ -355,9 +361,10 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
}
};

void populateLoweringKrnlGlobalOpPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<KrnlGlobalOpLowering>(typeConverter, ctx);
void populateLoweringConstantOpInterfacePattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
MLIRContext *ctx) {
patterns.insert<ConstantOpInterfaceLowering>(typeConverter, ctx);
}

} // namespace krnl
Expand Down
Loading

0 comments on commit 837217f

Please sign in to comment.