Skip to content

Commit

Permalink
EmitC: Add emitc.global and emitc.get_global (#145)
Browse files Browse the repository at this point in the history
* EmitC: Add emitc.global and emitc.get_global

This adds
- `emitc.global` and `emitc.get_global` ops to model global variables
similar to how `memref.global` and `memref.get_global` work.
- translation of those ops to C++
- lowering of `memref.global` and `memref.get_global` into those ops
  • Loading branch information
mgehre-amd authored Mar 25, 2024
1 parent 13e8abe commit c98ce04
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 15 deletions.
69 changes: 69 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
let hasVerifier = 1;
}

def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
let summary = "A global variable";
let description = [{
The `emitc.global` operation declares or defines a named global variable.
The backing memory for the variable is allocated statically and is
described by the type of the variable.
Optionally, and `initial_value` can be provided.
Internal linkage can be specified using the `staticSpecifier` unit attribute
and external linkage can be specified using the `externSpecifier` unit attribute.
Note that the default linkage without those two keywords depends on whether
the target is C or C++ and whether the global variable is `const`.
The global variable can also be marked constant using the `constSpecifier`
unit attribute. Writing to such constant global variables is
undefined.

The global variable can be accessed by using the `emitc.get_global` to
retrieve the value for the global variable.

Example:

```mlir
// Global variable with an initial value.
emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0>
// External global variable
emitc.global extern @x : emitc.array<2xf32>
// Constant global variable with internal linkage
emitc.global static const @x : i32 = 0
```
}];

let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttr:$type,
OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
UnitAttr:$externSpecifier,
UnitAttr:$staticSpecifier,
UnitAttr:$constSpecifier);

