Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>
  • Loading branch information
imaihal committed Aug 23, 2024
1 parent 5ac7b1f commit e8b935d
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 616 deletions.
132 changes: 15 additions & 117 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,28 +190,12 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));

// Create a ZHighStickifiedConstantOp.
// Set zero in value attribute later in lowering pass to LLVMIR.
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
/*value=*/nullptr,
/*layout*/ StringAttr(),
/*alignment=*/rewriter.getI64IntegerAttr(4096));
// // Use an dense resource attribute to store stickified data.
// // Attribute type: tensor<sizeInBytes x i8>
// int64_t sizeInBytes =
// affine::getIntOrFloatMemRefSizeInBytes(resType).value();
// char *rawData = (char *)malloc(sizeInBytes);
// memset(rawData, 0, sizeInBytes);
// DenseResourceElementsAttr valueAttr =
// DenseUI8ResourceElementsAttr::get(
// RankedTensorType::get({sizeInBytes}, rewriter.getI8Type()),
// stickifiedConstant.getOperation()
// ->getDialect()
// ->getNamespace(), // use the dialect as the blob "hint"
// HeapAsmResourceBlob::allocateAndCopyWithAlign(
// llvm::ArrayRef(rawData, sizeInBytes), alignof(char)));
// stickifiedConstant.setValueAttr(valueAttr);
// free(rawData);

res = stickifiedConstant.getResult();
} else {
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
Expand Down Expand Up @@ -664,82 +648,6 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickifiedConstantOp zhighStickifiedConstOp =
cast<ZHighStickifiedConstantOp>(op);

// IndexExprBuilderForKrnl createKrnlIE(rewriter, loc);
// ZHighStickifiedConstantOpShapeHelper shapeHelper(op, operands,
// &createKrnlIE); shapeHelper.computeShapeAndAssertOnFailure();

// Convert ZTensor type to MemRefType.
ZMemRefType zMemRefType =
convertZTensorToMemRefType(*op->result_type_begin());

// Normalize MemRefType to get a static shape.
assert(mlir::cast<MemRefType>(zMemRefType.value).getNumDynamicDims() == 0 &&
"MemRefType has dynamic dimensions");
MemRefType normalizedType =
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();
ZLowStickifiedConstantOp zlowStickifiedConstantOp;
if (zhighStickifiedConstOp.getValueAttr()) {
DenseElementsAttr dataAttr =
mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(
zhighStickifiedConstOp.getValue().value());

ArrayRef<int64_t> shape = dataAttr.getType().getShape();

// Create a ZLowStickifiedConstantOp.
zlowStickifiedConstantOp = rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*value=*/zhighStickifiedConstOp.getValueAttr(),
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*layout=*/zhighStickifiedConstOp.getLayoutAttr(),
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
} else {
// Create a ZLowStickifiedConstantOp.
zlowStickifiedConstantOp = rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*value=*/nullptr,
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*layout=*/zhighStickifiedConstOp.getLayoutAttr(),
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());
}
// Increment constant ID:
constantID++;

rewriter.replaceOp(op, zlowStickifiedConstantOp.getResult());
return success();
}
};

int ZHighToZLowStickifiedConstantOpLowering::constantID = 0;

//===----------------------------------------------------------------------===//
// Lower ZHigh Stickified Constant to KrnlGlobal (Original)
//===----------------------------------------------------------------------===//

