Skip to content

Commit

Permalink
added helper function
Browse files Browse the repository at this point in the history
  • Loading branch information
gitoleg committed Dec 10, 2024
1 parent 6bf3852 commit 892b92e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 8 deletions.
23 changes: 15 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,18 @@ static cir::CIRCallOpInterface emitCallLikeOp(
extraFnAttrs);
}

static RValue getRValueThroughMemory(mlir::Location loc,
CIRGenBuilderTy &builder,
mlir::Value val,
Address addr) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointAfterValue(val);
builder.createStore(loc, val, addr);
builder.restoreInsertionPoint(ip);
auto load = builder.createLoad(loc, addr);
return RValue::get(load);
}

RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
const CIRGenCallee &Callee,
ReturnValueSlot ReturnValue,
Expand Down Expand Up @@ -890,19 +902,14 @@ RValue CIRGenFunction::emitCall(const CIRGenFunctionInfo &CallInfo,
assert(Results.size() <= 1 && "multiple returns NYI");
assert(Results[0].getType() == RetCIRTy && "Bitcast support NYI");

auto reg = builder.getBlock()->getParent();
if (reg != theCall->getParentRegion()) {
auto region = builder.getBlock()->getParent();
if (region != theCall->getParentRegion()) {
Address DestPtr = ReturnValue.getValue();

if (!DestPtr.isValid())
DestPtr = CreateMemTemp(RetTy, callLoc, "tmp");

auto ip = builder.saveInsertionPoint();
builder.setInsertionPointAfter(theCall);
builder.createStore(callLoc, Results[0], DestPtr);
builder.restoreInsertionPoint(ip);
auto load = builder.createLoad(callLoc, DestPtr);
return RValue::get(load);
return getRValueThroughMemory(callLoc, builder, Results[0], DestPtr);
}

return RValue::get(Results[0]);
Expand Down
44 changes: 44 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"

#include <iostream>

using namespace mlir;
using namespace cir;

Expand Down Expand Up @@ -910,6 +912,42 @@ void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
patterns.getContext());
}

void removeTempAllocas(DominanceInfo& dom, FuncOp fun) {

fun.walk([&](AllocaOp op) {
if (op.getName().str().find("tmp") == std::string::npos)
return;

StoreOp store;
LoadOp load;
int total = 0;

for (auto* u : op->getUsers()) {
total++;
if (auto ld = dyn_cast<LoadOp>(u))
load = ld;
if (auto st = dyn_cast<StoreOp>(u))
if (st.getAddr() == op.getResult())
store = st;
}

if (total == 2 && load && store && dom.dominates(store, load)) {
if (load->hasOneUse()) {
if (auto st = dyn_cast<StoreOp>(*load->user_begin())) {
if (auto al = dyn_cast<AllocaOp>(st.getAddr().getDefiningOp())) {
llvm::SmallVector<mlir::Value> vals;
vals.push_back(al.getResult());
op->replaceAllUsesWith(vals);
op->erase();
}
}
}
}

});

}

void FlattenCFGPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateFlattenCFGPatterns(patterns);
Expand All @@ -924,6 +962,12 @@ void FlattenCFGPass::runOnOperation() {
// Apply patterns.
if (applyOpPatternsAndFold(ops, std::move(patterns)).failed())
signalPassFailure();

auto &dom = getAnalysis<DominanceInfo>();

getOperation()->walk([&](FuncOp fun) {
removeTempAllocas(dom, fun);
});
}

} // namespace
Expand Down

0 comments on commit 892b92e

Please sign in to comment.