let assemblyFormat = [{
(`extern` $externSpecifier^)?
(`static` $staticSpecifier^)?
(`const` $constSpecifier^)?
$sym_name
`:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
attr-dict
}];

let hasVerifier = 1;
}

def EmitC_GetGlobalOp : EmitC_Op<"get_global",
[Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Obtain access to a global variable";
let description = [{
The `emitc.get_global` operation retrieves the lvalue of a
named global variable. If the global variable is marked constant, assigning
to that lvalue is undefined.

Example:

```mlir
%x = emitc.get_global @foo : !emitc.array<2xf32>
```
}];

let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs AnyType:$result);
let assemblyFormat = "$name `:` type($result) attr-dict";
}

def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
let summary = "Verbatim operation";
let description = [{
Expand Down
66 changes: 64 additions & 2 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,68 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};

struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform global with dynamic shape");
}

if (op.getAlignment().value_or(1) > 1) {
// TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
return rewriter.notifyMatchFailure(
op.getLoc(), "global variable with alignment requirement is "
"currently not supported");
}
auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}

SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
if (visibility != SymbolTable::Visibility::Public &&
visibility != SymbolTable::Visibility::Private) {
return rewriter.notifyMatchFailure(
op.getLoc(),
"only public and private visibility is currently supported");
}
// We are explicit in specifier the linkage because the default linkage
// for constants is different in C and C++.
bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
bool externSpecifier = !staticSpecifier;

rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
op, operands.getSymName(), resultTy, operands.getInitialValueAttr(),
externSpecifier, staticSpecifier, operands.getConstant());
return success();
}
};

struct ConvertGetGlobal final
: public OpConversionPattern<memref::GetGlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cannot convert result type");
}
rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
operands.getNameAttr());
return success();
}
};

struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -109,6 +171,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {

void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
patterns.getContext());
patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
ConvertStore>(converter, patterns.getContext());
}
117 changes: 110 additions & 7 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,13 +790,6 @@ LogicalResult emitc::SubscriptOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"

//===----------------------------------------------------------------------===//
// EmitC Enums
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -896,3 +889,113 @@ LogicalResult mlir::emitc::OpaqueType::verify(
}
return success();
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
TypeAttr type,
Attribute initialValue) {
p << type;
if (initialValue) {
p << " = ";
p.printAttributeWithoutType(initialValue);
}
}

static Type getInitializerTypeForGlobal(Type type) {
if (auto array = llvm::dyn_cast<ArrayType>(type))
return RankedTensorType::get(array.getShape(), array.getElementType());
return type;
}

static ParseResult
parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
Attribute &initialValue) {
Type type;
if (parser.parseType(type))
return failure();

typeAttr = TypeAttr::get(type);

if (parser.parseOptionalEqual())
return success();

if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
return failure();

if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr>(initialValue))
return parser.emitError(parser.getNameLoc())
<< "initial value should be a unit, integer, float or elements "
"attribute";
return success();
}

LogicalResult GlobalOp::verify() {
// Verify that the initial value, if present, is either a unit attribute or
// an elements attribute.
if (getInitialValue().has_value()) {
Attribute initValue = getInitialValue().value();
// Check that the type of the initial value is compatible with the type of
// the global variable.
if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
auto arrayType = llvm::dyn_cast<ArrayType>(getType());
if (!arrayType)
return emitOpError("expected array type, but got ") << getType();

Type initType = elementsAttr.getType();
Type tensorType = getInitializerTypeForGlobal(getType());
if (initType != tensorType) {
return emitOpError("initial value expected to be of type ")
<< getType() << ", but was of type " << initType;
}
} else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
if (intAttr.getType() != getType()) {
return emitOpError("initial value expected to be of type ")
<< getType() << ", but was of type " << intAttr.getType();
}
} else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
if (floatAttr.getType() != getType()) {
return emitOpError("initial value expected to be of type ")
<< getType() << ", but was of type " << floatAttr.getType();
}
} else {
return emitOpError(
"initial value should be a unit, integer, float or elements "
"attribute, but got ")
<< initValue;
}
}
if (getStaticSpecifier() && getExternSpecifier()) {
return emitOpError("cannot have both static and extern specifiers");
}
return success();
}

//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//

LogicalResult
GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// Verify that the type matches the type of the global variable.
auto global =
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
if (!global)
return emitOpError("'")
<< getName() << "' does not reference a valid emitc.global";

Type resultType = getResult().getType();
if (global.getType() != resultType)
return emitOpError("result type ")
<< resultType << " does not match type " << global.getType()
<< " of the global @" << getName();
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
55 changes: 49 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ struct CppEmitter {
/// any result type could not be converted.
LogicalResult emitAssignPrefix(Operation &op);

/// Emits a global variable declaration or definition.
LogicalResult emitGlobalVariable(GlobalOp op);

/// Emits a label for the block.
LogicalResult emitLabel(Block &block);

Expand Down Expand Up @@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::GlobalOp globalOp) {

return emitter.emitGlobalVariable(globalOp);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
Expand All @@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return emitter.emitOperand(assignOp.getValue());
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::GetGlobalOp op) {
// Add name to cache so that `hasValueInScope` works.
emitter.getOrCreateName(op.getResult());
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::SubscriptOp subscriptOp) {
// Add name to cache so that `hasValueInScope` works.
Expand Down Expand Up @@ -1120,6 +1136,9 @@ StringRef CppEmitter::getOrCreateName(Value val) {
if (auto subscript =
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
valueMapper.insert(val, getSubscriptName(subscript));
} else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
val.getDefiningOp())) {
valueMapper.insert(val, getGlobal.getName().str());
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
Expand Down Expand Up @@ -1385,6 +1404,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return success();
}

LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
if (op.getExternSpecifier())
os << "extern ";
else if (op.getStaticSpecifier())
os << "static ";
if (op.getConstSpecifier())
os << "const ";

if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
op.getSymName()))) {
return failure();
}

std::optional<Attribute> initialValue = op.getInitialValue();
if (initialValue && !isa<UnitAttr>(*initialValue)) {
os << " = ";
if (failed(emitAttribute(op->getLoc(), *initialValue)))
return failure();
}

os << ";";
return success();
}

LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
// If op is being emitted as part of an expression, bail out.
if (getEmittedExpression())
Expand Down Expand Up @@ -1445,11 +1488,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
Expand All @@ -1462,7 +1505,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();

if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
return success();

if (getEmittedExpression() ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ func.func @zero_rank() {
%0 = memref.alloca() : memref<f32>
return
}

// -----

// expected-error@+1 {{failed to legalize operation 'memref.global'}}
memref.global "nested" constant @nested_global : memref<3x7xf32>
Loading

0 comments on commit c98ce04

Please sign in to comment.