struct ZHighToZLowStickifiedConstantOpLoweringOriginal
: public ConversionPattern {
static int constantID;
ZHighToZLowStickifiedConstantOpLoweringOriginal(
TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(typeConverter,
ZHighStickifiedConstantOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
ZHighStickifiedConstantOp stickifiedConstOp =
llvm::dyn_cast<ZHighStickifiedConstantOp>(op);

// Convert ZTensor type to MemRefType.
Expand All @@ -753,41 +661,33 @@ struct ZHighToZLowStickifiedConstantOpLoweringOriginal
affine::normalizeMemRefType(mlir::cast<MemRefType>(zMemRefType.value));
ArrayRef<int64_t> normalizedShape = normalizedType.getShape();

// Get dense resource attribute.
auto blob = mlir::cast<DenseResourceElementsAttr>(
stickifiedConstOp.getValue().value())
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
ArrayRef<char> data = blob->getData();

// Validate the stickified tensor.
int64_t memRefSizeInBytes = getMemRefEltSizeInBytes(normalizedType);
memRefSizeInBytes *= normalizedType.getNumElements();
assert((data.size() == (uint64_t)memRefSizeInBytes) &&
"The stickified tensor's buffer size and MemRef's size mismatched");

// Create a KrnlGlobalOp.
KrnlGlobalOp constantGlobal =
rewriter.create<KrnlGlobalOp>(loc, zMemRefType.value,
auto valueAttr = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(
zhighStickifiedConstOp.getValueAttr());

// Create a ZLowStickifiedConstantOp.
// Set nullptr in the valueAttr when it is initialized with zero later.
ZLowStickifiedConstantOp zlowStickifiedConstantOp =
rewriter.create<ZLowStickifiedConstantOp>(loc,
mlir::cast<MemRefType>(zMemRefType.value),
/*shape=*/
rewriter.getI64ArrayAttr(normalizedShape),
/*value=*/valueAttr,
/*name=*/
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*value=*/stickifiedConstOp.getValueAttr(),
/*offset=*/nullptr,
/*alignment=*/stickifiedConstOp.getAlignmentAttr());
/*layout=*/zhighStickifiedConstOp.getLayoutAttr(),
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());

// Increment constant ID:
constantID++;

rewriter.replaceOp(op, constantGlobal.getResult());
rewriter.replaceOp(op, zlowStickifiedConstantOp.getResult());
return success();
}
};

int ZHighToZLowStickifiedConstantOpLoweringOriginal::constantID = 0;
int ZHighToZLowStickifiedConstantOpLowering::constantID = 0;

template <typename OP_TYPE>
struct ZLowOpFor {
Expand Down Expand Up @@ -1815,8 +1715,6 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx,
bool enableParallel) {
// Stickify and unstickify operations.
// patterns.insert<ZHighToZLowStickifiedConstantOpLoweringOriginal>(typeConverter,
// ctx);
patterns.insert<ZHighToZLowStickifiedConstantOpLowering>(typeConverter, ctx);
patterns.insert<ZHighToZLowStickOpLowering>(typeConverter, ctx);
patterns.insert<ZHighToZLowStickForLSTMOpLowering>(typeConverter, ctx);
Expand Down
13 changes: 0 additions & 13 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -872,19 +872,6 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
let results = (outs AnyZTensor:$output);
}

//def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
// let summary = "ZHigh Stickified Constant operation";
// let description = [{
// This operator produces a constant tensor to store stickified data.
// Stickified data is opaque and must be 4K-aligned. One who produces
// the stickified data must make sure its size in bytes consistent with
// the output tensor's size.
// }];
// let arguments = (ins OptionalAttr<AnyAttr>:$value,
// DefaultValuedAttr<I64Attr, "4096">:$alignment);
// let results = (outs AnyZTensor:$output);
//}

def ZHighStickifiedConstantOfShapeOp:ZHigh_Op<"StickifiedConstantOfShape", [Pure,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
Expand Down
36 changes: 0 additions & 36 deletions src/Accelerators/NNPA/Transform/ZHigh/ZHighConstPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,46 +92,10 @@ ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter,
Value replacingValue, Value input, StringAttr layout) {
Location loc = replacingValue.getLoc();
Operation *op = input.getDefiningOp();
// ArrayRef<int64_t> shape =
// mlir::cast<ShapedType>(input.getType()).getShape();
// Type elementType =
// mlir::cast<ShapedType>(input.getType()).getElementType();
// int rank = shape.size();

// Read dense attributes.
DenseElementsAttr dataAttr = mlir::dyn_cast_or_null<mlir::DenseElementsAttr>(
op->getAttrOfType<::mlir::Attribute>("value"));
assert(dataAttr && "Attribute is null");
// Read attributes's raw data.
// std::vector<char> rawData;
// getRawData(dataAttr, rawData);
// // assert((rawData.size() == (uint64_t)getMemRefSizeInBytes(input)) &&
// // "Data size mismatched");
//
// // Call stickify.
// zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc;
// // pre-transformed desc.
// zdnn_data_layouts zDNNLayout =
// convertLayoutAttrToZDNNDataLayout(rank, layout);
// // If zDNNLayout is NHWC, we stickify directly from NCHW.
// if (zDNNLayout == ZDNN_NHWC)
// zDNNLayout = ZDNN_NCHW;
// zdnn_data_types zDNNType = mlirTypeToZDNNType(elementType);
// set_info_pre_transformed_desc(&pre_tfrmd_desc, zDNNLayout, zDNNType,
// shape);
// // transformed desc.
// zdnn_status status = generate_transformed_desc(&pre_tfrmd_desc,
// &tfrmd_desc); assert(status == ZDNN_OK);
// // Stick data using the software stickify.
// zdnn_ztensor ztensor;
// init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor);
// status = allochelper_ztensor_alloc(&ztensor);
// assert(status == ZDNN_OK);
// status = stickify(&ztensor, rawData.data());
// assert(status == ZDNN_OK);
// Emit a constant global in ZHigh dialect.
// ZHighStickifiedConstantOp constantOp = emitZHighStickifiedConstant(
// rewriter, loc, &ztensor, replacingValue.getType());
ZHighStickifiedConstantOp constantOp =
rewriter.create<ZHighStickifiedConstantOp>(loc, replacingValue.getType(),
/*value=*/dataAttr,
Expand Down
3 changes: 1 addition & 2 deletions src/Conversion/KrnlToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

add_onnx_mlir_library(OMKrnlToLLVM
ConvertKrnlToLLVM.cpp
ConstantOpInterface.cpp
KrnlFindIndex.cpp
KrnlCall.cpp
KrnlEntryPoint.cpp
# KrnlGlobal.cpp
KrnlConstantOpInterface.cpp
KrnlInstrument.cpp
KrnlMemcpy.cpp
KrnlNone.cpp
Expand Down
Loading

0 comments on commit e8b935d

Please sign in to comment.