From a1eb154421a00d62f3a25057d262e1cac747e266 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 14 Dec 2020 08:32:31 +0100 Subject: [PATCH 01/12] [flang] Use mlir::OpState::operator->() to get to methods of mlir::Operation. This is a preparation step to remove those methods from OpState. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D93194 --- .../include/flang/Optimizer/Dialect/FIROps.td | 83 ++++++++++--------- .../flang/Optimizer/Dialect/FIROpsSupport.h | 2 +- flang/lib/Lower/CharacterRuntime.cpp | 2 +- flang/lib/Lower/IO.cpp | 4 +- flang/lib/Lower/IntrinsicCall.cpp | 4 +- flang/lib/Lower/OpenACC.cpp | 40 ++++----- flang/lib/Optimizer/Dialect/FIROps.cpp | 80 ++++++++++-------- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp | 2 +- mlir/test/mlir-tblgen/op-attribute.td | 6 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 5 +- mlir/unittests/Pass/PassManagerTest.cpp | 14 ++-- 11 files changed, 126 insertions(+), 116 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 4ffca03958042d..8d7a6d4af95076 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -249,7 +249,7 @@ class fir_AllocatableOp traits = []> : }]; let printer = [{ - p << getOperationName() << ' ' << getAttr(inType()); + p << getOperationName() << ' ' << (*this)->getAttr(inType()); if (hasLenParams()) { // print the LEN parameters to a derived type in parens p << '(' << getLenParams() << " : " << getLenParams().getTypes() << ')'; @@ -267,7 +267,7 @@ class fir_AllocatableOp traits = []> : static constexpr llvm::StringRef lenpName() { return "len_param_count"; } mlir::Type getAllocatedType(); - bool hasLenParams() { return bool{getAttr(lenpName())}; } + bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } unsigned numLenParams() { if (auto val = (*this)->getAttrOfType(lenpName())) @@ -688,7 +688,7 @@ class fir_IntegralSwitchTerminatorOp(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) @@ -711,7 +711,7 @@ class fir_IntegralSwitchTerminatorOp() || getSelector().getType().isa())) return emitOpError("must be an integer"); - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); @@ -810,7 +810,7 @@ def fir_SelectCaseOp : fir_SwitchTerminatorOp<"select_case"> { p << getOperationName() << ' '; p.printOperand(getSelector()); p << " : " << getSelector().getType() << " ["; - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) @@ -839,7 +839,7 @@ def fir_SelectCaseOp : fir_SwitchTerminatorOp<"select_case"> { getSelector().getType().isa() || getSelector().getType().isa())) return emitOpError("must be an integer, character, or logical"); - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); @@ -925,7 +925,7 @@ def fir_SelectTypeOp : fir_SwitchTerminatorOp<"select_type"> { p << getOperationName() << ' '; p.printOperand(getSelector()); p << " : " << getSelector().getType() << " ["; - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumConditions(); for (decltype(count) i = 0; i != count; ++i) { if (i) @@ -941,7 +941,7 @@ def fir_SelectTypeOp : fir_SwitchTerminatorOp<"select_type"> { let verifier = [{ if (!(getSelector().getType().isa())) return emitOpError("must be a boxed type"); - auto cases = getAttrOfType(getCasesAttr()).getValue(); + auto cases = (*this)->getAttrOfType(getCasesAttr()).getValue(); auto count = getNumDest(); if (count == 0) return emitOpError("must have at least one successor"); @@ -1056,7 +1056,7 @@ def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> { if (getNumOperands() == 2) { p << ", "; p.printOperands(dims()); - } else if (auto map = getAttr(layoutName())) { + } else if (auto map = (*this)->getAttr(layoutName())) { p << " [" << map << ']'; } p.printOptionalAttrDict(getAttrs(), {layoutName(), lenpName()}); @@ -1097,9 +1097,9 @@ def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> { let extraClassDeclaration = [{ static constexpr llvm::StringRef layoutName() { return "layout_map"; } static constexpr llvm::StringRef lenpName() { return "len_param_count"; } - bool hasLenParams() { return bool{getAttr(lenpName())}; } + bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } unsigned numLenParams() { - if (auto x = getAttrOfType(lenpName())) + if (auto x = (*this)->getAttrOfType(lenpName())) return x.getInt(); return 0; } @@ -1213,13 +1213,13 @@ def fir_EmboxProcOp : fir_Op<"emboxproc", [NoSideEffect]> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr("funcname"); + p << getOperationName() << ' ' << (*this)->getAttr("funcname"); auto h = host(); if (h) { p << ", "; p.printOperand(h); } - p << " : (" << getAttr("functype"); + p << " : (" << (*this)->getAttr("functype"); if (h) p << ", " << h.getType(); p << ") -> " << getType(); @@ -1587,7 +1587,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { if (!ref().getType().dyn_cast()) return emitOpError("len_param_index must be used on box type"); } - if (auto attr = getAttr(CoordinateOp::baseType())) { + if (auto attr = (*this)->getAttr(CoordinateOp::baseType())) { if (!attr.isa()) return emitOpError("improperly constructed"); } else { @@ -1690,8 +1690,8 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> { let printer = [{ p << getOperationName() << ' ' - << getAttrOfType(fieldAttrName()).getValue() << ", " - << getAttr(typeAttrName()); + << (*this)->getAttrOfType(fieldAttrName()).getValue() + << ", " << (*this)->getAttr(typeAttrName()); if (getNumOperands()) { p << '('; p.printOperands(lenparams()); @@ -1826,8 +1826,8 @@ def fir_LenParamIndexOp : fir_OneResultOp<"len_param_index", [NoSideEffect]> { let printer = [{ p << getOperationName() << ' ' - << getAttrOfType(fieldAttrName()).getValue() << ", " - << getAttr(typeAttrName()); + << (*this)->getAttrOfType(fieldAttrName()).getValue() + << ", " << (*this)->getAttr(typeAttrName()); }]; let builders = [ @@ -1841,7 +1841,7 @@ def fir_LenParamIndexOp : fir_OneResultOp<"len_param_index", [NoSideEffect]> { static constexpr llvm::StringRef fieldAttrName() { return "field_id"; } static constexpr llvm::StringRef typeAttrName() { return "on_type"; } mlir::Type getOnType() { - return getAttrOfType(typeAttrName()).getValue(); + return (*this)->getAttrOfType(typeAttrName()).getValue(); } }]; } @@ -2166,7 +2166,7 @@ def fir_DispatchOp : fir_Op<"dispatch", }]; let printer = [{ - p << getOperationName() << ' ' << getAttr("method") << '('; + p << getOperationName() << ' ' << (*this)->getAttr("method") << '('; p.printOperand(object()); if (arg_operand_begin() != arg_operand_end()) { p << ", "; @@ -2250,7 +2250,7 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> { auto eleTy = getType().cast().getEleTy(); if (!eleTy.isa()) return emitOpError("must have !fir.char type"); - if (auto xl = getAttr(xlist())) { + if (auto xl = (*this)->getAttr(xlist())) { auto xList = xl.cast(); for (auto a : xList) if (!a.isa()) @@ -2265,12 +2265,12 @@ def fir_StringLitOp : fir_Op<"string_lit", [NoSideEffect]> { static constexpr const char *xlist() { return "xlist"; } // Get the LEN attribute of this character constant - mlir::Attribute getSize() { return getAttr(size()); } + mlir::Attribute getSize() { return (*this)->getAttr(size()); } // Get the string value of this character constant mlir::Attribute getValue() { - if (auto attr = getAttr(value())) + if (auto attr = (*this)->getAttr(value())) return attr; - return getAttr(xlist()); + return (*this)->getAttr(xlist()); } /// Is this a wide character literal (1 character > 8 bits) @@ -2381,7 +2381,7 @@ def fir_CmpfOp : fir_Op<"cmpf", static CmpFPredicate getPredicateByName(llvm::StringRef name); CmpFPredicate getPredicate() { - return (CmpFPredicate)getAttrOfType( + return (CmpFPredicate)(*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -2415,11 +2415,11 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { let printer = [{ p << getOperationName() << " (0x"; - auto f1 = getAttr(realAttrName()).cast(); + auto f1 = (*this)->getAttr(realAttrName()).cast(); auto i1 = f1.getValue().bitcastToAPInt(); p.getStream().write_hex(i1.getZExtValue()); p << ", 0x"; - auto f2 = getAttr(imagAttrName()).cast(); + auto f2 = (*this)->getAttr(imagAttrName()).cast(); auto i2 = f2.getValue().bitcastToAPInt(); p.getStream().write_hex(i2.getZExtValue()); p << ") : "; @@ -2436,8 +2436,8 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { static constexpr llvm::StringRef realAttrName() { return "real"; } static constexpr llvm::StringRef imagAttrName() { return "imaginary"; } - mlir::Attribute getReal() { return getAttr(realAttrName()); } - mlir::Attribute getImaginary() { return getAttr(imagAttrName()); } + mlir::Attribute getReal() { return (*this)->getAttr(realAttrName()); } + mlir::Attribute getImaginary() { return (*this)->getAttr(imagAttrName()); } }]; } @@ -2485,7 +2485,7 @@ def fir_CmpcOp : fir_Op<"cmpc", } CmpFPredicate getPredicate() { - return (CmpFPredicate)getAttrOfType( + return (CmpFPredicate)(*this)->getAttrOfType( getPredicateAttrName()).getInt(); } }]; @@ -2601,7 +2601,7 @@ def fir_GenTypeDescOp : fir_OneResultOp<"gentypedesc", [NoSideEffect]> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr("in_type"); + p << getOperationName() << ' ' << (*this)->getAttr("in_type"); p.printOptionalAttrDict(getAttrs(), {"in_type"}); }]; @@ -2623,7 +2623,7 @@ def fir_GenTypeDescOp : fir_OneResultOp<"gentypedesc", [NoSideEffect]> { let extraClassDeclaration = [{ mlir::Type getInType() { // get the type that the type descriptor describes - return getAttrOfType("in_type").getValue(); + return (*this)->getAttrOfType("in_type").getValue(); } }]; } @@ -2697,7 +2697,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { if (linkName().hasValue()) p << ' ' << linkName().getValue(); p << ' '; - p.printAttributeWithoutType(getAttr(symbolAttrName())); + p.printAttributeWithoutType((*this)->getAttr(symbolAttrName())); if (auto val = getValueOrNull()) p << '(' << val << ')'; if ((*this)->getAttr(constantAttrName())) @@ -2738,7 +2738,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { /// The printable type of the global mlir::Type getType() { - return getAttrOfType(typeAttrName()).getValue(); + return (*this)->getAttrOfType(typeAttrName()).getValue(); } /// The semantic type of the global @@ -2768,8 +2768,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { } mlir::FlatSymbolRefAttr getSymbol() { - return mlir::FlatSymbolRefAttr::get(getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext()); + return mlir::FlatSymbolRefAttr::get( + (*this)->getAttrOfType( + mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext()); } }]; } @@ -2811,8 +2812,8 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr(lenParamAttrName()) << ", " - << getAttr(intAttrName()); + p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) + << ", " << (*this)->getAttr(intAttrName()); }]; let extraClassDeclaration = [{ @@ -2865,7 +2866,7 @@ def fir_DispatchTableOp : fir_Op<"dispatch_table", }]; let printer = [{ - auto tableName = getAttrOfType( + auto tableName = (*this)->getAttrOfType( mlir::SymbolTable::getSymbolAttrName()).getValue(); p << getOperationName() << " @" << tableName; @@ -2946,8 +2947,8 @@ def fir_DTEntryOp : fir_Op<"dt_entry", []> { }]; let printer = [{ - p << getOperationName() << ' ' << getAttr(methodAttrName()) << ", " - << getAttr(procAttrName()); + p << getOperationName() << ' ' << (*this)->getAttr(methodAttrName()) << ", " + << (*this)->getAttr(procAttrName()); }]; let extraClassDeclaration = [{ diff --git a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h index 43588ff17962ed..2d9ad28981ab24 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h +++ b/flang/include/flang/Optimizer/Dialect/FIROpsSupport.h @@ -18,7 +18,7 @@ namespace fir { /// return true iff the Operation is a non-volatile LoadOp inline bool nonVolatileLoad(mlir::Operation *op) { if (auto load = dyn_cast(op)) - return !load.getAttr("volatile"); + return !load->getAttr("volatile"); return false; } diff --git a/flang/lib/Lower/CharacterRuntime.cpp b/flang/lib/Lower/CharacterRuntime.cpp index af95885f985d2a..4bfbf5824efbbe 100644 --- a/flang/lib/Lower/CharacterRuntime.cpp +++ b/flang/lib/Lower/CharacterRuntime.cpp @@ -62,7 +62,7 @@ static mlir::FuncOp getRuntimeFunc(mlir::Location loc, return func; auto funTy = getTypeModel()(builder.getContext()); func = builder.createFunction(loc, name, funTy); - func.setAttr("fir.runtime", builder.getUnitAttr()); + func->setAttr("fir.runtime", builder.getUnitAttr()); return func; } diff --git a/flang/lib/Lower/IO.cpp b/flang/lib/Lower/IO.cpp index 3f79b79e32ee9a..ab7387dd3ce6e1 100644 --- a/flang/lib/Lower/IO.cpp +++ b/flang/lib/Lower/IO.cpp @@ -123,8 +123,8 @@ static mlir::FuncOp getIORuntimeFunc(mlir::Location loc, return func; auto funTy = getTypeModel()(builder.getContext()); func = builder.createFunction(loc, name, funTy); - func.setAttr("fir.runtime", builder.getUnitAttr()); - func.setAttr("fir.io", builder.getUnitAttr()); + func->setAttr("fir.runtime", builder.getUnitAttr()); + func->setAttr("fir.io", builder.getUnitAttr()); return func; } diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp index 23b084eaf67d5f..0e0081ef664cfe 100644 --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -497,7 +497,7 @@ static mlir::FuncOp getFuncOp(mlir::Location loc, const RuntimeFunction &runtime) { auto function = builder.addNamedFunction( loc, runtime.symbol, runtime.typeGenerator(builder.getContext())); - function.setAttr("fir.runtime", builder.getUnitAttr()); + function->setAttr("fir.runtime", builder.getUnitAttr()); return function; } @@ -769,7 +769,7 @@ mlir::FuncOp IntrinsicLibrary::getWrapper(GeneratorType generator, if (!function) { // First time this wrapper is needed, build it. function = builder.createFunction(loc, wrapperName, funcType); - function.setAttr("fir.intrinsic", builder.getUnitAttr()); + function->setAttr("fir.intrinsic", builder.getUnitAttr()); function.addEntryBlock(); // Create local context to emit code into the newly created function diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp index 6f45bb623d7db9..12cf97869543aa 100644 --- a/flang/lib/Lower/OpenACC.cpp +++ b/flang/lib/Lower/OpenACC.cpp @@ -98,8 +98,8 @@ static Op createRegionOp(Fortran::lower::FirOpBuilder &builder, builder.setInsertionPointToStart(&block); builder.create(loc); - op.setAttr(Op::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr(operandSegments)); + op->setAttr(Op::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr(operandSegments)); // Place the insertion point to the start of the first block. builder.setInsertionPointToStart(&block); @@ -114,8 +114,8 @@ static Op createSimpleOp(Fortran::lower::FirOpBuilder &builder, const SmallVectorImpl &operandSegments) { llvm::ArrayRef argTy; Op op = builder.create(loc, argTy, operands); - op.setAttr(Op::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr(operandSegments)); + op->setAttr(Op::getOperandSegmentSizeAttr(), + builder.getI32VectorAttr(operandSegments)); return op; } @@ -231,8 +231,8 @@ static void genACC(Fortran::lower::AbstractConverter &converter, auto loopOp = createRegionOp( firOpBuilder, currentLocation, operands, operandSegments); - loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), - firOpBuilder.getI64IntegerAttr(executionMapping)); + loopOp->setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(), + firOpBuilder.getI64IntegerAttr(executionMapping)); // Lower clauses mapped to attributes for (const auto &clause : accClauseList.v) { @@ -241,19 +241,19 @@ static void genACC(Fortran::lower::AbstractConverter &converter, const auto *expr = Fortran::semantics::GetExpr(collapseClause->v); const auto collapseValue = Fortran::evaluate::ToInt64(*expr); if (collapseValue) { - loopOp.setAttr(mlir::acc::LoopOp::getCollapseAttrName(), - firOpBuilder.getI64IntegerAttr(*collapseValue)); + loopOp->setAttr(mlir::acc::LoopOp::getCollapseAttrName(), + firOpBuilder.getI64IntegerAttr(*collapseValue)); } } else if (std::get_if(&clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getSeqAttrName(), - firOpBuilder.getUnitAttr()); + loopOp->setAttr(mlir::acc::LoopOp::getSeqAttrName(), + firOpBuilder.getUnitAttr()); } else if (std::get_if( &clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getIndependentAttrName(), - firOpBuilder.getUnitAttr()); + loopOp->setAttr(mlir::acc::LoopOp::getIndependentAttrName(), + firOpBuilder.getUnitAttr()); } else if (std::get_if(&clause.u)) { - loopOp.setAttr(mlir::acc::LoopOp::getAutoAttrName(), - firOpBuilder.getUnitAttr()); + loopOp->setAttr(mlir::acc::LoopOp::getAutoAttrName(), + firOpBuilder.getUnitAttr()); } } } @@ -425,14 +425,14 @@ genACCParallelOp(Fortran::lower::AbstractConverter &converter, firOpBuilder, currentLocation, operands, operandSegments); if (addAsyncAttr) - parallelOp.setAttr(mlir::acc::ParallelOp::getAsyncAttrName(), - firOpBuilder.getUnitAttr()); + parallelOp->setAttr(mlir::acc::ParallelOp::getAsyncAttrName(), + firOpBuilder.getUnitAttr()); if (addWaitAttr) - parallelOp.setAttr(mlir::acc::ParallelOp::getWaitAttrName(), - firOpBuilder.getUnitAttr()); + parallelOp->setAttr(mlir::acc::ParallelOp::getWaitAttrName(), + firOpBuilder.getUnitAttr()); if (addSelfAttr) - parallelOp.setAttr(mlir::acc::ParallelOp::getSelfAttrName(), - firOpBuilder.getUnitAttr()); + parallelOp->setAttr(mlir::acc::ParallelOp::getSelfAttrName(), + firOpBuilder.getUnitAttr()); } static void genACCDataOp(Fortran::lower::AbstractConverter &converter, diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 32d04b79c058f7..4a6c8d50e2ae04 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -166,7 +166,7 @@ static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) { p << callee.getValue(); else p << op.getOperand(0); - p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')'; + p << '(' << op->getOperands().drop_front(isDirect ? 0 : 1) << ')'; p.printOptionalAttrDict(op.getAttrs(), {fir::CallOp::calleeAttrName()}); auto resultTypes{op.getResultTypes()}; llvm::SmallVector argTypes( @@ -240,7 +240,8 @@ template static void printCmpOp(OpAsmPrinter &p, OPTY op) { p << op.getOperationName() << ' '; auto predSym = mlir::symbolizeCmpFPredicate( - op.template getAttrOfType(OPTY::getPredicateAttrName()) + op->template getAttrOfType( + OPTY::getPredicateAttrName()) .getInt()); assert(predSym.hasValue() && "invalid symbol value for predicate"); p << '"' << mlir::stringifyCmpFPredicate(predSym.getValue()) << '"' << ", "; @@ -385,7 +386,10 @@ static mlir::ParseResult parseCoordinateOp(mlir::OpAsmParser &parser, } mlir::Type fir::CoordinateOp::getBaseType() { - return getAttr(CoordinateOp::baseType()).cast().getValue(); + return (*this) + ->getAttr(CoordinateOp::baseType()) + .cast() + .getValue(); } void fir::CoordinateOp::build(OpBuilder &, OperationState &result, @@ -412,7 +416,7 @@ void fir::CoordinateOp::build(OpBuilder &builder, OperationState &result, //===----------------------------------------------------------------------===// mlir::FunctionType fir::DispatchOp::getFunctionType() { - auto attr = getAttr("fn_type").cast(); + auto attr = (*this)->getAttr("fn_type").cast(); return attr.getValue().cast(); } @@ -745,7 +749,7 @@ static void print(mlir::OpAsmPrinter &p, fir::IterWhileOp op) { [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> (" << op.getResultTypes().drop_front() << ')'; } - p.printOptionalAttrDictWithKeyword(op.getAttrs(), {}); + p.printOptionalAttrDictWithKeyword(op->getAttrs(), {}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } @@ -930,7 +934,7 @@ static void print(mlir::OpAsmPrinter &p, fir::LoopOp op) { p << ") -> (" << op.getResultTypes() << ')'; printBlockTerminators = true; } - p.printOptionalAttrDictWithKeyword(op.getAttrs(), + p.printOptionalAttrDictWithKeyword(op->getAttrs(), {fir::LoopOp::unorderedAttrName()}); p.printRegion(op.region(), /*printEntryBlockArgs=*/false, printBlockTerminators); @@ -963,9 +967,9 @@ mlir::OpFoldResult fir::MulfOp::fold(llvm::ArrayRef opnds) { //===----------------------------------------------------------------------===// static mlir::LogicalResult verify(fir::ResultOp op) { - auto parentOp = op.getParentOp(); + auto *parentOp = op->getParentOp(); auto results = parentOp->getResults(); - auto operands = op.getOperands(); + auto operands = op->getOperands(); if (parentOp->getNumResults() != op.getNumOperands()) return op.emitOpError() << "parent of result must have same arity"; @@ -1032,15 +1036,16 @@ fir::SelectOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// @@ -1049,16 +1054,18 @@ unsigned fir::SelectOp::targetOffsetSize() { llvm::Optional fir::SelectCaseOp::getCompareOperands(unsigned cond) { - auto a = getAttrOfType(getCompareOffsetAttr()); + auto a = (*this)->getAttrOfType( + getCompareOffsetAttr()); return {getSubOperands(cond, compareArgs(), a)}; } llvm::Optional> fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef operands, unsigned cond) { - auto a = getAttrOfType(getCompareOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = (*this)->getAttrOfType( + getCompareOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; } @@ -1071,9 +1078,10 @@ fir::SelectCaseOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } @@ -1152,13 +1160,13 @@ static mlir::ParseResult parseSelectCase(mlir::OpAsmParser &parser, } unsigned fir::SelectCaseOp::compareOffsetSize() { - return denseElementsSize( - getAttrOfType(getCompareOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getCompareOffsetAttr())); } unsigned fir::SelectCaseOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } void fir::SelectCaseOp::build(mlir::OpBuilder &builder, @@ -1262,15 +1270,16 @@ fir::SelectRankOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } unsigned fir::SelectRankOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// @@ -1296,9 +1305,10 @@ fir::SelectTypeOp::getMutableSuccessorOperands(unsigned oper) { llvm::Optional> fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef operands, unsigned oper) { - auto a = getAttrOfType(getTargetOffsetAttr()); - auto segments = - getAttrOfType(getOperandSegmentSizeAttr()); + auto a = + (*this)->getAttrOfType(getTargetOffsetAttr()); + auto segments = (*this)->getAttrOfType( + getOperandSegmentSizeAttr()); return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; } @@ -1348,8 +1358,8 @@ static ParseResult parseSelectType(OpAsmParser &parser, } unsigned fir::SelectTypeOp::targetOffsetSize() { - return denseElementsSize( - getAttrOfType(getTargetOffsetAttr())); + return denseElementsSize((*this)->getAttrOfType( + getTargetOffsetAttr())); } //===----------------------------------------------------------------------===// @@ -1467,7 +1477,7 @@ static void print(mlir::OpAsmPrinter &p, fir::WhereOp op) { p.printRegion(otherReg, /*printEntryBlockArgs=*/false, printBlockTerminators); } - p.printOptionalAttrDict(op.getAttrs()); + p.printOptionalAttrDict(op->getAttrs()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp index 38b54ebadfc760..e0a8420e15b7a7 100644 --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -787,7 +787,7 @@ LogicalResult Importer::processFunction(llvm::Function *f) { convertLinkageFromLLVM(f->getLinkage())); if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) - fop.setAttr(b.getIdentifier("personality"), personality); + fop->setAttr(b.getIdentifier("personality"), personality); else if (f->hasPersonalityFn()) emitWarning(UnknownLoc::get(context), "could not deduce personality, skipping it"); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index d5daebb57475b8..833f90dd28a2fb 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -48,13 +48,13 @@ def AOp : NS_Op<"a_op", []> { // --- // DEF: some-attr-kind AOp::aAttrAttr() -// DEF-NEXT: (*this)->getAttr("aAttr").cast() +// DEF-NEXT: (*this)->getAttr("aAttr").template cast() // DEF: some-return-type AOp::aAttr() { // DEF-NEXT: auto attr = aAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::bAttrAttr() -// DEF-NEXT: return (*this)->getAttr("bAttr").dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr("bAttr").template dyn_cast_or_null() // DEF: some-return-type AOp::bAttr() { // DEF-NEXT: auto attr = bAttrAttr(); // DEF-NEXT: if (!attr) @@ -62,7 +62,7 @@ def AOp : NS_Op<"a_op", []> { // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::cAttrAttr() -// DEF-NEXT: return (*this)->getAttr("cAttr").dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr("cAttr").template dyn_cast_or_null() // DEF: ::llvm::Optional AOp::cAttr() { // DEF-NEXT: auto attr = cAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 5d756e708a434e..f2a57fbbb00fa7 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -90,8 +90,7 @@ const char *adapterSegmentSizeAttrInitCode = R"( auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); )"; const char *opSegmentSizeAttrInitCode = R"( - auto sizeAttr = - getOperation()->getAttrOfType<::mlir::DenseIntElementsAttr>("{0}"); + auto sizeAttr = (*this)->getAttrOfType<::mlir::DenseIntElementsAttr>("{0}"); )"; const char *attrSizedSegmentValueRangeCalcCode = R"( unsigned start = 0; @@ -521,7 +520,7 @@ void OpEmitter::genAttrGetters() { if (!method) return; auto &body = method->body(); - body << " return (*this)->getAttr(\"" << name << "\")."; + body << " return (*this)->getAttr(\"" << name << "\").template "; if (attr.isOptional() || attr.hasDefaultValue()) body << "dyn_cast_or_null<"; else diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp index d576f2b640a94e..6e4283d3a3e42f 100644 --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -34,13 +34,13 @@ struct AnnotateFunctionPass : public PassWrapper> { void runOnOperation() override { FuncOp op = getOperation(); - Builder builder(op.getParentOfType()); + Builder builder(op->getParentOfType()); auto &ga = getAnalysis(); auto &sa = getAnalysis(); - op.setAttr("isFunc", builder.getBoolAttr(ga.isFunc)); - op.setAttr("isSecret", builder.getBoolAttr(sa.isSecret)); + op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc)); + op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret)); } }; @@ -66,12 +66,12 @@ TEST(PassManagerTest, OpSpecificAnalysis) { // Verify that each function got annotated with expected attributes. for (FuncOp func : module->getOps()) { - ASSERT_TRUE(func.getAttr("isFunc").isa()); - EXPECT_TRUE(func.getAttr("isFunc").cast().getValue()); + ASSERT_TRUE(func->getAttr("isFunc").isa()); + EXPECT_TRUE(func->getAttr("isFunc").cast().getValue()); bool isSecret = func.getName() == "secret"; - ASSERT_TRUE(func.getAttr("isSecret").isa()); - EXPECT_EQ(func.getAttr("isSecret").cast().getValue(), isSecret); + ASSERT_TRUE(func->getAttr("isSecret").isa()); + EXPECT_EQ(func->getAttr("isSecret").cast().getValue(), isSecret); } } From 5a2d954671e91e63e2f944cce31bdcc232c8ecc2 Mon Sep 17 00:00:00 2001 From: Alina Sbirlea Date: Mon, 14 Dec 2020 11:19:01 -0800 Subject: [PATCH 02/12] [NFC] Remove stray comment. --- llvm/include/llvm/Support/GenericDomTree.h | 1 - 1 file changed, 1 deletion(-) diff --git a/llvm/include/llvm/Support/GenericDomTree.h b/llvm/include/llvm/Support/GenericDomTree.h index 4bed550f44c0f2..d2d7c8c4481d6d 100644 --- a/llvm/include/llvm/Support/GenericDomTree.h +++ b/llvm/include/llvm/Support/GenericDomTree.h @@ -554,7 +554,6 @@ class DominatorTreeBase { /// obtained PostViewCFG is the desired end state. void applyUpdates(ArrayRef Updates, ArrayRef PostViewUpdates) { - // GraphDiff *PostViewCFG = nullptr) { if (Updates.empty()) { GraphDiff PostViewCFG(PostViewUpdates); DomTreeBuilder::ApplyUpdates(*this, PostViewCFG, &PostViewCFG); From 55fc64bce08a30f1bf7f7ebf83df776a40700fbe Mon Sep 17 00:00:00 2001 From: Reid Kleckner Date: Mon, 14 Dec 2020 11:22:57 -0800 Subject: [PATCH 03/12] [Hexagon] Tweak _MSC_VER workaround version My bot runs VS 2019, but it could not compile this code. Message: [55/2465] Building CXX object lib\Target\Hexagon\CMakeFiles\LLVMHexagonCodeGen.dir\HexagonVectorCombine.cpp.obj FAILED: lib/Target/Hexagon/CMakeFiles/LLVMHexagonCodeGen.dir/HexagonVectorCombine.cpp.obj ... C:\Program Files (x86)\Microsoft Visual Studio\2019\Professional\VC\Tools\MSVC\14.23.28105\include\map(71): error C2976: 'std::map': too few template arguments C:\Program Files (x86)\Microsoft Visual Studio\2019\Professional\VC\Tools\MSVC\14.23.28105\include\map(71): note: see declaration of 'std::map' The version in the path, 14.23, corresponds to _MSC_VER 1923, so raise the version floor to 1924. I have not tested with versions between 1924 and 1928 (latest), but the latest works with the variadic version. --- llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp index 4c0c202be4bef9..2d90e37349e57b 100644 --- a/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp +++ b/llvm/lib/Target/Hexagon/HexagonVectorCombine.cpp @@ -292,7 +292,7 @@ template <> StoreInst *isCandidate(Instruction *In) { return getIfUnordered(dyn_cast(In)); } -#if !defined(_MSC_VER) || _MSC_VER >= 1920 +#if !defined(_MSC_VER) || _MSC_VER >= 1924 // VS2017 has trouble compiling this: // error C2976: 'std::map': too few template arguments template From 9c1765acabf10b7df7cf49456a06bbba2b33b364 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Mon, 14 Dec 2020 10:59:26 -0500 Subject: [PATCH 04/12] [VectorCombine] add test for load with offset; NFC --- llvm/test/Transforms/VectorCombine/X86/load.ll | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/llvm/test/Transforms/VectorCombine/X86/load.ll b/llvm/test/Transforms/VectorCombine/X86/load.ll index 824a507ed1036b..ba2bf3f37d7b67 100644 --- a/llvm/test/Transforms/VectorCombine/X86/load.ll +++ b/llvm/test/Transforms/VectorCombine/X86/load.ll @@ -535,3 +535,20 @@ define <8 x i32> @load_v1i32_extract_insert_v8i32_extra_use(<1 x i32>* align 16 %r = insertelement <8 x i32> undef, i32 %s, i32 0 ret <8 x i32> %r } + +; TODO: Can't safely load the offset vector, but can load+shuffle if it is profitable. + +define <8 x i16> @gep1_load_v2i16_extract_insert_v8i16(<2 x i16>* align 16 dereferenceable(16) %p) { +; CHECK-LABEL: @gep1_load_v2i16_extract_insert_v8i16( +; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds <2 x i16>, <2 x i16>* [[P:%.*]], i64 1 +; CHECK-NEXT: [[L:%.*]] = load <2 x i16>, <2 x i16>* [[GEP]], align 2 +; CHECK-NEXT: [[S:%.*]] = extractelement <2 x i16> [[L]], i32 0 +; CHECK-NEXT: [[R:%.*]] = insertelement <8 x i16> undef, i16 [[S]], i64 0 +; CHECK-NEXT: ret <8 x i16> [[R]] +; + %gep = getelementptr inbounds <2 x i16>, <2 x i16>* %p, i64 1 + %l = load <2 x i16>, <2 x i16>* %gep, align 2 + %s = extractelement <2 x i16> %l, i32 0 + %r = insertelement <8 x i16> undef, i16 %s, i64 0 + ret <8 x i16> %r +} From 0936655bac78f6e9cb84dc3feb30c32012100839 Mon Sep 17 00:00:00 2001 From: Artem Belevich Date: Tue, 8 Dec 2020 15:05:33 -0800 Subject: [PATCH 05/12] [CUDA] Do not diagnose host/device variable access in dependent types. `isCUDADeviceBuiltinSurfaceType()`/`isCUDADeviceBuiltinTextureType()` do not work on dependent types as they rely on specific type attributes. Differential Revision: https://reviews.llvm.org/D92893 --- clang/include/clang/Basic/Attr.td | 2 ++ clang/test/SemaCUDA/device-use-host-var.cu | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 51f654fc7613aa..79902c8f5b894e 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -1079,6 +1079,7 @@ def CUDADeviceBuiltinSurfaceType : InheritableAttr { let LangOpts = [CUDA]; let Subjects = SubjectList<[CXXRecord]>; let Documentation = [CUDADeviceBuiltinSurfaceTypeDocs]; + let MeaningfulToClassTemplateDefinition = 1; } def CUDADeviceBuiltinTextureType : InheritableAttr { @@ -1087,6 +1088,7 @@ def CUDADeviceBuiltinTextureType : InheritableAttr { let LangOpts = [CUDA]; let Subjects = SubjectList<[CXXRecord]>; let Documentation = [CUDADeviceBuiltinTextureTypeDocs]; + let MeaningfulToClassTemplateDefinition = 1; } def CUDAGlobal : InheritableAttr { diff --git a/clang/test/SemaCUDA/device-use-host-var.cu b/clang/test/SemaCUDA/device-use-host-var.cu index cf5514610a42ab..c8ef7dbbb18dca 100644 --- a/clang/test/SemaCUDA/device-use-host-var.cu +++ b/clang/test/SemaCUDA/device-use-host-var.cu @@ -158,3 +158,23 @@ void dev_lambda_capture_by_copy(int *out) { }); } +// Texture references are special. As far as C++ is concerned they are host +// variables that are referenced from device code. However, they are handled +// very differently by the compiler under the hood and such references are +// allowed. Compiler should produce no warning here, but it should diagnose the +// same case without the device_builtin_texture_type attribute. +template +struct __attribute__((device_builtin_texture_type)) texture { + static texture ref; + __device__ int c() { + auto &x = ref; + } +}; + +template +struct not_a_texture { + static not_a_texture ref; + __device__ int c() { + auto &x = ref; // dev-error {{reference to __host__ variable 'ref' in __device__ function}} + } +}; From c234b65cef07b38c91b9ab7dec6a35f8b390e658 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 14 Dec 2020 11:53:34 -0800 Subject: [PATCH 06/12] [mlir][OpFormat] Add support for emitting newlines from the custom format of an operation This revision adds a new `printNewline` hook to OpAsmPrinter that allows for printing a newline within the custom format of an operation, that is then indented to the start of the operation. Support for the declarative assembly format is also added, in the form of a `\n` literal. Differential Revision: https://reviews.llvm.org/D93151 --- mlir/docs/OpDefinitions.md | 24 +++++++++ mlir/include/mlir/IR/OpImplementation.h | 4 ++ mlir/lib/IR/AsmPrinter.cpp | 8 +++ mlir/test/lib/Dialect/Test/TestOps.td | 3 +- mlir/test/mlir-tblgen/op-format-spec.td | 2 +- mlir/test/mlir-tblgen/op-format.mlir | 6 ++- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 65 +++++++++++++++++++------ 7 files changed, 93 insertions(+), 19 deletions(-) diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md index a267a60adc3ef2..189cd0825af7eb 100644 --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -646,6 +646,30 @@ The following are the set of valid punctuation: `:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`, `?`, `+`, `*` +The following are valid whitespace punctuation: + +`\n`, ` ` + +The `\n` literal emits a newline an indents to the start of the operation. An +example is shown below: + +```tablegen +let assemblyFormat = [{ + `{` `\n` ` ` ` ` `this_is_on_a_newline` `\n` `}` attr-dict +}]; +``` + +```mlir +%results = my.operation { + this_is_on_a_newline +} +``` + +An empty literal \`\` may be used to remove a space that is inserted implicitly +after certain literal elements, such as `)`/`]`/etc. For example, "`]`" may +result in an output of `]` it is not the last element in the format. "`]` \`\`" +would trim the trailing space in this situation. + #### Variables A variable is an entity that has been registered on the operation itself, i.e. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index a7e87dc0ab06ff..31d3b42c84935f 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -36,6 +36,10 @@ class OpAsmPrinter { virtual ~OpAsmPrinter(); virtual raw_ostream &getStream() const = 0; + /// Print a newline and indent the printer to the start of the current + /// operation. + virtual void printNewline() = 0; + /// Print implementations for various things an operation contains. virtual void printOperand(Value value) = 0; virtual void printOperand(Value value, raw_ostream &os) = 0; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 49e7048cfb1375..1c2caa0bdfd686 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -429,6 +429,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { /// The following are hooks of `OpAsmPrinter` that are not necessary for /// determining potential aliases. void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} + void printNewline() override {} void printOperand(Value) override {} void printOperand(Value, raw_ostream &os) override { // Users expect the output string to have at least the prefixed % to signal @@ -2218,6 +2219,13 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter { /// Return the current stream of the printer. raw_ostream &getStream() const override { return os; } + /// Print a newline and indent the printer to the start of the current + /// operation. + void printNewline() override { + os << newLine; + os.indent(currentIndent); + } + /// Print the given type. void printType(Type type) override { ModulePrinter::printType(type); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6a7291abfec78f..9a7eb5940fb957 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1393,7 +1393,8 @@ def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> { def FormatLiteralOp : TEST_Op<"format_literal_op"> { let assemblyFormat = [{ - `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` `?` `+` `*` attr-dict + `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` + `?` `+` `*` `{` `\n` `}` attr-dict }]; } diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td index 7817920f8955d4..424dbb83c27679 100644 --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -309,7 +309,7 @@ def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{ }]>; // CHECK-NOT: error def LiteralValid : TestFormat_Op<"literal_valid", [{ - `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `abc$._` + `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `\n` `abc$._` attr-dict }]>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index 6286f76551466d..334313debda113 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -7,8 +7,10 @@ // CHECK: %[[MEMREF:.*]] = %memref = "foo.op"() : () -> (memref<1xf64>) -// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {foo.some_attr} -test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {foo.some_attr} +// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * { +// CHECK-NEXT: } {foo.some_attr} +test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * { +} {foo.some_attr} // CHECK: test.format_attr_op 10 // CHECK-NOT: {attr diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index e09cdd2ac6d42b..6cc7c75dc8a468 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -58,7 +58,8 @@ class Element { /// This element is a literal. Literal, - /// This element prints or omits a space. It is ignored by the parser. + /// This element is a whitespace. + Newline, Space, /// This element is an variable value. @@ -296,14 +297,35 @@ bool LiteralElement::isValidLiteral(StringRef value) { } //===----------------------------------------------------------------------===// -// SpaceElement +// WhitespaceElement namespace { +/// This class represents a whitespace element, e.g. newline or space. It's a +/// literal that is printed but never parsed. +class WhitespaceElement : public Element { +public: + WhitespaceElement(Kind kind) : Element{kind} {} + static bool classof(const Element *element) { + Kind kind = element->getKind(); + return kind == Kind::Newline || kind == Kind::Space; + } +}; + +/// This class represents an instance of a newline element. It's a literal that +/// prints a newline. It is ignored by the parser. +class NewlineElement : public WhitespaceElement { +public: + NewlineElement() : WhitespaceElement(Kind::Newline) {} + static bool classof(const Element *element) { + return element->getKind() == Kind::Newline; + } +}; + /// This class represents an instance of a space element. It's a literal that /// prints or omits printing a space. It is ignored by the parser. -class SpaceElement : public Element { +class SpaceElement : public WhitespaceElement { public: - SpaceElement(bool value) : Element{Kind::Space}, value(value) {} + SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {} static bool classof(const Element *element) { return element->getKind() == Kind::Space; } @@ -347,7 +369,8 @@ class OptionalElement : public Element { std::vector> elements; /// The index of the element that acts as the anchor for the optional group. unsigned anchor; - /// The index of the first element that is parsed (is not a SpaceElement). + /// The index of the first element that is parsed (is not a + /// WhitespaceElement). unsigned parseStart; }; } // end anonymous namespace @@ -1098,8 +1121,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body, genLiteralParser(literal->getLiteral(), body); body << ")\n return ::mlir::failure();\n"; - /// Spaces. - } else if (isa(element)) { + /// Whitespaces. + } else if (isa(element)) { // Nothing to parse. /// Arguments. @@ -1620,6 +1643,11 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body, return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, lastWasPunctuation); + // Emit a whitespace element. + if (NewlineElement *newline = dyn_cast(element)) { + body << " p.printNewline();\n"; + return; + } if (SpaceElement *space = dyn_cast(element)) return genSpacePrinter(space->getValue(), body, shouldEmitSpace, lastWasPunctuation); @@ -2272,9 +2300,10 @@ LogicalResult FormatParser::verifyAttributes( for (auto &nextItPair : iteratorStack) { ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second; for (; nextIt != nextE; ++nextIt) { - // Skip any trailing spaces, attribute dictionaries, or optional groups. - if (isa(*nextIt) || isa(*nextIt) || - isa(*nextIt)) + // Skip any trailing whitespace, attribute dictionaries, or optional + // groups. + if (isa(*nextIt) || + isa(*nextIt) || isa(*nextIt)) continue; // We are only interested in `:` literals. @@ -2600,6 +2629,11 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr &element) { element = std::make_unique(!value.empty()); return ::mlir::success(); } + // The parsed literal is a newline element. + if (value == "\\n") { + element = std::make_unique(); + return ::mlir::success(); + } // Check that the parsed literal is valid. if (!LiteralElement::isValidLiteral(value)) @@ -2635,8 +2669,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr &element, // The first parsable element of the group must be able to be parsed in an // optional fashion. - auto parseBegin = llvm::find_if_not( - elements, [](auto &element) { return isa(element.get()); }); + auto parseBegin = llvm::find_if_not(elements, [](auto &element) { + return isa(element.get()); + }); Element *firstElement = parseBegin->get(); if (!isa(firstElement) && !isa(firstElement) && @@ -2718,9 +2753,9 @@ LogicalResult FormatParser::parseOptionalChildElement( // a check here. return ::mlir::success(); }) - // Literals, spaces, custom directives, and type directives may be used, - // but they can't anchor the group. - .Case([&](Element *) { if (isAnchor) From 6bc9439f59acbcc5e46a108c2f74a4d5ffe55a3b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 14 Dec 2020 11:53:43 -0800 Subject: [PATCH 07/12] [mlir][OpAsmParser] Add support for parsing integer literals without going through IntegerAttr Some operations use integer literals as part of their custom format that don't necessarily map to an internal IntegerAttr. This revision exposes the same `parseInteger` functions as the DialectAsmParser to allow for these operations to parse integer literals without incurring the otherwise unnecessary roundtrip through IntegerAttr. Differential Revision: https://reviews.llvm.org/D93152 --- mlir/include/mlir/IR/OpImplementation.h | 29 ++++++++++++++++++++++ mlir/lib/Parser/DialectSymbolParser.cpp | 15 +---------- mlir/lib/Parser/Parser.cpp | 23 +++++++++++++++++ mlir/lib/Parser/Parser.h | 3 +++ mlir/test/IR/parser.mlir | 15 ++++++++--- mlir/test/lib/Dialect/Test/TestDialect.cpp | 28 ++++++++++++++++++--- mlir/test/lib/Dialect/Test/TestOps.td | 8 +++++- 7 files changed, 98 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 31d3b42c84935f..f74eb52aec6d6e 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -413,6 +413,35 @@ class OpAsmParser { /// Parse a `...` token if present; virtual ParseResult parseOptionalEllipsis() = 0; + /// Parse an integer value from the stream. + template ParseResult parseInteger(IntT &result) { + auto loc = getCurrentLocation(); + OptionalParseResult parseResult = parseOptionalInteger(result); + if (!parseResult.hasValue()) + return emitError(loc, "expected integer value"); + return *parseResult; + } + + /// Parse an optional integer value from the stream. + virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; + + template + OptionalParseResult parseOptionalInteger(IntT &result) { + auto loc = getCurrentLocation(); + + // Parse the unsigned variant. + uint64_t uintResult; + OptionalParseResult parseResult = parseOptionalInteger(uintResult); + if (!parseResult.hasValue() || failed(*parseResult)) + return parseResult; + + // Try to convert to the provided integer type. + result = IntT(uintResult); + if (uint64_t(result) != uintResult) + return emitError(loc, "integer value too large"); + return success(); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp index 3bbc495cab658a..11e7e237a192a8 100644 --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -82,20 +82,7 @@ class CustomDialectAsmParser : public DialectAsmParser { /// Parse an optional integer value from the stream. OptionalParseResult parseOptionalInteger(uint64_t &result) override { - Token curToken = parser.getToken(); - if (curToken.isNot(Token::integer, Token::minus)) - return llvm::None; - - bool negative = parser.consumeIf(Token::minus); - Token curTok = parser.getToken(); - if (parser.parseToken(Token::integer, "expected integer value")) - return failure(); - - auto val = curTok.getUInt64IntegerValue(); - if (!val) - return emitError(curTok.getLoc(), "integer value too large"); - result = negative ? -*val : *val; - return success(); + return parser.parseOptionalInteger(result); } //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 47fef1ca393cf8..58ed9004c56c05 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -94,6 +94,24 @@ ParseResult Parser::parseToken(Token::Kind expectedToken, return emitError(message); } +/// Parse an optional integer value from the stream. +OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) { + Token curToken = getToken(); + if (curToken.isNot(Token::integer, Token::minus)) + return llvm::None; + + bool negative = consumeIf(Token::minus); + Token curTok = getToken(); + if (parseToken(Token::integer, "expected integer value")) + return failure(); + + auto val = curTok.getUInt64IntegerValue(); + if (!val) + return emitError(curTok.getLoc(), "integer value too large"); + result = negative ? -*val : *val; + return success(); +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// @@ -1109,6 +1127,11 @@ class CustomOpAsmParser : public OpAsmParser { return success(parser.consumeIf(Token::star)); } + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(uint64_t &result) override { + return parser.parseOptionalInteger(result); + } + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h index e32cca411e790f..0e9e4caff440ca 100644 --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -127,6 +127,9 @@ class Parser { /// output a diagnostic and return failure. ParseResult parseToken(Token::Kind expectedToken, const Twine &message); + /// Parse an optional integer value from the stream. + OptionalParseResult parseOptionalInteger(uint64_t &result); + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 8fcb7863726f7c..ca61bf2a97a5f5 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1174,10 +1174,17 @@ func private @ptr_to_function() -> !unreg.ptr<() -> ()> // CHECK-LABEL: func private @escaped_string_char(i1 {foo.value = "\0A"}) func private @escaped_string_char(i1 {foo.value = "\n"}) -// CHECK-LABEL: func @wrapped_keyword_test -func @wrapped_keyword_test() { - // CHECK: test.wrapped_keyword foo.keyword - test.wrapped_keyword foo.keyword +// CHECK-LABEL: func @parse_integer_literal_test +func @parse_integer_literal_test() { + // CHECK: test.parse_integer_literal : 5 + test.parse_integer_literal : 5 + return +} + +// CHECK-LABEL: func @parse_wrapped_keyword_test +func @parse_wrapped_keyword_test() { + // CHECK: test.parse_wrapped_keyword foo.keyword + test.parse_wrapped_keyword foo.keyword return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 933b59dc5b8f9e..5e9bae8933634f 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -564,8 +564,28 @@ static void print(OpAsmPrinter &p, AffineScopeOp op) { // Test parser. //===----------------------------------------------------------------------===// -static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, - OperationState &result) { +static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalColon()) + return success(); + uint64_t numResults; + if (parser.parseInteger(numResults)) + return failure(); + + IndexType type = parser.getBuilder().getIndexType(); + for (unsigned i = 0; i < numResults; ++i) + result.addTypes(type); + return success(); +} + +static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) { + p << ParseIntegerLiteralOp::getOperationName(); + if (unsigned numResults = op->getNumResults()) + p << " : " << numResults; +} + +static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser, + OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); @@ -573,8 +593,8 @@ static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, return success(); } -static void print(OpAsmPrinter &p, WrappedKeywordOp op) { - p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); +static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) { + p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 9a7eb5940fb957..1fc419cc375f34 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1293,7 +1293,13 @@ def TestMergeBlocksOp : TEST_Op<"merge_blocks"> { // Test parser. //===----------------------------------------------------------------------===// -def WrappedKeywordOp : TEST_Op<"wrapped_keyword"> { +def ParseIntegerLiteralOp : TEST_Op<"parse_integer_literal"> { + let results = (outs Variadic:$results); + let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; +} + +def ParseWrappedKeywordOp : TEST_Op<"parse_wrapped_keyword"> { let arguments = (ins StrAttr:$keyword); let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(p, *this); }]; From 2aa43358060c6b34fb9cdc6c4321e958f62331e7 Mon Sep 17 00:00:00 2001 From: Michael Kruse Date: Mon, 14 Dec 2020 11:05:51 -0600 Subject: [PATCH 08/12] [flang] Fix copy elision assumption. Before this patch, the Restorer depended on copy elision to happen. Without copy elision, the function ScopedSet calls the move constructor before its dtor. The dtor will prematurely restore the reference to the original value. Instead of relying the compiler to not use the Restorer's copy constructor, delete its copy and assign operators. Hence, callers cannot move or copy a Restorer object anymore, and have to explicitly provide the reset state. ScopedSet avoids calling move/copy operations by relying on unnamed return value optimization, which is mandatory in C++17. Reviewed By: klausler Differential Revision: https://reviews.llvm.org/D88797 --- flang/include/flang/Common/restorer.h | 17 ++++++++++++----- flang/lib/Semantics/check-declarations.cpp | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/flang/include/flang/Common/restorer.h b/flang/include/flang/Common/restorer.h index 47e54237fe43cc..4d5f5e4e2c818d 100644 --- a/flang/include/flang/Common/restorer.h +++ b/flang/include/flang/Common/restorer.h @@ -22,9 +22,16 @@ namespace Fortran::common { template class Restorer { public: - explicit Restorer(A &p) : p_{p}, original_{std::move(p)} {} + explicit Restorer(A &p, A original) : p_{p}, original_{std::move(original)} {} ~Restorer() { p_ = std::move(original_); } + // Inhibit any recreation of this restorer that would result in two restorers + // trying to restore the same reference. + Restorer(const Restorer &) = delete; + Restorer(Restorer &&that) = delete; + const Restorer &operator=(const Restorer &) = delete; + const Restorer &operator=(Restorer &&that) = delete; + private: A &p_; A original_; @@ -32,15 +39,15 @@ template class Restorer { template common::IfNoLvalue, B> ScopedSet(A &to, B &&from) { - Restorer result{to}; + A original{std::move(to)}; to = std::move(from); - return result; + return Restorer{to, std::move(original)}; } template common::IfNoLvalue, B> ScopedSet(A &to, const B &from) { - Restorer result{to}; + A original{std::move(to)}; to = from; - return result; + return Restorer{to, std::move(original)}; } } // namespace Fortran::common #endif // FORTRAN_COMMON_RESTORER_H_ diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp index 0d2e2e86241c66..dd76f67100701e 100644 --- a/flang/lib/Semantics/check-declarations.cpp +++ b/flang/lib/Semantics/check-declarations.cpp @@ -1498,7 +1498,7 @@ void CheckHelper::CheckProcBinding( void CheckHelper::Check(const Scope &scope) { scope_ = &scope; - common::Restorer restorer{innermostSymbol_}; + common::Restorer restorer{innermostSymbol_, innermostSymbol_}; if (const Symbol * symbol{scope.symbol()}) { innermostSymbol_ = symbol; } From 6f271e921ba48f4c4fa54bbd2c7a4c548ca5e59e Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Mon, 14 Dec 2020 08:34:39 +0100 Subject: [PATCH 09/12] [mlir] Remove methods from mlir::OpState that just forward to mlir::Operation. All call sites have been converted in previous changes. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D93176 --- mlir/include/mlir/IR/OpDefinition.h | 51 ----------------------------- 1 file changed, 51 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index b200f7c2dc6cf2..2edc48dd099e6e 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -108,27 +108,6 @@ class OpState { /// Return the operation that this refers to. Operation *getOperation() { return state; } - /// Return the dialect that this refers to. - Dialect *getDialect() { return getOperation()->getDialect(); } - - /// Return the parent Region of this operation. - Region *getParentRegion() { return getOperation()->getParentRegion(); } - - /// Returns the closest surrounding operation that contains this operation - /// or nullptr if this is a top-level operation. - Operation *getParentOp() { return getOperation()->getParentOp(); } - - /// Return the closest surrounding parent operation that is of type 'OpTy'. - template OpTy getParentOfType() { - return getOperation()->getParentOfType(); - } - - /// Returns the closest surrounding parent operation with trait `Trait`. - template