Skip to content

Commit

Permalink
Use stickified attribute and zeroconst attribute
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>
  • Loading branch information
imaihal committed Sep 2, 2024
1 parent 16773e7 commit 9e236c9
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 75 deletions.
155 changes: 84 additions & 71 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"

#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp"
#include "src/Accelerators/NNPA/Support/Stickify/Stickify.hpp"
// #include "src/Interface/ConstantOpInterface.hpp"

Expand Down Expand Up @@ -367,16 +367,26 @@ void ZLowBatchNormOp::getEffects(
}

/// Get raw data from a dense attribute.
static void getRawData(DenseElementsAttr denseAttr, std::vector<char> &data) {
if (!denseAttr.isSplat()) {
data = denseAttr.getRawData();
} else {
ShapedType denseShapeType = mlir::cast<ShapedType>(denseAttr.getType());
std::vector<char> rawData = denseAttr.getRawData();
int64_t numElements = denseShapeType.getNumElements();
for (int i = 0; i < numElements; i++)
data.insert(data.end(), rawData.begin(), rawData.end());
}
static void getRawData(Attribute dataAttr, std::vector<char> &data) {
TypeSwitch<Attribute>(dataAttr)
.Case<DenseElementsAttr>([&](DenseElementsAttr denseAttr) {
if (!denseAttr.isSplat()) {
data = denseAttr.getRawData();
} else {
ShapedType denseShapeType =
mlir::cast<ShapedType>(denseAttr.getType());
std::vector<char> rawData = denseAttr.getRawData();
int64_t numElements = denseShapeType.getNumElements();
for (int i = 0; i < numElements; i++)
data.insert(data.end(), rawData.begin(), rawData.end());
}
})
.Case<DenseResourceElementsAttr>(
[&](DenseResourceElementsAttr denseResourceAttr) {
data = denseResourceAttr.getRawHandle().getBlob()->getData();
})
.Default(
[&](Attribute attr) { llvm_unreachable("Unsupported data type."); });
}

/// MLIR type to zDNN type.
Expand All @@ -401,67 +411,70 @@ ArrayRef<char> ZLowStickifiedConstantOp::getBuffer() {
StringAttr layout = onnx_mlir::zhigh::getZTensorLayoutAttr(
rewriter, zlowStickifiedConstantOp.getResult().getType());
ArrayRef<char> ret;
if (zlowStickifiedConstantOp.getValueAttr()) {
if (zlowStickifiedConstantOp.getValueAttr() &&
zlowStickifiedConstantOp.getStickifiedAttr()) {
auto dataAttr = zlowStickifiedConstantOp.getValue().value();
TypeSwitch<Attribute>(dataAttr)
.Case<DenseElementsAttr>([&](DenseElementsAttr denseAttr) {
ArrayRef<int64_t> shape = denseAttr.getType().getShape();
Type elementType = denseAttr.getType().getElementType();
int rank = shape.size();
// Read attributes's raw data.
std::vector<char> attrData;
getRawData(denseAttr, attrData);
// Call stickify.
zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc;
// pre-transformed desc.
zdnn_data_layouts zDNNLayout =
convertLayoutAttrToZDNNDataLayout(rank, layout);
// If zDNNLayout is NHWC, we stickify directly from NCHW.
if (zDNNLayout == ZDNN_NHWC)
zDNNLayout = ZDNN_NCHW;
zdnn_data_types zDNNType = mlirTypeToZDNNType(elementType);
set_info_pre_transformed_desc(
&pre_tfrmd_desc, zDNNLayout, zDNNType, shape);
// transformed desc.
zdnn_status status =
generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc);
assert(status == ZDNN_OK);
// Stick data using the software stickify.
zdnn_ztensor ztensor;
init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor);
status = allochelper_ztensor_alloc(&ztensor);
assert(status == ZDNN_OK);
status = stickify(&ztensor, attrData.data());
assert(status == ZDNN_OK);
std::vector<char>().swap(attrData);
zlowStickifiedConstantOp.removeValueAttr();
int64_t sizeInBytes = ztensor.buffer_size;
char *rawData = (char *)malloc(sizeInBytes);
memcpy(rawData, ztensor.buffer, sizeInBytes);
ret = llvm::ArrayRef(rawData, sizeInBytes);
allochelper_ztensor_free(&ztensor);
})
.Case<DenseResourceElementsAttr>(
[&](DenseResourceElementsAttr denseResourceAttr) {
ArrayRef<char> attrData =
denseResourceAttr.getRawHandle().getBlob()->getData();
int64_t sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(
zlowStickifiedConstantOp.getResult().getType())
.value();
char *rawData = (char *)malloc(sizeInBytes);
memcpy(rawData, attrData.data(), sizeInBytes);
ret = llvm::ArrayRef(rawData, sizeInBytes);
})
.Default([&](Attribute attr) {
llvm_unreachable("Unsupported data type.");
});
} else if (zlowStickifiedConstantOp.getZeroconstAttr()) {
int64_t sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(
zlowStickifiedConstantOp.getResult().getType())
.value();
char *rawData = (char *)malloc(sizeInBytes);
memset(rawData, 0, sizeInBytes);
ret = llvm::ArrayRef(rawData, sizeInBytes);
if (!zlowStickifiedConstantOp.getStickified().value()) {
// The case which the data in value attribute is still not stickified.
DenseElementsAttr denseAttr = mlir::cast<DenseElementsAttr>(dataAttr);
ArrayRef<int64_t> shape = denseAttr.getType().getShape();
Type elementType = denseAttr.getType().getElementType();
int rank = shape.size();
// Read attributes's raw data.
std::vector<char> attrData;
getRawData(denseAttr, attrData);
// Call stickify.
zdnn_tensor_desc pre_tfrmd_desc, tfrmd_desc;
// pre-transformed desc.
zdnn_data_layouts zDNNLayout =
convertLayoutAttrToZDNNDataLayout(rank, layout);
// If zDNNLayout is NHWC, we stickify directly from NCHW.
if (zDNNLayout == ZDNN_NHWC)
zDNNLayout = ZDNN_NCHW;
zdnn_data_types zDNNType = mlirTypeToZDNNType(elementType);
set_info_pre_transformed_desc(
&pre_tfrmd_desc, zDNNLayout, zDNNType, shape);
// transformed desc.
zdnn_status status =
generate_transformed_desc(&pre_tfrmd_desc, &tfrmd_desc);
assert(status == ZDNN_OK);
// Stick data using the software stickify.
zdnn_ztensor ztensor;
init_ztensor(&pre_tfrmd_desc, &tfrmd_desc, &ztensor);
status = allochelper_ztensor_alloc(&ztensor);
assert(status == ZDNN_OK);
status = stickify(&ztensor, attrData.data());
assert(status == ZDNN_OK);
std::vector<char>().swap(attrData);
int64_t sizeInBytes = ztensor.buffer_size;
char *rawData = (char *)malloc(sizeInBytes);
memcpy(rawData, ztensor.buffer, sizeInBytes);
ret = llvm::ArrayRef(rawData, sizeInBytes);
allochelper_ztensor_free(&ztensor);
} else {
DenseResourceElementsAttr denseResourceAttr =
mlir::cast<DenseResourceElementsAttr>(dataAttr);
ArrayRef<char> attrData =
denseResourceAttr.getRawHandle().getBlob()->getData();
int64_t sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(
zlowStickifiedConstantOp.getResult().getType())
.value();
char *rawData = (char *)malloc(sizeInBytes);
memcpy(rawData, attrData.data(), sizeInBytes);
ret = llvm::ArrayRef(rawData, sizeInBytes);
}
zlowStickifiedConstantOp.removeValueAttr();
zlowStickifiedConstantOp.removeStickifiedAttr();
} else if (auto zeroConst = zlowStickifiedConstantOp.getZeroconstAttr()) {
if (zeroConst) {
int64_t sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(
zlowStickifiedConstantOp.getResult().getType())
.value();
char *rawData = (char *)malloc(sizeInBytes);
memset(rawData, 0, sizeInBytes);
ret = llvm::ArrayRef(rawData, sizeInBytes);
}
zlowStickifiedConstantOp.removeZeroconstAttr();
}
return ret;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter,
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, outputType,
/*value=*/nullptr,
/*stickfied=*/BoolAttr(),
/*zeroconst=*/BoolAttr(),
/*stickfied=*/nullptr,
/*zeroconst=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
Expand All @@ -85,6 +85,7 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter,
llvm::ArrayRef((char *)ztensor->buffer, sizeInBytes), alignof(char)));

stickifiedConstant.setValueAttr(valueAttr);
stickifiedConstant.setStickifiedAttr(rewriter.getBoolAttr(true));

return stickifiedConstant;
}
Expand All @@ -100,8 +101,8 @@ ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter,
ZHighStickifiedConstantOp constantOp =
rewriter.create<ZHighStickifiedConstantOp>(loc, replacingValue.getType(),
/*value=*/dataAttr,
/*stickified=*/BoolAttr(),
/*zeroconst=*/BoolAttr(),
/*stickified=*/rewriter.getBoolAttr(false),
/*zeroconst=*/nullptr,
/*alignment=*/rewriter.getI64IntegerAttr(4096));
return constantOp;
}
Expand Down

0 comments on commit 9e236c9

Please sign in to comment.