Skip to content

Commit

Permalink
Fix the case without setting --store-constants-to-file option.
Browse files Browse the repository at this point in the history
Signed-off-by: Haruki Imai <imaihal@jp.ibm.com>
  • Loading branch information
imaihal committed Aug 29, 2024
1 parent 4cb46dd commit 434272a
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 35 deletions.
4 changes: 3 additions & 1 deletion src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ Value insertAllocOrEmitZeroConstant(ArrayRef<IndexExpr> dims,
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, resType,
/*value=*/nullptr,
/*layout*/ StringAttr(),
/*layout=*/StringAttr(),
/*init=*/rewriter.getBoolAttr(true),
/*alignment=*/rewriter.getI64IntegerAttr(4096));
res = stickifiedConstant.getResult();
} else {
Expand Down Expand Up @@ -710,6 +711,7 @@ struct ZHighToZLowStickifiedConstantOpLowering : public ConversionPattern {
rewriter.getStringAttr(
"constant_stickify_" + std::to_string(constantID)),
/*layout=*/zhighStickifiedConstOp.getLayoutAttr(),
/*init=*/zhighStickifiedConstOp.getInitAttr(),
/*offset=*/rewriter.getI64IntegerAttr(0),
/*alignment=*/zhighStickifiedConstOp.getAlignmentAttr());

Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def ZHighStickifiedConstantOp:ZHigh_Op<"StickifiedConstant", [Pure]> {
}];
let arguments = (ins OptionalAttr<AnyAttr>:$value,
OptionalAttr<StrAttr>:$layout,
OptionalAttr<BoolAttr>:$init,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs AnyZTensor:$output);
}
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def ZLowStickifiedConstantOp:ZLow_Op<"stickifiedConstant", [MemRefsNormalizable,
OptionalAttr<AnyAttr>:$value,
StrAttr:$name,
OptionalAttr<StrAttr>:$layout,
OptionalAttr<BoolAttr>:$init,
OptionalAttr<I64Attr>:$offset,
DefaultValuedAttr<I64Attr, "4096">:$alignment);
let results = (outs ZMemRef:$output);
Expand Down
2 changes: 1 addition & 1 deletion src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ ArrayRef<char> ZLowStickifiedConstantOp::getBuffer() {
.Default([&](Attribute attr) {
llvm_unreachable("Unsupported data type.");
});
} else {
} else if (zlowStickifiedConstantOp.getInitAttr()) {
int64_t sizeInBytes = affine::getIntOrFloatMemRefSizeInBytes(
zlowStickifiedConstantOp.getResult().getType())
.value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ ZHighStickifiedConstantOp emitZHighStickifiedConstant(PatternRewriter &rewriter,
ZHighStickifiedConstantOp stickifiedConstant =
rewriter.create<ZHighStickifiedConstantOp>(loc, outputType,
/*value=*/nullptr,
/*layout*/ StringAttr(),
/*layout=*/StringAttr(),
/*init=*/BoolAttr(),
/*alignment=*/rewriter.getI64IntegerAttr(4096));

// Use an dense resource attribute to store stickified data.
Expand Down Expand Up @@ -100,6 +101,7 @@ ZHighStickifiedConstantOp createConstantForStick(PatternRewriter &rewriter,
rewriter.create<ZHighStickifiedConstantOp>(loc, replacingValue.getType(),
/*value=*/dataAttr,
/*layout=*/layout,
/*init=*/BoolAttr(),
/*alignment=*/rewriter.getI64IntegerAttr(4096));
return constantOp;
}
Expand Down
5 changes: 5 additions & 0 deletions src/Conversion/KrnlToLLVM/ConstantOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ class ConstantOpInterfaceLowering
LLVM::GlobalOp global;
// Pointer to the raw data of the global.
Value dataPtr;
ArrayRef<char> rawData = op.getBuffer();
if (!rawData.empty()) {
op.setBuffer(rawData);
op.freeBuffer(rawData);
}

if (op.getValue().has_value()) {
auto value = op.getValue().value();
Expand Down
20 changes: 4 additions & 16 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,34 +488,22 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath,

// Get raw data from DenseElementsAttr or DenseResourceElementsAttr.
uint64_t bufferSize = op.getBufferSize();
if (bufferSize <= singleThreshold) {
ArrayRef<char> rawData = op.getBuffer();
op.setBuffer(rawData);
if (bufferSize <= singleThreshold)
return WalkResult::advance();
}

if (op.getValueAttr()) {
auto valueAttr = mlir::cast<ElementsAttr>(op.getValue().value());
if (valueAttr.isSplat()) {
ArrayRef<char> rawData = op.getBuffer();
op.setBuffer(rawData);
if (valueAttr.isSplat())
return WalkResult::advance();
}
}
globalOfInterest.emplace_back(op);
totalSize += bufferSize;
return WalkResult::advance();
});
// Do not use file if the total size of satisfied constants is <=
// totalThreshold.
if (totalSize <= totalThreshold) {
// Set buffer before return.
for (int64_t i = globalOfInterest.size() - 1; i >= 0; --i) {
ConstantOpInterface op = globalOfInterest[i];
ArrayRef<char> rawData = op.getBuffer();
op.setBuffer(rawData);
}
if (totalSize <= totalThreshold)
return false;
}

// Sort constants in the non-descending order of alignment values.
// Non-alignment is the smallest value (-1), the others are positive.
Expand Down
36 changes: 20 additions & 16 deletions src/Dialect/Krnl/KrnlOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,22 +805,26 @@ MutableOperandRange KrnlSpecializedKernel::getLoopRefs() {

ArrayRef<char> getRawData(KrnlGlobalOp &op) {
ArrayRef<char> rawData;
assert(op.getValue().has_value() && "Krnl Global must always have a value");
auto value = op.getValue().value();
TypeSwitch<Attribute>(value)
.Case<DenseResourceElementsAttr>([&](DenseResourceElementsAttr attr) {
auto blob = mlir::cast<DenseResourceElementsAttr>(value)
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
rawData = blob->getData();
})
.Case<DenseElementsAttr>([&](DenseElementsAttr attr) {
DenseElementsAttr denseAttr =
mlir::dyn_cast_or_null<DenseElementsAttr>(value);
rawData = denseAttr.getRawData();
})
.Default([&](Attribute attr) { return; });
// llvm::dbgs() << "getRawData op:" << op << "\n";
// assert(op.getValue().has_value() && "Krnl Global must always have a
// value");
if (op.getValueAttr()) {
auto value = op.getValue().value();
TypeSwitch<Attribute>(value)
.Case<DenseResourceElementsAttr>([&](DenseResourceElementsAttr attr) {
auto blob = mlir::cast<DenseResourceElementsAttr>(value)
.getRawHandle()
.getBlob();
assert(blob && "Expecting dense resource with a valid blob");
rawData = blob->getData();
})
.Case<DenseElementsAttr>([&](DenseElementsAttr attr) {
DenseElementsAttr denseAttr =
mlir::dyn_cast_or_null<DenseElementsAttr>(value);
rawData = denseAttr.getRawData();
})
.Default([&](Attribute attr) { return; });
}
return rawData;
}

Expand Down

0 comments on commit 434272a

Please sign in to comment.