Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug where emitc constants wouldn't be directly emitted in subscripts. #411

Merged
merged 10 commits into from
Dec 13, 2024
Merged
74 changes: 42 additions & 32 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ struct CppEmitter {
return operandExpression == emittedExpression;
};

/// Determine whether expression \p expressionOp should be emitted inline,
/// i.e. as part of its user. This function recommends inlining of any
/// expressions that can be inlined unless it is used by another expression,
/// under the assumption that any expression fusion/re-materialization was
/// taken care of by transformations run by the backend.
bool shouldBeInlined(ExpressionOp expressionOp);

/// This emitter will only emit translation units whos id matches this value.
StringRef willOnlyEmitTu() { return onlyTu; }

private:
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
Expand Down Expand Up @@ -297,21 +307,22 @@ struct CppEmitter {
return lowestPrecedence();
return emittedExpressionPrecedence.back();
}

/// Determine whether expression \p op should be emitted in a deferred way.
bool hasDeferredEmission(Operation *op);
};
} // namespace

/// Determine whether expression \p op should be emitted in a deferred way.
static bool hasDeferredEmission(Operation *op) {
bool CppEmitter::hasDeferredEmission(Operation *op) {
if (llvm::isa_and_nonnull<emitc::ConstantOp>(op)) {
return !shouldUseConstantsAsVariables();
}

return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
}

/// Determine whether expression \p expressionOp should be emitted inline, i.e.
/// as part of its user. This function recommends inlining of any expressions
/// that can be inlined unless it is used by another expression, under the
/// assumption that any expression fusion/re-materialization was taken care of
/// by transformations run by the backend.
static bool shouldBeInlined(ExpressionOp expressionOp) {
bool CppEmitter::shouldBeInlined(ExpressionOp expressionOp) {
// Do not inline if expression is marked as such.
if (expressionOp.getDoNotInline())
return false;
Expand Down Expand Up @@ -373,6 +384,25 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConstantOp constantOp) {
if (!emitter.shouldUseConstantsAsVariables()) {
std::string out;
llvm::raw_string_ostream ss(out);

/// Temporary emitter object that writes to our stream instead of the output
/// allowing for the capture and caching of the produced string.
CppEmitter sniffer = CppEmitter(ss, emitter.shouldDeclareVariablesAtTop(),
emitter.willOnlyEmitTu(),
emitter.shouldUseConstantsAsVariables());

ss << "(";
if (failed(sniffer.emitType(constantOp.getLoc(), constantOp.getType())))
return failure();
ss << ") ";

if (failed(
sniffer.emitAttribute(constantOp.getLoc(), constantOp.getValue())))
return failure();

emitter.cacheDeferredOpResult(constantOp.getResult(), out);
return success();
}

Expand Down Expand Up @@ -838,7 +868,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ExpressionOp expressionOp) {
if (shouldBeInlined(expressionOp))
if (emitter.shouldBeInlined(expressionOp))
return success();

Operation &op = *expressionOp.getOperation();
Expand Down Expand Up @@ -892,7 +922,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
if (!expressionOp)
return false;
return shouldBeInlined(expressionOp);
return emitter.shouldBeInlined(expressionOp);
};

os << "for (";
Expand Down Expand Up @@ -1114,7 +1144,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
(isa<emitc::ExpressionOp>(op) &&
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
emitter.shouldBeInlined(cast<emitc::ExpressionOp>(op))))
return WalkResult::skip();
for (OpResult result : op->getResults()) {
if (failed(emitter.emitVariableDeclaration(
Expand Down Expand Up @@ -1494,22 +1524,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {

LogicalResult CppEmitter::emitOperand(Value value) {
Operation *def = value.getDefiningOp();
if (!shouldUseConstantsAsVariables()) {
if (auto constant = dyn_cast_if_present<ConstantOp>(def)) {
os << "((";

if (failed(emitType(constant.getLoc(), constant.getType()))) {
return failure();
}
os << ") ";

if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) {
return failure();
}
os << ")";
return success();
}
}

if (isPartOfCurrentExpression(value)) {
assert(def && "Expected operand to be defined by an operation");
Expand Down Expand Up @@ -1721,11 +1735,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
cacheDeferredOpResult(op.getResult(), op.getValue());
return success();
})
.Case<emitc::MemberOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
return success();
})
.Case<emitc::MemberOfPtrOp>([&](auto op) {
.Case<emitc::MemberOp, emitc::MemberOfPtrOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), createMemberAccess(op));
return success();
})
Expand Down
50 changes: 48 additions & 2 deletions mlir/test/Target/Cpp/emitc-constants-as-variables.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,55 @@ func.func @test() {

return
}
// CPP-DEFAULT-LABEL: void test() {
// CPP-DEFAULT-NEXT: for (size_t v1 = (size_t) 0; v1 < (size_t) 10; v1 += (size_t) 1) {
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
// CPP-DEFAULT-NEXT: }

// -----

func.func @test_subscript(%arg0: !emitc.array<4xf32>) -> (!emitc.lvalue<f32>) {
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%0 = emitc.subscript %arg0[%c0] : (!emitc.array<4xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
return %0 : !emitc.lvalue<f32>
}
// CPP-DEFAULT-LABEL: float test_subscript(float v1[4]) {
// CPP-DEFAULT-NEXT: return v1[(size_t) 0];
// CPP-DEFAULT-NEXT: }

// -----

// CPP-DEFAULT: void test() {
// CPP-DEFAULT-NEXT: for (size_t v1 = ((size_t) 0); v1 < ((size_t) 10); v1 += ((size_t) 1)) {
func.func @emitc_switch_ui64() {
%0 = "emitc.constant"(){value = 1 : ui64} : () -> ui64

emitc.switch %0 : ui64
default {
emitc.call_opaque "func2" (%0) : (ui64) -> ()
emitc.yield
}
return
}
// CPP-DEFAULT-LABEL: void emitc_switch_ui64() {
// CPP-DEFAULT: switch ((uint64_t) 1) {
// CPP-DEFAULT-NEXT: default: {
// CPP-DEFAULT-NEXT: func2((uint64_t) 1);
// CPP-DEFAULT-NEXT: break;
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
// CPP-DEFAULT-NEXT: }

// -----

func.func @negative_values() {
%1 = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t
%2 = "emitc.constant"() <{value = -3000000000 : index}> : () -> !emitc.ssize_t

%3 = emitc.add %1, %2 : (!emitc.size_t, !emitc.ssize_t) -> !emitc.ssize_t

return
}
// CPP-DEFAULT-LABEL: void negative_values() {
// CPP-DEFAULT-NEXT: ssize_t v1 = (size_t) 10 + (ssize_t) -3000000000;
// CPP-DEFAULT-NEXT: return;
// CPP-DEFAULT-NEXT: }
Loading